mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
commit
92b163e95c
@ -2,17 +2,17 @@
|
|||||||
|
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
from fastapi import Query
|
from fastapi import Query, Body
|
||||||
from fastapi.routing import APIRouter, HTTPException
|
from fastapi.routing import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel, Field, parse_obj_as
|
from pydantic import BaseModel, Field, parse_obj_as
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
from invokeai.backend import BaseModelType, ModelType
|
from invokeai.backend import BaseModelType, ModelType
|
||||||
|
from invokeai.backend.model_management import AddModelResult
|
||||||
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
|
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
|
||||||
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
|
|
||||||
class VaeRepo(BaseModel):
|
class VaeRepo(BaseModel):
|
||||||
repo_id: str = Field(description="The repo ID to use for this VAE")
|
repo_id: str = Field(description="The repo ID to use for this VAE")
|
||||||
path: Optional[str] = Field(description="The path to the VAE")
|
path: Optional[str] = Field(description="The path to the VAE")
|
||||||
@ -51,9 +51,12 @@ class CreateModelResponse(BaseModel):
|
|||||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||||
status: str = Field(description="The status of the API response")
|
status: str = Field(description="The status of the API response")
|
||||||
|
|
||||||
class ImportModelRequest(BaseModel):
|
class ImportModelResponse(BaseModel):
|
||||||
name: str = Field(description="A model path, repo_id or URL to import")
|
name: str = Field(description="The name of the imported model")
|
||||||
prediction_type: Optional[Literal['epsilon','v_prediction','sample']] = Field(description='Prediction type for SDv2 checkpoint files')
|
# base_model: str = Field(description="The base model")
|
||||||
|
# model_type: str = Field(description="The model type")
|
||||||
|
info: AddModelResult = Field(description="The model info")
|
||||||
|
status: str = Field(description="The status of the API response")
|
||||||
|
|
||||||
class ConversionRequest(BaseModel):
|
class ConversionRequest(BaseModel):
|
||||||
name: str = Field(description="The name of the new model")
|
name: str = Field(description="The name of the new model")
|
||||||
@ -86,7 +89,6 @@ async def list_models(
|
|||||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/",
|
"/",
|
||||||
operation_id="update_model",
|
operation_id="update_model",
|
||||||
@ -109,27 +111,38 @@ async def update_model(
|
|||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/",
|
"/import",
|
||||||
operation_id="import_model",
|
operation_id="import_model",
|
||||||
responses={200: {"status": "success"}},
|
responses= {
|
||||||
|
201: {"description" : "The model imported successfully"},
|
||||||
|
404: {"description" : "The model could not be found"},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
response_model=ImportModelResponse
|
||||||
)
|
)
|
||||||
async def import_model(
|
async def import_model(
|
||||||
model_request: ImportModelRequest
|
name: str = Query(description="A model path, repo_id or URL to import"),
|
||||||
) -> None:
|
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||||
""" Add Model """
|
) -> ImportModelResponse:
|
||||||
items_to_import = set([model_request.name])
|
""" Add a model using its local path, repo_id, or remote URL """
|
||||||
|
items_to_import = {name}
|
||||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||||
items_to_import = items_to_import,
|
items_to_import = items_to_import,
|
||||||
prediction_type_helper = lambda x: prediction_types.get(model_request.prediction_type)
|
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||||
)
|
)
|
||||||
if len(installed_models) > 0:
|
if info := installed_models.get(name):
|
||||||
logger.info(f'Successfully imported {model_request.name}')
|
logger.info(f'Successfully imported {name}, got {info}')
|
||||||
|
return ImportModelResponse(
|
||||||
|
name = name,
|
||||||
|
info = info,
|
||||||
|
status = "success",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(f'Model {model_request.name} not imported')
|
logger.error(f'Model {name} not imported')
|
||||||
raise HTTPException(status_code=500, detail=f'Model {model_request.name} not imported')
|
raise HTTPException(status_code=404, detail=f'Model {name} not found')
|
||||||
|
|
||||||
@models_router.delete(
|
@models_router.delete(
|
||||||
"/{model_name}",
|
"/{model_name}",
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
from typing import Literal, Optional, Union, List
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
import copy
|
import copy
|
||||||
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||||
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||||
|
InvocationConfig, InvocationContext)
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(BaseModel):
|
||||||
model_name: str = Field(description="Info to load submodel")
|
model_name: str = Field(description="Info to load submodel")
|
||||||
@ -30,7 +31,6 @@ class VaeField(BaseModel):
|
|||||||
# TODO: better naming?
|
# TODO: better naming?
|
||||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||||
|
|
||||||
|
|
||||||
class ModelLoaderOutput(BaseInvocationOutput):
|
class ModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""Model loader output"""
|
||||||
|
|
||||||
@ -43,25 +43,26 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
|||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
class PipelineModelField(BaseModel):
|
class MainModelField(BaseModel):
|
||||||
"""Pipeline model field"""
|
"""Main model field"""
|
||||||
|
|
||||||
model_name: str = Field(description="Name of the model")
|
model_name: str = Field(description="Name of the model")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
|
||||||
class PipelineModelLoaderInvocation(BaseInvocation):
|
class MainModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a pipeline model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
|
type: Literal["main_model_loader"] = "main_model_loader"
|
||||||
|
|
||||||
model: PipelineModelField = Field(description="The model to load")
|
model: MainModelField = Field(description="The model to load")
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {
|
||||||
|
"title": "Model Loader",
|
||||||
"tags": ["model", "loader"],
|
"tags": ["model", "loader"],
|
||||||
"type_hints": {
|
"type_hints": {
|
||||||
"model": "model"
|
"model": "model"
|
||||||
@ -175,6 +176,14 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||||
|
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"title": "Lora Loader",
|
||||||
|
"tags": ["lora", "loader"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||||
|
|
||||||
# TODO: ui rewrite
|
# TODO: ui rewrite
|
||||||
@ -221,3 +230,56 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
class VAEModelField(BaseModel):
|
||||||
|
"""Vae model field"""
|
||||||
|
|
||||||
|
model_name: str = Field(description="Name of the model")
|
||||||
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
class VaeLoaderOutput(BaseInvocationOutput):
|
||||||
|
"""Model loader output"""
|
||||||
|
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["vae_loader_output"] = "vae_loader_output"
|
||||||
|
|
||||||
|
vae: VaeField = Field(default=None, description="Vae model")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
class VaeLoaderInvocation(BaseInvocation):
|
||||||
|
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||||
|
type: Literal["vae_loader"] = "vae_loader"
|
||||||
|
|
||||||
|
vae_model: VAEModelField = Field(description="The VAE to load")
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"title": "VAE Loader",
|
||||||
|
"tags": ["vae", "loader"],
|
||||||
|
"type_hints": {
|
||||||
|
"vae_model": "vae_model"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
||||||
|
base_model = self.vae_model.base_model
|
||||||
|
model_name = self.vae_model.model_name
|
||||||
|
model_type = ModelType.Vae
|
||||||
|
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
base_model=base_model,
|
||||||
|
model_name=model_name,
|
||||||
|
model_type=model_type,
|
||||||
|
):
|
||||||
|
raise Exception(f"Unkown vae name: {model_name}!")
|
||||||
|
return VaeLoaderOutput(
|
||||||
|
vae=VaeField(
|
||||||
|
vae = ModelInfo(
|
||||||
|
model_name = model_name,
|
||||||
|
base_model = base_model,
|
||||||
|
model_type = model_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@ -135,6 +135,29 @@ class ModelManagerServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def heuristic_import(self,
|
||||||
|
items_to_import: Set[str],
|
||||||
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||||
|
)->Dict[str, AddModelResult]:
|
||||||
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
|
successfully imported items.
|
||||||
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
|
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||||
|
|
||||||
|
The prediction type helper is necessary to distinguish between
|
||||||
|
models based on Stable Diffusion 2 Base (requiring
|
||||||
|
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||||
|
(requiring SchedulerPredictionType.VPrediction). It is
|
||||||
|
generally impossible to do this programmatically, so the
|
||||||
|
prediction_type_helper usually asks the user to choose.
|
||||||
|
|
||||||
|
The result is a set of successfully installed models. Each element
|
||||||
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||||
|
that model.
|
||||||
|
'''
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def commit(self, conf_file: Path = None) -> None:
|
def commit(self, conf_file: Path = None) -> None:
|
||||||
"""
|
"""
|
||||||
@ -361,3 +384,24 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
def logger(self):
|
def logger(self):
|
||||||
return self.mgr.logger
|
return self.mgr.logger
|
||||||
|
|
||||||
|
def heuristic_import(self,
|
||||||
|
items_to_import: Set[str],
|
||||||
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||||
|
)->Dict[str, AddModelResult]:
|
||||||
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
|
successfully imported items.
|
||||||
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
|
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||||
|
|
||||||
|
The prediction type helper is necessary to distinguish between
|
||||||
|
models based on Stable Diffusion 2 Base (requiring
|
||||||
|
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||||
|
(requiring SchedulerPredictionType.VPrediction). It is
|
||||||
|
generally impossible to do this programmatically, so the
|
||||||
|
prediction_type_helper usually asks the user to choose.
|
||||||
|
|
||||||
|
The result is a set of successfully installed models. Each element
|
||||||
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||||
|
that model.
|
||||||
|
'''
|
||||||
|
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||||
|
@ -18,7 +18,7 @@ from tqdm import tqdm
|
|||||||
import invokeai.configs as configs
|
import invokeai.configs as configs
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
|
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||||
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
||||||
from invokeai.backend.util import download_with_resume
|
from invokeai.backend.util import download_with_resume
|
||||||
from ..util.logging import InvokeAILogger
|
from ..util.logging import InvokeAILogger
|
||||||
@ -166,17 +166,22 @@ class ModelInstall(object):
|
|||||||
# add requested models
|
# add requested models
|
||||||
for path in selections.install_models:
|
for path in selections.install_models:
|
||||||
logger.info(f'Installing {path} [{job}/{jobs}]')
|
logger.info(f'Installing {path} [{job}/{jobs}]')
|
||||||
self.heuristic_install(path)
|
self.heuristic_import(path)
|
||||||
job += 1
|
job += 1
|
||||||
|
|
||||||
self.mgr.commit()
|
self.mgr.commit()
|
||||||
|
|
||||||
def heuristic_install(self,
|
def heuristic_import(self,
|
||||||
model_path_id_or_url: Union[str,Path],
|
model_path_id_or_url: Union[str,Path],
|
||||||
models_installed: Set[Path]=None)->Set[Path]:
|
models_installed: Set[Path]=None)->Dict[str, AddModelResult]:
|
||||||
|
'''
|
||||||
|
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
||||||
|
:param models_installed: Set of installed models, used for recursive invocation
|
||||||
|
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
||||||
|
'''
|
||||||
|
|
||||||
if not models_installed:
|
if not models_installed:
|
||||||
models_installed = set()
|
models_installed = dict()
|
||||||
|
|
||||||
# A little hack to allow nested routines to retrieve info on the requested ID
|
# A little hack to allow nested routines to retrieve info on the requested ID
|
||||||
self.current_id = model_path_id_or_url
|
self.current_id = model_path_id_or_url
|
||||||
@ -185,24 +190,24 @@ class ModelInstall(object):
|
|||||||
try:
|
try:
|
||||||
# checkpoint file, or similar
|
# checkpoint file, or similar
|
||||||
if path.is_file():
|
if path.is_file():
|
||||||
models_installed.add(self._install_path(path))
|
models_installed.update(self._install_path(path))
|
||||||
|
|
||||||
# folders style or similar
|
# folders style or similar
|
||||||
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||||
models_installed.add(self._install_path(path))
|
models_installed.update(self._install_path(path))
|
||||||
|
|
||||||
# recursive scan
|
# recursive scan
|
||||||
elif path.is_dir():
|
elif path.is_dir():
|
||||||
for child in path.iterdir():
|
for child in path.iterdir():
|
||||||
self.heuristic_install(child, models_installed=models_installed)
|
self.heuristic_import(child, models_installed=models_installed)
|
||||||
|
|
||||||
# huggingface repo
|
# huggingface repo
|
||||||
elif len(str(path).split('/')) == 2:
|
elif len(str(path).split('/')) == 2:
|
||||||
models_installed.add(self._install_repo(str(path)))
|
models_installed.update(self._install_repo(str(path)))
|
||||||
|
|
||||||
# a URL
|
# a URL
|
||||||
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
|
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
|
||||||
models_installed.add(self._install_url(model_path_id_or_url))
|
models_installed.update(self._install_url(model_path_id_or_url))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||||
@ -214,24 +219,25 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
# install a model from a local path. The optional info parameter is there to prevent
|
# install a model from a local path. The optional info parameter is there to prevent
|
||||||
# the model from being probed twice in the event that it has already been probed.
|
# the model from being probed twice in the event that it has already been probed.
|
||||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path:
|
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]:
|
||||||
try:
|
try:
|
||||||
# logger.debug(f'Probing {path}')
|
model_result = None
|
||||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||||
model_name = path.stem if info.format=='checkpoint' else path.name
|
model_name = path.stem if info.format=='checkpoint' else path.name
|
||||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||||
attributes = self._make_attributes(path,info)
|
attributes = self._make_attributes(path,info)
|
||||||
self.mgr.add_model(model_name = model_name,
|
model_result = self.mgr.add_model(model_name = model_name,
|
||||||
base_model = info.base_type,
|
base_model = info.base_type,
|
||||||
model_type = info.model_type,
|
model_type = info.model_type,
|
||||||
model_attributes = attributes,
|
model_attributes = attributes,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f'{str(e)} Skipping registration.')
|
logger.warning(f'{str(e)} Skipping registration.')
|
||||||
return path
|
return {}
|
||||||
|
return {str(path): model_result}
|
||||||
|
|
||||||
def _install_url(self, url: str)->Path:
|
def _install_url(self, url: str)->dict:
|
||||||
# copy to a staging area, probe, import and delete
|
# copy to a staging area, probe, import and delete
|
||||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
location = download_with_resume(url,Path(staging))
|
location = download_with_resume(url,Path(staging))
|
||||||
@ -244,7 +250,7 @@ class ModelInstall(object):
|
|||||||
# staged version will be garbage-collected at this time
|
# staged version will be garbage-collected at this time
|
||||||
return self._install_path(Path(models_path), info)
|
return self._install_path(Path(models_path), info)
|
||||||
|
|
||||||
def _install_repo(self, repo_id: str)->Path:
|
def _install_repo(self, repo_id: str)->dict:
|
||||||
hinfo = HfApi().model_info(repo_id)
|
hinfo = HfApi().model_info(repo_id)
|
||||||
|
|
||||||
# we try to figure out how to download this most economically
|
# we try to figure out how to download this most economically
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.backend.model_management
|
Initialization file for invokeai.backend.model_management
|
||||||
"""
|
"""
|
||||||
from .model_manager import ModelManager, ModelInfo
|
from .model_manager import ModelManager, ModelInfo, AddModelResult
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
||||||
|
|
||||||
|
@ -233,14 +233,14 @@ import hashlib
|
|||||||
import textwrap
|
import textwrap
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List, Tuple, Union, Set, Callable, types
|
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
@ -278,8 +278,13 @@ class InvalidModelError(Exception):
|
|||||||
"Raised when an invalid model is requested"
|
"Raised when an invalid model is requested"
|
||||||
pass
|
pass
|
||||||
|
|
||||||
MAX_CACHE_SIZE = 6.0 # GB
|
class AddModelResult(BaseModel):
|
||||||
|
name: str = Field(description="The name of the model after import")
|
||||||
|
model_type: ModelType = Field(description="The type of model")
|
||||||
|
base_model: BaseModelType = Field(description="The base model")
|
||||||
|
config: ModelConfigBase = Field(description="The configuration of the model")
|
||||||
|
|
||||||
|
MAX_CACHE_SIZE = 6.0 # GB
|
||||||
|
|
||||||
class ConfigMeta(BaseModel):
|
class ConfigMeta(BaseModel):
|
||||||
version: str
|
version: str
|
||||||
@ -571,13 +576,16 @@ class ModelManager(object):
|
|||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_attributes: dict,
|
model_attributes: dict,
|
||||||
clobber: bool = False,
|
clobber: bool = False,
|
||||||
) -> None:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
On a successful update, the config will be changed in memory and the
|
On a successful update, the config will be changed in memory and the
|
||||||
method will return True. Will fail with an assertion error if provided
|
method will return True. Will fail with an assertion error if provided
|
||||||
attributes are incorrect or the model name is missing.
|
attributes are incorrect or the model name is missing.
|
||||||
|
|
||||||
|
The returned dict has the same format as the dict returned by
|
||||||
|
model_info().
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
@ -601,12 +609,18 @@ class ModelManager(object):
|
|||||||
old_model_cache.unlink()
|
old_model_cache.unlink()
|
||||||
|
|
||||||
# remove in-memory cache
|
# remove in-memory cache
|
||||||
# note: it not garantie to release memory(model can has other references)
|
# note: it not guaranteed to release memory(model can has other references)
|
||||||
cache_ids = self.cache_keys.pop(model_key, [])
|
cache_ids = self.cache_keys.pop(model_key, [])
|
||||||
for cache_id in cache_ids:
|
for cache_id in cache_ids:
|
||||||
self.cache.uncache_model(cache_id)
|
self.cache.uncache_model(cache_id)
|
||||||
|
|
||||||
self.models[model_key] = model_config
|
self.models[model_key] = model_config
|
||||||
|
return AddModelResult(
|
||||||
|
name = model_name,
|
||||||
|
model_type = model_type,
|
||||||
|
base_model = base_model,
|
||||||
|
config = model_config,
|
||||||
|
)
|
||||||
|
|
||||||
def search_models(self, search_folder):
|
def search_models(self, search_folder):
|
||||||
self.logger.info(f"Finding Models In: {search_folder}")
|
self.logger.info(f"Finding Models In: {search_folder}")
|
||||||
@ -729,7 +743,7 @@ class ModelManager(object):
|
|||||||
if (new_models_found or imported_models) and self.config_path:
|
if (new_models_found or imported_models) and self.config_path:
|
||||||
self.commit()
|
self.commit()
|
||||||
|
|
||||||
def autoimport(self)->set[Path]:
|
def autoimport(self)->Dict[str, AddModelResult]:
|
||||||
'''
|
'''
|
||||||
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
||||||
'''
|
'''
|
||||||
@ -742,7 +756,6 @@ class ModelManager(object):
|
|||||||
prediction_type_helper = ask_user_for_prediction_type,
|
prediction_type_helper = ask_user_for_prediction_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
installed = set()
|
|
||||||
scanned_dirs = set()
|
scanned_dirs = set()
|
||||||
|
|
||||||
config = self.app_config
|
config = self.app_config
|
||||||
@ -756,13 +769,14 @@ class ModelManager(object):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
self.logger.info(f'Scanning {autodir} for models to import')
|
self.logger.info(f'Scanning {autodir} for models to import')
|
||||||
|
installed = dict()
|
||||||
|
|
||||||
autodir = self.app_config.root_path / autodir
|
autodir = self.app_config.root_path / autodir
|
||||||
if not autodir.exists():
|
if not autodir.exists():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
items_scanned = 0
|
items_scanned = 0
|
||||||
new_models_found = set()
|
new_models_found = dict()
|
||||||
|
|
||||||
for root, dirs, files in os.walk(autodir):
|
for root, dirs, files in os.walk(autodir):
|
||||||
items_scanned += len(dirs) + len(files)
|
items_scanned += len(dirs) + len(files)
|
||||||
@ -772,7 +786,7 @@ class ModelManager(object):
|
|||||||
scanned_dirs.add(path)
|
scanned_dirs.add(path)
|
||||||
continue
|
continue
|
||||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||||
new_models_found.update(installer.heuristic_install(path))
|
new_models_found.update(installer.heuristic_import(path))
|
||||||
scanned_dirs.add(path)
|
scanned_dirs.add(path)
|
||||||
|
|
||||||
for f in files:
|
for f in files:
|
||||||
@ -780,7 +794,7 @@ class ModelManager(object):
|
|||||||
if path in known_paths or path.parent in scanned_dirs:
|
if path in known_paths or path.parent in scanned_dirs:
|
||||||
continue
|
continue
|
||||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||||
new_models_found.update(installer.heuristic_install(path))
|
new_models_found.update(installer.heuristic_import(path))
|
||||||
|
|
||||||
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
||||||
installed.update(new_models_found)
|
installed.update(new_models_found)
|
||||||
@ -790,7 +804,7 @@ class ModelManager(object):
|
|||||||
def heuristic_import(self,
|
def heuristic_import(self,
|
||||||
items_to_import: Set[str],
|
items_to_import: Set[str],
|
||||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||||
)->Set[str]:
|
)->Dict[str, AddModelResult]:
|
||||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
successfully imported items.
|
successfully imported items.
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
@ -803,17 +817,20 @@ class ModelManager(object):
|
|||||||
generally impossible to do this programmatically, so the
|
generally impossible to do this programmatically, so the
|
||||||
prediction_type_helper usually asks the user to choose.
|
prediction_type_helper usually asks the user to choose.
|
||||||
|
|
||||||
|
The result is a set of successfully installed models. Each element
|
||||||
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||||
|
that model.
|
||||||
'''
|
'''
|
||||||
# avoid circular import here
|
# avoid circular import here
|
||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
successfully_installed = set()
|
successfully_installed = dict()
|
||||||
|
|
||||||
installer = ModelInstall(config = self.app_config,
|
installer = ModelInstall(config = self.app_config,
|
||||||
prediction_type_helper = prediction_type_helper,
|
prediction_type_helper = prediction_type_helper,
|
||||||
model_manager = self)
|
model_manager = self)
|
||||||
for thing in items_to_import:
|
for thing in items_to_import:
|
||||||
try:
|
try:
|
||||||
installed = installer.heuristic_install(thing)
|
installed = installer.heuristic_import(thing)
|
||||||
successfully_installed.update(installed)
|
successfully_installed.update(installed)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning(f'{thing} could not be imported: {str(e)}')
|
self.logger.warning(f'{thing} could not be imported: {str(e)}')
|
||||||
|
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
|||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
<script type="module" crossorigin src="./assets/index-8a3e9251.js"></script>
|
<script type="module" crossorigin src="./assets/index-c0367e37.js"></script>
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
<body dir="ltr">
|
<body dir="ltr">
|
||||||
|
17
invokeai/frontend/web/dist/locales/en.json
vendored
17
invokeai/frontend/web/dist/locales/en.json
vendored
@ -24,16 +24,13 @@
|
|||||||
},
|
},
|
||||||
"common": {
|
"common": {
|
||||||
"hotkeysLabel": "Hotkeys",
|
"hotkeysLabel": "Hotkeys",
|
||||||
"themeLabel": "Theme",
|
"darkMode": "Dark Mode",
|
||||||
|
"lightMode": "Light Mode",
|
||||||
"languagePickerLabel": "Language",
|
"languagePickerLabel": "Language",
|
||||||
"reportBugLabel": "Report Bug",
|
"reportBugLabel": "Report Bug",
|
||||||
"githubLabel": "Github",
|
"githubLabel": "Github",
|
||||||
"discordLabel": "Discord",
|
"discordLabel": "Discord",
|
||||||
"settingsLabel": "Settings",
|
"settingsLabel": "Settings",
|
||||||
"darkTheme": "Dark",
|
|
||||||
"lightTheme": "Light",
|
|
||||||
"greenTheme": "Green",
|
|
||||||
"oceanTheme": "Ocean",
|
|
||||||
"langArabic": "العربية",
|
"langArabic": "العربية",
|
||||||
"langEnglish": "English",
|
"langEnglish": "English",
|
||||||
"langDutch": "Nederlands",
|
"langDutch": "Nederlands",
|
||||||
@ -55,6 +52,7 @@
|
|||||||
"unifiedCanvas": "Unified Canvas",
|
"unifiedCanvas": "Unified Canvas",
|
||||||
"linear": "Linear",
|
"linear": "Linear",
|
||||||
"nodes": "Node Editor",
|
"nodes": "Node Editor",
|
||||||
|
"modelmanager": "Model Manager",
|
||||||
"postprocessing": "Post Processing",
|
"postprocessing": "Post Processing",
|
||||||
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
||||||
"postProcessing": "Post Processing",
|
"postProcessing": "Post Processing",
|
||||||
@ -336,6 +334,7 @@
|
|||||||
"modelManager": {
|
"modelManager": {
|
||||||
"modelManager": "Model Manager",
|
"modelManager": "Model Manager",
|
||||||
"model": "Model",
|
"model": "Model",
|
||||||
|
"vae": "VAE",
|
||||||
"allModels": "All Models",
|
"allModels": "All Models",
|
||||||
"checkpointModels": "Checkpoints",
|
"checkpointModels": "Checkpoints",
|
||||||
"diffusersModels": "Diffusers",
|
"diffusersModels": "Diffusers",
|
||||||
@ -351,6 +350,7 @@
|
|||||||
"scanForModels": "Scan For Models",
|
"scanForModels": "Scan For Models",
|
||||||
"addManually": "Add Manually",
|
"addManually": "Add Manually",
|
||||||
"manual": "Manual",
|
"manual": "Manual",
|
||||||
|
"baseModel": "Base Model",
|
||||||
"name": "Name",
|
"name": "Name",
|
||||||
"nameValidationMsg": "Enter a name for your model",
|
"nameValidationMsg": "Enter a name for your model",
|
||||||
"description": "Description",
|
"description": "Description",
|
||||||
@ -363,6 +363,7 @@
|
|||||||
"repoIDValidationMsg": "Online repository of your model",
|
"repoIDValidationMsg": "Online repository of your model",
|
||||||
"vaeLocation": "VAE Location",
|
"vaeLocation": "VAE Location",
|
||||||
"vaeLocationValidationMsg": "Path to where your VAE is located.",
|
"vaeLocationValidationMsg": "Path to where your VAE is located.",
|
||||||
|
"variant": "Variant",
|
||||||
"vaeRepoID": "VAE Repo ID",
|
"vaeRepoID": "VAE Repo ID",
|
||||||
"vaeRepoIDValidationMsg": "Online repository of your VAE",
|
"vaeRepoIDValidationMsg": "Online repository of your VAE",
|
||||||
"width": "Width",
|
"width": "Width",
|
||||||
@ -524,7 +525,8 @@
|
|||||||
"initialImage": "Initial Image",
|
"initialImage": "Initial Image",
|
||||||
"showOptionsPanel": "Show Options Panel",
|
"showOptionsPanel": "Show Options Panel",
|
||||||
"hidePreview": "Hide Preview",
|
"hidePreview": "Hide Preview",
|
||||||
"showPreview": "Show Preview"
|
"showPreview": "Show Preview",
|
||||||
|
"controlNetControlMode": "Control Mode"
|
||||||
},
|
},
|
||||||
"settings": {
|
"settings": {
|
||||||
"models": "Models",
|
"models": "Models",
|
||||||
@ -547,7 +549,8 @@
|
|||||||
"general": "General",
|
"general": "General",
|
||||||
"generation": "Generation",
|
"generation": "Generation",
|
||||||
"ui": "User Interface",
|
"ui": "User Interface",
|
||||||
"availableSchedulers": "Available Schedulers"
|
"favoriteSchedulers": "Favorite Schedulers",
|
||||||
|
"favoriteSchedulersPlaceholder": "No schedulers favorited"
|
||||||
},
|
},
|
||||||
"toast": {
|
"toast": {
|
||||||
"serverError": "Server Error",
|
"serverError": "Server Error",
|
||||||
|
@ -67,6 +67,7 @@
|
|||||||
"@fontsource-variable/inter": "^5.0.3",
|
"@fontsource-variable/inter": "^5.0.3",
|
||||||
"@fontsource/inter": "^5.0.3",
|
"@fontsource/inter": "^5.0.3",
|
||||||
"@mantine/core": "^6.0.14",
|
"@mantine/core": "^6.0.14",
|
||||||
|
"@mantine/form": "^6.0.15",
|
||||||
"@mantine/hooks": "^6.0.14",
|
"@mantine/hooks": "^6.0.14",
|
||||||
"@reduxjs/toolkit": "^1.9.5",
|
"@reduxjs/toolkit": "^1.9.5",
|
||||||
"@roarr/browser-log-writer": "^1.1.5",
|
"@roarr/browser-log-writer": "^1.1.5",
|
||||||
|
@ -53,6 +53,7 @@
|
|||||||
"linear": "Linear",
|
"linear": "Linear",
|
||||||
"nodes": "Node Editor",
|
"nodes": "Node Editor",
|
||||||
"batch": "Batch Manager",
|
"batch": "Batch Manager",
|
||||||
|
"modelmanager": "Model Manager",
|
||||||
"postprocessing": "Post Processing",
|
"postprocessing": "Post Processing",
|
||||||
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
||||||
"postProcessing": "Post Processing",
|
"postProcessing": "Post Processing",
|
||||||
@ -334,6 +335,7 @@
|
|||||||
"modelManager": {
|
"modelManager": {
|
||||||
"modelManager": "Model Manager",
|
"modelManager": "Model Manager",
|
||||||
"model": "Model",
|
"model": "Model",
|
||||||
|
"vae": "VAE",
|
||||||
"allModels": "All Models",
|
"allModels": "All Models",
|
||||||
"checkpointModels": "Checkpoints",
|
"checkpointModels": "Checkpoints",
|
||||||
"diffusersModels": "Diffusers",
|
"diffusersModels": "Diffusers",
|
||||||
@ -349,6 +351,7 @@
|
|||||||
"scanForModels": "Scan For Models",
|
"scanForModels": "Scan For Models",
|
||||||
"addManually": "Add Manually",
|
"addManually": "Add Manually",
|
||||||
"manual": "Manual",
|
"manual": "Manual",
|
||||||
|
"baseModel": "Base Model",
|
||||||
"name": "Name",
|
"name": "Name",
|
||||||
"nameValidationMsg": "Enter a name for your model",
|
"nameValidationMsg": "Enter a name for your model",
|
||||||
"description": "Description",
|
"description": "Description",
|
||||||
@ -361,6 +364,7 @@
|
|||||||
"repoIDValidationMsg": "Online repository of your model",
|
"repoIDValidationMsg": "Online repository of your model",
|
||||||
"vaeLocation": "VAE Location",
|
"vaeLocation": "VAE Location",
|
||||||
"vaeLocationValidationMsg": "Path to where your VAE is located.",
|
"vaeLocationValidationMsg": "Path to where your VAE is located.",
|
||||||
|
"variant": "Variant",
|
||||||
"vaeRepoID": "VAE Repo ID",
|
"vaeRepoID": "VAE Repo ID",
|
||||||
"vaeRepoIDValidationMsg": "Online repository of your VAE",
|
"vaeRepoIDValidationMsg": "Online repository of your VAE",
|
||||||
"width": "Width",
|
"width": "Width",
|
||||||
|
@ -4,6 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { PartialAppConfig } from 'app/types/invokeai';
|
import { PartialAppConfig } from 'app/types/invokeai';
|
||||||
import ImageUploader from 'common/components/ImageUploader';
|
import ImageUploader from 'common/components/ImageUploader';
|
||||||
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
|
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
|
||||||
|
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
|
||||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||||
import SiteHeader from 'features/system/components/SiteHeader';
|
import SiteHeader from 'features/system/components/SiteHeader';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
@ -15,11 +16,10 @@ import InvokeTabs from 'features/ui/components/InvokeTabs';
|
|||||||
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
|
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
|
||||||
import i18n from 'i18n';
|
import i18n from 'i18n';
|
||||||
import { ReactNode, memo, useEffect } from 'react';
|
import { ReactNode, memo, useEffect } from 'react';
|
||||||
|
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
|
||||||
|
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||||
import GlobalHotkeys from './GlobalHotkeys';
|
import GlobalHotkeys from './GlobalHotkeys';
|
||||||
import Toaster from './Toaster';
|
import Toaster from './Toaster';
|
||||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
|
||||||
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
|
|
||||||
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
|
|
||||||
|
|
||||||
const DEFAULT_CONFIG = {};
|
const DEFAULT_CONFIG = {};
|
||||||
|
|
||||||
|
@ -3,20 +3,21 @@ import { memo } from 'react';
|
|||||||
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
||||||
import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent';
|
import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent';
|
||||||
import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
|
import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
|
||||||
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
|
||||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
|
||||||
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
|
||||||
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
|
||||||
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
|
|
||||||
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
|
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
|
||||||
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
||||||
|
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
||||||
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
||||||
|
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
||||||
|
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
||||||
|
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||||
|
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
|
||||||
|
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
||||||
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
||||||
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
||||||
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
||||||
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
|
||||||
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
|
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
||||||
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
|
||||||
|
|
||||||
type InputFieldComponentProps = {
|
type InputFieldComponentProps = {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
@ -152,6 +153,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (type === 'vae_model' && template.type === 'vae_model') {
|
||||||
|
return (
|
||||||
|
<VaeModelInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (type === 'array' && template.type === 'array') {
|
if (type === 'array' && template.type === 'array') {
|
||||||
return (
|
return (
|
||||||
<ArrayInputFieldComponent
|
<ArrayInputFieldComponent
|
||||||
|
@ -6,13 +6,13 @@ import {
|
|||||||
ModelInputFieldValue,
|
ModelInputFieldValue,
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
|
|
||||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
|
||||||
import { FieldComponentProps } from './types';
|
|
||||||
import { forEach, isString } from 'lodash-es';
|
|
||||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||||
|
import { forEach, isString } from 'lodash-es';
|
||||||
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
const ModelInputFieldComponent = (
|
const ModelInputFieldComponent = (
|
||||||
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
||||||
@ -22,18 +22,18 @@ const ModelInputFieldComponent = (
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { data: pipelineModels } = useListModelsQuery({
|
const { data: mainModels } = useListModelsQuery({
|
||||||
model_type: 'main',
|
model_type: 'main',
|
||||||
});
|
});
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!pipelineModels) {
|
if (!mainModels) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
const data: SelectItem[] = [];
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
forEach(pipelineModels.entities, (model, id) => {
|
forEach(mainModels.entities, (model, id) => {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -46,11 +46,11 @@ const ModelInputFieldComponent = (
|
|||||||
});
|
});
|
||||||
|
|
||||||
return data;
|
return data;
|
||||||
}, [pipelineModels]);
|
}, [mainModels]);
|
||||||
|
|
||||||
const selectedModel = useMemo(
|
const selectedModel = useMemo(
|
||||||
() => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]],
|
() => mainModels?.entities[field.value ?? mainModels.ids[0]],
|
||||||
[pipelineModels?.entities, pipelineModels?.ids, field.value]
|
[mainModels?.entities, mainModels?.ids, field.value]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleValueChanged = useCallback(
|
const handleValueChanged = useCallback(
|
||||||
@ -71,18 +71,18 @@ const ModelInputFieldComponent = (
|
|||||||
);
|
);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (field.value && pipelineModels?.ids.includes(field.value)) {
|
if (field.value && mainModels?.ids.includes(field.value)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const firstModel = pipelineModels?.ids[0];
|
const firstModel = mainModels?.ids[0];
|
||||||
|
|
||||||
if (!isString(firstModel)) {
|
if (!isString(firstModel)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
handleValueChanged(firstModel);
|
handleValueChanged(firstModel);
|
||||||
}, [field.value, handleValueChanged, pipelineModels?.ids]);
|
}, [field.value, handleValueChanged, mainModels?.ids]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
|
@ -0,0 +1,97 @@
|
|||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import {
|
||||||
|
VaeModelInputFieldTemplate,
|
||||||
|
VaeModelInputFieldValue,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
const VaeModelInputFieldComponent = (
|
||||||
|
props: FieldComponentProps<
|
||||||
|
VaeModelInputFieldValue,
|
||||||
|
VaeModelInputFieldTemplate
|
||||||
|
>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const { data: vaeModels } = useListModelsQuery({
|
||||||
|
model_type: 'vae',
|
||||||
|
});
|
||||||
|
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() => vaeModels?.entities[field.value ?? vaeModels.ids[0]],
|
||||||
|
[vaeModels?.entities, vaeModels?.ids, field.value]
|
||||||
|
);
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!vaeModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(vaeModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.name,
|
||||||
|
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [vaeModels]);
|
||||||
|
|
||||||
|
const handleValueChanged = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
fieldValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value: v,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (field.value && vaeModels?.ids.includes(field.value)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
handleValueChanged('auto');
|
||||||
|
}, [field.value, handleValueChanged, vaeModels?.ids]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIMantineSelect
|
||||||
|
tooltip={selectedModel?.description}
|
||||||
|
label={
|
||||||
|
selectedModel?.base_model &&
|
||||||
|
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
||||||
|
}
|
||||||
|
value={field.value}
|
||||||
|
placeholder="Pick one"
|
||||||
|
data={data}
|
||||||
|
onChange={handleValueChanged}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(VaeModelInputFieldComponent);
|
@ -17,6 +17,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
|||||||
ClipField: 'clip',
|
ClipField: 'clip',
|
||||||
VaeField: 'vae',
|
VaeField: 'vae',
|
||||||
model: 'model',
|
model: 'model',
|
||||||
|
vae_model: 'vae_model',
|
||||||
array: 'array',
|
array: 'array',
|
||||||
item: 'item',
|
item: 'item',
|
||||||
ColorField: 'color',
|
ColorField: 'color',
|
||||||
@ -116,6 +117,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
|||||||
title: 'Model',
|
title: 'Model',
|
||||||
description: 'Models are models.',
|
description: 'Models are models.',
|
||||||
},
|
},
|
||||||
|
vae_model: {
|
||||||
|
color: 'teal',
|
||||||
|
colorCssVar: getColorTokenCssVariable('teal'),
|
||||||
|
title: 'Model',
|
||||||
|
description: 'Models are models.',
|
||||||
|
},
|
||||||
array: {
|
array: {
|
||||||
color: 'gray',
|
color: 'gray',
|
||||||
colorCssVar: getColorTokenCssVariable('gray'),
|
colorCssVar: getColorTokenCssVariable('gray'),
|
||||||
|
@ -64,6 +64,7 @@ export type FieldType =
|
|||||||
| 'vae'
|
| 'vae'
|
||||||
| 'control'
|
| 'control'
|
||||||
| 'model'
|
| 'model'
|
||||||
|
| 'vae_model'
|
||||||
| 'array'
|
| 'array'
|
||||||
| 'item'
|
| 'item'
|
||||||
| 'color'
|
| 'color'
|
||||||
@ -91,6 +92,7 @@ export type InputFieldValue =
|
|||||||
| ControlInputFieldValue
|
| ControlInputFieldValue
|
||||||
| EnumInputFieldValue
|
| EnumInputFieldValue
|
||||||
| ModelInputFieldValue
|
| ModelInputFieldValue
|
||||||
|
| VaeModelInputFieldValue
|
||||||
| ArrayInputFieldValue
|
| ArrayInputFieldValue
|
||||||
| ItemInputFieldValue
|
| ItemInputFieldValue
|
||||||
| ColorInputFieldValue
|
| ColorInputFieldValue
|
||||||
@ -116,6 +118,7 @@ export type InputFieldTemplate =
|
|||||||
| ControlInputFieldTemplate
|
| ControlInputFieldTemplate
|
||||||
| EnumInputFieldTemplate
|
| EnumInputFieldTemplate
|
||||||
| ModelInputFieldTemplate
|
| ModelInputFieldTemplate
|
||||||
|
| VaeModelInputFieldTemplate
|
||||||
| ArrayInputFieldTemplate
|
| ArrayInputFieldTemplate
|
||||||
| ItemInputFieldTemplate
|
| ItemInputFieldTemplate
|
||||||
| ColorInputFieldTemplate
|
| ColorInputFieldTemplate
|
||||||
@ -228,6 +231,11 @@ export type ModelInputFieldValue = FieldValueBase & {
|
|||||||
value?: string;
|
value?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type VaeModelInputFieldValue = FieldValueBase & {
|
||||||
|
type: 'vae_model';
|
||||||
|
value?: string;
|
||||||
|
};
|
||||||
|
|
||||||
export type ArrayInputFieldValue = FieldValueBase & {
|
export type ArrayInputFieldValue = FieldValueBase & {
|
||||||
type: 'array';
|
type: 'array';
|
||||||
value?: (string | number)[];
|
value?: (string | number)[];
|
||||||
@ -305,6 +313,21 @@ export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'conditioning';
|
type: 'conditioning';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type UNetInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: undefined;
|
||||||
|
type: 'unet';
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ClipInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: undefined;
|
||||||
|
type: 'clip';
|
||||||
|
};
|
||||||
|
|
||||||
|
export type VaeInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: undefined;
|
||||||
|
type: 'vae';
|
||||||
|
};
|
||||||
|
|
||||||
export type ControlInputFieldTemplate = InputFieldTemplateBase & {
|
export type ControlInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: undefined;
|
default: undefined;
|
||||||
type: 'control';
|
type: 'control';
|
||||||
@ -322,6 +345,11 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'model';
|
type: 'model';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: string;
|
||||||
|
type: 'vae_model';
|
||||||
|
};
|
||||||
|
|
||||||
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
|
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: [];
|
default: [];
|
||||||
type: 'array';
|
type: 'array';
|
||||||
|
@ -3,27 +3,28 @@ import { OpenAPIV3 } from 'openapi-types';
|
|||||||
import { FIELD_TYPE_MAP } from '../types/constants';
|
import { FIELD_TYPE_MAP } from '../types/constants';
|
||||||
import { isSchemaObject } from '../types/typeGuards';
|
import { isSchemaObject } from '../types/typeGuards';
|
||||||
import {
|
import {
|
||||||
BooleanInputFieldTemplate,
|
|
||||||
EnumInputFieldTemplate,
|
|
||||||
FloatInputFieldTemplate,
|
|
||||||
ImageInputFieldTemplate,
|
|
||||||
IntegerInputFieldTemplate,
|
|
||||||
LatentsInputFieldTemplate,
|
|
||||||
ConditioningInputFieldTemplate,
|
|
||||||
UNetInputFieldTemplate,
|
|
||||||
ClipInputFieldTemplate,
|
|
||||||
VaeInputFieldTemplate,
|
|
||||||
ControlInputFieldTemplate,
|
|
||||||
StringInputFieldTemplate,
|
|
||||||
ModelInputFieldTemplate,
|
|
||||||
ArrayInputFieldTemplate,
|
ArrayInputFieldTemplate,
|
||||||
ItemInputFieldTemplate,
|
BooleanInputFieldTemplate,
|
||||||
|
ClipInputFieldTemplate,
|
||||||
ColorInputFieldTemplate,
|
ColorInputFieldTemplate,
|
||||||
InputFieldTemplateBase,
|
ConditioningInputFieldTemplate,
|
||||||
OutputFieldTemplate,
|
ControlInputFieldTemplate,
|
||||||
TypeHints,
|
EnumInputFieldTemplate,
|
||||||
FieldType,
|
FieldType,
|
||||||
|
FloatInputFieldTemplate,
|
||||||
ImageCollectionInputFieldTemplate,
|
ImageCollectionInputFieldTemplate,
|
||||||
|
ImageInputFieldTemplate,
|
||||||
|
InputFieldTemplateBase,
|
||||||
|
IntegerInputFieldTemplate,
|
||||||
|
ItemInputFieldTemplate,
|
||||||
|
LatentsInputFieldTemplate,
|
||||||
|
ModelInputFieldTemplate,
|
||||||
|
OutputFieldTemplate,
|
||||||
|
StringInputFieldTemplate,
|
||||||
|
TypeHints,
|
||||||
|
UNetInputFieldTemplate,
|
||||||
|
VaeInputFieldTemplate,
|
||||||
|
VaeModelInputFieldTemplate,
|
||||||
} from '../types/types';
|
} from '../types/types';
|
||||||
|
|
||||||
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
||||||
@ -175,6 +176,21 @@ const buildModelInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildVaeModelInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): VaeModelInputFieldTemplate => {
|
||||||
|
const template: VaeModelInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'vae_model',
|
||||||
|
inputRequirement: 'always',
|
||||||
|
inputKind: 'direct',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildImageInputFieldTemplate = ({
|
const buildImageInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -441,6 +457,9 @@ export const buildInputFieldTemplate = (
|
|||||||
if (['model'].includes(fieldType)) {
|
if (['model'].includes(fieldType)) {
|
||||||
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
}
|
}
|
||||||
|
if (['vae_model'].includes(fieldType)) {
|
||||||
|
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
|
}
|
||||||
if (['enum'].includes(fieldType)) {
|
if (['enum'].includes(fieldType)) {
|
||||||
return buildEnumInputFieldTemplate({ schemaObject, baseField });
|
return buildEnumInputFieldTemplate({ schemaObject, baseField });
|
||||||
}
|
}
|
||||||
|
@ -75,6 +75,10 @@ export const buildInputFieldValue = (
|
|||||||
if (template.type === 'model') {
|
if (template.type === 'model') {
|
||||||
fieldValue.value = undefined;
|
fieldValue.value = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (template.type === 'vae_model') {
|
||||||
|
fieldValue.value = undefined;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return fieldValue;
|
return fieldValue;
|
||||||
|
@ -0,0 +1,68 @@
|
|||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
|
||||||
|
import {
|
||||||
|
IMAGE_TO_IMAGE_GRAPH,
|
||||||
|
IMAGE_TO_LATENTS,
|
||||||
|
INPAINT,
|
||||||
|
INPAINT_GRAPH,
|
||||||
|
LATENTS_TO_IMAGE,
|
||||||
|
MAIN_MODEL_LOADER,
|
||||||
|
TEXT_TO_IMAGE_GRAPH,
|
||||||
|
VAE_LOADER,
|
||||||
|
} from './constants';
|
||||||
|
|
||||||
|
export const addVAEToGraph = (
|
||||||
|
graph: NonNullableGraph,
|
||||||
|
state: RootState
|
||||||
|
): void => {
|
||||||
|
const { vae: vaeId } = state.generation;
|
||||||
|
const vae_model = modelIdToVAEModelField(vaeId);
|
||||||
|
|
||||||
|
if (vaeId !== 'auto') {
|
||||||
|
graph.nodes[VAE_LOADER] = {
|
||||||
|
type: 'vae_loader',
|
||||||
|
id: VAE_LOADER,
|
||||||
|
vae_model,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (graph.id === IMAGE_TO_IMAGE_GRAPH) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: IMAGE_TO_LATENTS,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (graph.id === INPAINT_GRAPH) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: INPAINT,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
@ -1,31 +1,26 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import {
|
import {
|
||||||
ImageDTO,
|
ImageDTO,
|
||||||
ImageResizeInvocation,
|
ImageResizeInvocation,
|
||||||
ImageToLatentsInvocation,
|
ImageToLatentsInvocation,
|
||||||
RandomIntInvocation,
|
|
||||||
RangeOfSizeInvocation,
|
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
import { log } from 'app/logging/useLogger';
|
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||||
|
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||||
|
import { addVAEToGraph } from './addVAEToGraph';
|
||||||
import {
|
import {
|
||||||
ITERATE,
|
IMAGE_TO_IMAGE_GRAPH,
|
||||||
|
IMAGE_TO_LATENTS,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
PIPELINE_MODEL_LOADER,
|
LATENTS_TO_LATENTS,
|
||||||
|
MAIN_MODEL_LOADER,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
RANDOM_INT,
|
|
||||||
RANGE_OF_SIZE,
|
|
||||||
IMAGE_TO_IMAGE_GRAPH,
|
|
||||||
IMAGE_TO_LATENTS,
|
|
||||||
LATENTS_TO_LATENTS,
|
|
||||||
RESIZE,
|
RESIZE,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { set } from 'lodash-es';
|
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
|
||||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
|
||||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
@ -52,7 +47,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
// The bounding box determines width and height, not the width and height params
|
// The bounding box determines width and height, not the width and height params
|
||||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||||
|
|
||||||
const model = modelIdToPipelineModelField(modelId);
|
const model = modelIdToMainModelField(modelId);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
@ -81,9 +76,9 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
type: 'noise',
|
type: 'noise',
|
||||||
id: NOISE,
|
id: NOISE,
|
||||||
},
|
},
|
||||||
[PIPELINE_MODEL_LOADER]: {
|
[MAIN_MODEL_LOADER]: {
|
||||||
type: 'pipeline_model_loader',
|
type: 'main_model_loader',
|
||||||
id: PIPELINE_MODEL_LOADER,
|
id: MAIN_MODEL_LOADER,
|
||||||
model,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
@ -110,7 +105,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
edges: [
|
edges: [
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -120,7 +115,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -128,16 +123,6 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_IMAGE,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: LATENTS_TO_LATENTS,
|
node_id: LATENTS_TO_LATENTS,
|
||||||
@ -170,17 +155,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: IMAGE_TO_LATENTS,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -277,6 +252,9 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add VAE
|
||||||
|
addVAEToGraph(graph, state);
|
||||||
|
|
||||||
// add dynamic prompts, mutating `graph`
|
// add dynamic prompts, mutating `graph`
|
||||||
addDynamicPromptsToGraph(graph, state);
|
addDynamicPromptsToGraph(graph, state);
|
||||||
|
|
||||||
|
@ -1,23 +1,24 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import {
|
import {
|
||||||
ImageDTO,
|
ImageDTO,
|
||||||
InpaintInvocation,
|
InpaintInvocation,
|
||||||
RandomIntInvocation,
|
RandomIntInvocation,
|
||||||
RangeOfSizeInvocation,
|
RangeOfSizeInvocation,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||||
import { log } from 'app/logging/useLogger';
|
import { addVAEToGraph } from './addVAEToGraph';
|
||||||
import {
|
import {
|
||||||
|
INPAINT,
|
||||||
|
INPAINT_GRAPH,
|
||||||
ITERATE,
|
ITERATE,
|
||||||
PIPELINE_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
RANDOM_INT,
|
RANDOM_INT,
|
||||||
RANGE_OF_SIZE,
|
RANGE_OF_SIZE,
|
||||||
INPAINT_GRAPH,
|
|
||||||
INPAINT,
|
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
@ -55,7 +56,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
// We may need to set the inpaint width and height to scale the image
|
// We may need to set the inpaint width and height to scale the image
|
||||||
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
||||||
|
|
||||||
const model = modelIdToPipelineModelField(modelId);
|
const model = modelIdToMainModelField(modelId);
|
||||||
|
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
id: INPAINT_GRAPH,
|
id: INPAINT_GRAPH,
|
||||||
@ -101,9 +102,9 @@ export const buildCanvasInpaintGraph = (
|
|||||||
id: NEGATIVE_CONDITIONING,
|
id: NEGATIVE_CONDITIONING,
|
||||||
prompt: negativePrompt,
|
prompt: negativePrompt,
|
||||||
},
|
},
|
||||||
[PIPELINE_MODEL_LOADER]: {
|
[MAIN_MODEL_LOADER]: {
|
||||||
type: 'pipeline_model_loader',
|
type: 'main_model_loader',
|
||||||
id: PIPELINE_MODEL_LOADER,
|
id: MAIN_MODEL_LOADER,
|
||||||
model,
|
model,
|
||||||
},
|
},
|
||||||
[RANGE_OF_SIZE]: {
|
[RANGE_OF_SIZE]: {
|
||||||
@ -142,7 +143,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -152,7 +153,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -162,7 +163,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -170,16 +171,6 @@ export const buildCanvasInpaintGraph = (
|
|||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: INPAINT,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: RANGE_OF_SIZE,
|
node_id: RANGE_OF_SIZE,
|
||||||
@ -203,6 +194,9 @@ export const buildCanvasInpaintGraph = (
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Add VAE
|
||||||
|
addVAEToGraph(graph, state);
|
||||||
|
|
||||||
// handle seed
|
// handle seed
|
||||||
if (shouldRandomizeSeed) {
|
if (shouldRandomizeSeed) {
|
||||||
// Random int node to generate the starting seed
|
// Random int node to generate the starting seed
|
||||||
|
@ -1,21 +1,18 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api/types';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||||
|
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||||
|
import { addVAEToGraph } from './addVAEToGraph';
|
||||||
import {
|
import {
|
||||||
ITERATE,
|
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
PIPELINE_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
RANDOM_INT,
|
|
||||||
RANGE_OF_SIZE,
|
|
||||||
TEXT_TO_IMAGE_GRAPH,
|
TEXT_TO_IMAGE_GRAPH,
|
||||||
TEXT_TO_LATENTS,
|
TEXT_TO_LATENTS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
|
||||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
|
||||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Canvas tab's Text to Image graph.
|
* Builds the Canvas tab's Text to Image graph.
|
||||||
@ -38,7 +35,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
// The bounding box determines width and height, not the width and height params
|
// The bounding box determines width and height, not the width and height params
|
||||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||||
|
|
||||||
const model = modelIdToPipelineModelField(modelId);
|
const model = modelIdToMainModelField(modelId);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
@ -76,9 +73,9 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
},
|
},
|
||||||
[PIPELINE_MODEL_LOADER]: {
|
[MAIN_MODEL_LOADER]: {
|
||||||
type: 'pipeline_model_loader',
|
type: 'main_model_loader',
|
||||||
id: PIPELINE_MODEL_LOADER,
|
id: MAIN_MODEL_LOADER,
|
||||||
model,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
@ -109,7 +106,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -119,7 +116,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -129,7 +126,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -147,16 +144,6 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
field: 'latents',
|
field: 'latents',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_IMAGE,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: NOISE,
|
node_id: NOISE,
|
||||||
@ -170,6 +157,9 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Add VAE
|
||||||
|
addVAEToGraph(graph, state);
|
||||||
|
|
||||||
// add dynamic prompts, mutating `graph`
|
// add dynamic prompts, mutating `graph`
|
||||||
addDynamicPromptsToGraph(graph, state);
|
addDynamicPromptsToGraph(graph, state);
|
||||||
|
|
||||||
|
@ -1,28 +1,29 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import {
|
import {
|
||||||
ImageCollectionInvocation,
|
ImageCollectionInvocation,
|
||||||
ImageResizeInvocation,
|
ImageResizeInvocation,
|
||||||
ImageToLatentsInvocation,
|
ImageToLatentsInvocation,
|
||||||
IterateInvocation,
|
IterateInvocation,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
import { log } from 'app/logging/useLogger';
|
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||||
|
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||||
|
import { addVAEToGraph } from './addVAEToGraph';
|
||||||
import {
|
import {
|
||||||
|
IMAGE_COLLECTION,
|
||||||
|
IMAGE_COLLECTION_ITERATE,
|
||||||
|
IMAGE_TO_IMAGE_GRAPH,
|
||||||
|
IMAGE_TO_LATENTS,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
PIPELINE_MODEL_LOADER,
|
LATENTS_TO_LATENTS,
|
||||||
|
MAIN_MODEL_LOADER,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
IMAGE_TO_IMAGE_GRAPH,
|
|
||||||
IMAGE_TO_LATENTS,
|
|
||||||
LATENTS_TO_LATENTS,
|
|
||||||
RESIZE,
|
RESIZE,
|
||||||
IMAGE_COLLECTION,
|
|
||||||
IMAGE_COLLECTION_ITERATE,
|
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
|
||||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
|
||||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
@ -69,7 +70,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
throw new Error('No initial image found in state');
|
throw new Error('No initial image found in state');
|
||||||
}
|
}
|
||||||
|
|
||||||
const model = modelIdToPipelineModelField(modelId);
|
const model = modelIdToMainModelField(modelId);
|
||||||
|
|
||||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
@ -89,9 +90,9 @@ export const buildLinearImageToImageGraph = (
|
|||||||
type: 'noise',
|
type: 'noise',
|
||||||
id: NOISE,
|
id: NOISE,
|
||||||
},
|
},
|
||||||
[PIPELINE_MODEL_LOADER]: {
|
[MAIN_MODEL_LOADER]: {
|
||||||
type: 'pipeline_model_loader',
|
type: 'main_model_loader',
|
||||||
id: PIPELINE_MODEL_LOADER,
|
id: MAIN_MODEL_LOADER,
|
||||||
model,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
@ -118,7 +119,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
edges: [
|
edges: [
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -128,7 +129,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -136,16 +137,6 @@ export const buildLinearImageToImageGraph = (
|
|||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_IMAGE,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: LATENTS_TO_LATENTS,
|
node_id: LATENTS_TO_LATENTS,
|
||||||
@ -176,19 +167,10 @@ export const buildLinearImageToImageGraph = (
|
|||||||
field: 'noise',
|
field: 'noise',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: IMAGE_TO_LATENTS,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -322,6 +304,8 @@ export const buildLinearImageToImageGraph = (
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
// Add VAE
|
||||||
|
addVAEToGraph(graph, state);
|
||||||
|
|
||||||
// add dynamic prompts, mutating `graph`
|
// add dynamic prompts, mutating `graph`
|
||||||
addDynamicPromptsToGraph(graph, state);
|
addDynamicPromptsToGraph(graph, state);
|
||||||
|
@ -1,17 +1,18 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||||
|
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||||
|
import { addVAEToGraph } from './addVAEToGraph';
|
||||||
import {
|
import {
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
PIPELINE_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
TEXT_TO_IMAGE_GRAPH,
|
TEXT_TO_IMAGE_GRAPH,
|
||||||
TEXT_TO_LATENTS,
|
TEXT_TO_LATENTS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
|
||||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
|
||||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
|
||||||
|
|
||||||
export const buildLinearTextToImageGraph = (
|
export const buildLinearTextToImageGraph = (
|
||||||
state: RootState
|
state: RootState
|
||||||
@ -27,7 +28,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
height,
|
height,
|
||||||
} = state.generation;
|
} = state.generation;
|
||||||
|
|
||||||
const model = modelIdToPipelineModelField(modelId);
|
const model = modelIdToMainModelField(modelId);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
@ -65,9 +66,9 @@ export const buildLinearTextToImageGraph = (
|
|||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
},
|
},
|
||||||
[PIPELINE_MODEL_LOADER]: {
|
[MAIN_MODEL_LOADER]: {
|
||||||
type: 'pipeline_model_loader',
|
type: 'main_model_loader',
|
||||||
id: PIPELINE_MODEL_LOADER,
|
id: MAIN_MODEL_LOADER,
|
||||||
model,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
@ -98,7 +99,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -108,7 +109,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -118,7 +119,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
node_id: MAIN_MODEL_LOADER,
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -136,16 +137,6 @@ export const buildLinearTextToImageGraph = (
|
|||||||
field: 'latents',
|
field: 'latents',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: PIPELINE_MODEL_LOADER,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_IMAGE,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: NOISE,
|
node_id: NOISE,
|
||||||
@ -159,6 +150,9 @@ export const buildLinearTextToImageGraph = (
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Add Custom VAE Support
|
||||||
|
addVAEToGraph(graph, state);
|
||||||
|
|
||||||
// add dynamic prompts, mutating `graph`
|
// add dynamic prompts, mutating `graph`
|
||||||
addDynamicPromptsToGraph(graph, state);
|
addDynamicPromptsToGraph(graph, state);
|
||||||
|
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import { Graph } from 'services/api/types';
|
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
|
||||||
import { cloneDeep, omit, reduce } from 'lodash-es';
|
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { InputFieldValue } from 'features/nodes/types/types';
|
import { InputFieldValue } from 'features/nodes/types/types';
|
||||||
|
import { cloneDeep, omit, reduce } from 'lodash-es';
|
||||||
|
import { Graph } from 'services/api/types';
|
||||||
import { AnyInvocation } from 'services/events/types';
|
import { AnyInvocation } from 'services/events/types';
|
||||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||||
|
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* We need to do special handling for some fields
|
* We need to do special handling for some fields
|
||||||
@ -27,7 +28,13 @@ export const parseFieldValue = (field: InputFieldValue) => {
|
|||||||
|
|
||||||
if (field.type === 'model') {
|
if (field.type === 'model') {
|
||||||
if (field.value) {
|
if (field.value) {
|
||||||
return modelIdToPipelineModelField(field.value);
|
return modelIdToMainModelField(field.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (field.type === 'vae_model') {
|
||||||
|
if (field.value) {
|
||||||
|
return modelIdToVAEModelField(field.value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,7 +7,8 @@ export const NOISE = 'noise';
|
|||||||
export const RANDOM_INT = 'rand_int';
|
export const RANDOM_INT = 'rand_int';
|
||||||
export const RANGE_OF_SIZE = 'range_of_size';
|
export const RANGE_OF_SIZE = 'range_of_size';
|
||||||
export const ITERATE = 'iterate';
|
export const ITERATE = 'iterate';
|
||||||
export const PIPELINE_MODEL_LOADER = 'pipeline_model_loader';
|
export const MAIN_MODEL_LOADER = 'main_model_loader';
|
||||||
|
export const VAE_LOADER = 'vae_loader';
|
||||||
export const IMAGE_TO_LATENTS = 'image_to_latents';
|
export const IMAGE_TO_LATENTS = 'image_to_latents';
|
||||||
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
||||||
export const RESIZE = 'resize_image';
|
export const RESIZE = 'resize_image';
|
||||||
|
@ -0,0 +1,16 @@
|
|||||||
|
import { BaseModelType, MainModelField } from 'services/api/types';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Crudely converts a model id to a main model field
|
||||||
|
* TODO: Make better
|
||||||
|
*/
|
||||||
|
export const modelIdToMainModelField = (modelId: string): MainModelField => {
|
||||||
|
const [base_model, model_type, model_name] = modelId.split('/');
|
||||||
|
|
||||||
|
const field: MainModelField = {
|
||||||
|
base_model: base_model as BaseModelType,
|
||||||
|
model_name,
|
||||||
|
};
|
||||||
|
|
||||||
|
return field;
|
||||||
|
};
|
@ -1,18 +0,0 @@
|
|||||||
import { BaseModelType, PipelineModelField } from 'services/api/types';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Crudely converts a model id to a pipeline model field
|
|
||||||
* TODO: Make better
|
|
||||||
*/
|
|
||||||
export const modelIdToPipelineModelField = (
|
|
||||||
modelId: string
|
|
||||||
): PipelineModelField => {
|
|
||||||
const [base_model, model_type, model_name] = modelId.split('/');
|
|
||||||
|
|
||||||
const field: PipelineModelField = {
|
|
||||||
base_model: base_model as BaseModelType,
|
|
||||||
model_name,
|
|
||||||
};
|
|
||||||
|
|
||||||
return field;
|
|
||||||
};
|
|
@ -0,0 +1,16 @@
|
|||||||
|
import { BaseModelType, VAEModelField } from 'services/api/types';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Crudely converts a model id to a main model field
|
||||||
|
* TODO: Make better
|
||||||
|
*/
|
||||||
|
export const modelIdToVAEModelField = (modelId: string): VAEModelField => {
|
||||||
|
const [base_model, model_type, model_name] = modelId.split('/');
|
||||||
|
|
||||||
|
const field: VAEModelField = {
|
||||||
|
base_model: base_model as BaseModelType,
|
||||||
|
model_name,
|
||||||
|
};
|
||||||
|
|
||||||
|
return field;
|
||||||
|
};
|
@ -1,19 +1,19 @@
|
|||||||
import { Box, Flex } from '@chakra-ui/react';
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
import ModelSelect from 'features/system/components/ModelSelect';
|
import ModelSelect from 'features/system/components/ModelSelect';
|
||||||
|
import VAESelect from 'features/system/components/VAESelect';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import ParamScheduler from './ParamScheduler';
|
|
||||||
|
|
||||||
const ParamSchedulerAndModel = () => {
|
const ParamModelandVAE = () => {
|
||||||
return (
|
return (
|
||||||
<Flex gap={3} w="full">
|
<Flex gap={3} w="full">
|
||||||
<Box w="25rem">
|
|
||||||
<ParamScheduler />
|
|
||||||
</Box>
|
|
||||||
<Box w="full">
|
<Box w="full">
|
||||||
<ModelSelect />
|
<ModelSelect />
|
||||||
</Box>
|
</Box>
|
||||||
|
<Box w="full">
|
||||||
|
<VAESelect />
|
||||||
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(ParamSchedulerAndModel);
|
export default memo(ParamModelandVAE);
|
@ -14,6 +14,7 @@ import {
|
|||||||
SeedParam,
|
SeedParam,
|
||||||
StepsParam,
|
StepsParam,
|
||||||
StrengthParam,
|
StrengthParam,
|
||||||
|
VAEParam,
|
||||||
WidthParam,
|
WidthParam,
|
||||||
} from './parameterZodSchemas';
|
} from './parameterZodSchemas';
|
||||||
|
|
||||||
@ -47,6 +48,7 @@ export interface GenerationState {
|
|||||||
horizontalSymmetrySteps: number;
|
horizontalSymmetrySteps: number;
|
||||||
verticalSymmetrySteps: number;
|
verticalSymmetrySteps: number;
|
||||||
model: ModelParam;
|
model: ModelParam;
|
||||||
|
vae: VAEParam;
|
||||||
shouldUseSeamless: boolean;
|
shouldUseSeamless: boolean;
|
||||||
seamlessXAxis: boolean;
|
seamlessXAxis: boolean;
|
||||||
seamlessYAxis: boolean;
|
seamlessYAxis: boolean;
|
||||||
@ -81,6 +83,7 @@ export const initialGenerationState: GenerationState = {
|
|||||||
horizontalSymmetrySteps: 0,
|
horizontalSymmetrySteps: 0,
|
||||||
verticalSymmetrySteps: 0,
|
verticalSymmetrySteps: 0,
|
||||||
model: '',
|
model: '',
|
||||||
|
vae: '',
|
||||||
shouldUseSeamless: false,
|
shouldUseSeamless: false,
|
||||||
seamlessXAxis: true,
|
seamlessXAxis: true,
|
||||||
seamlessYAxis: true,
|
seamlessYAxis: true,
|
||||||
@ -216,6 +219,9 @@ export const generationSlice = createSlice({
|
|||||||
modelSelected: (state, action: PayloadAction<string>) => {
|
modelSelected: (state, action: PayloadAction<string>) => {
|
||||||
state.model = action.payload;
|
state.model = action.payload;
|
||||||
},
|
},
|
||||||
|
vaeSelected: (state, action: PayloadAction<string>) => {
|
||||||
|
state.vae = action.payload;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(configChanged, (state, action) => {
|
builder.addCase(configChanged, (state, action) => {
|
||||||
@ -260,6 +266,7 @@ export const {
|
|||||||
setVerticalSymmetrySteps,
|
setVerticalSymmetrySteps,
|
||||||
initialImageChanged,
|
initialImageChanged,
|
||||||
modelSelected,
|
modelSelected,
|
||||||
|
vaeSelected,
|
||||||
setShouldUseNoiseSettings,
|
setShouldUseNoiseSettings,
|
||||||
setSeamless,
|
setSeamless,
|
||||||
setSeamlessXAxis,
|
setSeamlessXAxis,
|
||||||
|
@ -135,6 +135,15 @@ export const zModel = z.string();
|
|||||||
* Type alias for model parameter, inferred from its zod schema
|
* Type alias for model parameter, inferred from its zod schema
|
||||||
*/
|
*/
|
||||||
export type ModelParam = z.infer<typeof zModel>;
|
export type ModelParam = z.infer<typeof zModel>;
|
||||||
|
/**
|
||||||
|
* Zod schema for VAE parameter
|
||||||
|
* TODO: Make this a dynamically generated enum?
|
||||||
|
*/
|
||||||
|
export const zVAE = z.string();
|
||||||
|
/**
|
||||||
|
* Type alias for model parameter, inferred from its zod schema
|
||||||
|
*/
|
||||||
|
export type VAEParam = z.infer<typeof zVAE>;
|
||||||
/**
|
/**
|
||||||
* Validates/type-guards a value as a model parameter
|
* Validates/type-guards a value as a model parameter
|
||||||
*/
|
*/
|
||||||
|
@ -1,125 +0,0 @@
|
|||||||
import {
|
|
||||||
Button,
|
|
||||||
Flex,
|
|
||||||
Modal,
|
|
||||||
ModalBody,
|
|
||||||
ModalCloseButton,
|
|
||||||
ModalContent,
|
|
||||||
ModalFooter,
|
|
||||||
ModalHeader,
|
|
||||||
ModalOverlay,
|
|
||||||
Text,
|
|
||||||
useDisclosure,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
|
|
||||||
import IAIButton from 'common/components/IAIButton';
|
|
||||||
|
|
||||||
import { FaArrowLeft, FaPlus } from 'react-icons/fa';
|
|
||||||
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
import type { RootState } from 'app/store/store';
|
|
||||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
|
||||||
import AddCheckpointModel from './AddCheckpointModel';
|
|
||||||
import AddDiffusersModel from './AddDiffusersModel';
|
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
|
||||||
|
|
||||||
function AddModelBox({
|
|
||||||
text,
|
|
||||||
onClick,
|
|
||||||
}: {
|
|
||||||
text: string;
|
|
||||||
onClick?: () => void;
|
|
||||||
}) {
|
|
||||||
return (
|
|
||||||
<Flex
|
|
||||||
position="relative"
|
|
||||||
width="50%"
|
|
||||||
height={40}
|
|
||||||
justifyContent="center"
|
|
||||||
alignItems="center"
|
|
||||||
onClick={onClick}
|
|
||||||
as={Button}
|
|
||||||
>
|
|
||||||
<Text fontWeight="bold">{text}</Text>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export default function AddModel() {
|
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
|
||||||
|
|
||||||
const addNewModelUIOption = useAppSelector(
|
|
||||||
(state: RootState) => state.ui.addNewModelUIOption
|
|
||||||
);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const addModelModalClose = () => {
|
|
||||||
onClose();
|
|
||||||
dispatch(setAddNewModelUIOption(null));
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<>
|
|
||||||
<IAIButton
|
|
||||||
aria-label={t('modelManager.addNewModel')}
|
|
||||||
tooltip={t('modelManager.addNewModel')}
|
|
||||||
onClick={onOpen}
|
|
||||||
size="sm"
|
|
||||||
>
|
|
||||||
<Flex columnGap={2} alignItems="center">
|
|
||||||
<FaPlus />
|
|
||||||
{t('modelManager.addNew')}
|
|
||||||
</Flex>
|
|
||||||
</IAIButton>
|
|
||||||
|
|
||||||
<Modal
|
|
||||||
isOpen={isOpen}
|
|
||||||
onClose={addModelModalClose}
|
|
||||||
size="3xl"
|
|
||||||
closeOnOverlayClick={false}
|
|
||||||
>
|
|
||||||
<ModalOverlay />
|
|
||||||
<ModalContent margin="auto">
|
|
||||||
<ModalHeader>{t('modelManager.addNewModel')} </ModalHeader>
|
|
||||||
{addNewModelUIOption !== null && (
|
|
||||||
<IAIIconButton
|
|
||||||
aria-label={t('common.back')}
|
|
||||||
tooltip={t('common.back')}
|
|
||||||
onClick={() => dispatch(setAddNewModelUIOption(null))}
|
|
||||||
position="absolute"
|
|
||||||
variant="ghost"
|
|
||||||
zIndex={1}
|
|
||||||
size="sm"
|
|
||||||
insetInlineEnd={12}
|
|
||||||
top={2}
|
|
||||||
icon={<FaArrowLeft />}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
<ModalCloseButton />
|
|
||||||
<ModalBody>
|
|
||||||
{addNewModelUIOption == null && (
|
|
||||||
<Flex columnGap={4}>
|
|
||||||
<AddModelBox
|
|
||||||
text={t('modelManager.addCheckpointModel')}
|
|
||||||
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
|
|
||||||
/>
|
|
||||||
<AddModelBox
|
|
||||||
text={t('modelManager.addDiffuserModel')}
|
|
||||||
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
)}
|
|
||||||
{addNewModelUIOption == 'ckpt' && <AddCheckpointModel />}
|
|
||||||
{addNewModelUIOption == 'diffusers' && <AddDiffusersModel />}
|
|
||||||
</ModalBody>
|
|
||||||
<ModalFooter />
|
|
||||||
</ModalContent>
|
|
||||||
</Modal>
|
|
||||||
</>
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,339 +0,0 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
|
|
||||||
import IAIButton from 'common/components/IAIButton';
|
|
||||||
import IAIInput from 'common/components/IAIInput';
|
|
||||||
import IAINumberInput from 'common/components/IAINumberInput';
|
|
||||||
import { useEffect, useState } from 'react';
|
|
||||||
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
|
|
||||||
import {
|
|
||||||
Flex,
|
|
||||||
FormControl,
|
|
||||||
FormLabel,
|
|
||||||
HStack,
|
|
||||||
Text,
|
|
||||||
VStack,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
|
|
||||||
// import { addNewModel } from 'app/socketio/actions';
|
|
||||||
import { Field, Formik } from 'formik';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
import type { InvokeModelConfigProps } from 'app/types/invokeai';
|
|
||||||
import type { RootState } from 'app/store/store';
|
|
||||||
import type { FieldInputProps, FormikProps } from 'formik';
|
|
||||||
import { isEqual, pickBy } from 'lodash-es';
|
|
||||||
import ModelConvert from './ModelConvert';
|
|
||||||
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
|
|
||||||
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
|
|
||||||
import IAIForm from 'common/components/IAIForm';
|
|
||||||
|
|
||||||
const selector = createSelector(
|
|
||||||
[systemSelector],
|
|
||||||
(system) => {
|
|
||||||
const { openModel, model_list } = system;
|
|
||||||
return {
|
|
||||||
model_list,
|
|
||||||
openModel,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const MIN_MODEL_SIZE = 64;
|
|
||||||
const MAX_MODEL_SIZE = 2048;
|
|
||||||
|
|
||||||
export default function CheckpointModelEdit() {
|
|
||||||
const { openModel, model_list } = useAppSelector(selector);
|
|
||||||
const isProcessing = useAppSelector(
|
|
||||||
(state: RootState) => state.system.isProcessing
|
|
||||||
);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const [editModelFormValues, setEditModelFormValues] =
|
|
||||||
useState<InvokeModelConfigProps>({
|
|
||||||
name: '',
|
|
||||||
description: '',
|
|
||||||
config: 'configs/stable-diffusion/v1-inference.yaml',
|
|
||||||
weights: '',
|
|
||||||
vae: '',
|
|
||||||
width: 512,
|
|
||||||
height: 512,
|
|
||||||
default: false,
|
|
||||||
format: 'ckpt',
|
|
||||||
});
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (openModel) {
|
|
||||||
const retrievedModel = pickBy(model_list, (_val, key) => {
|
|
||||||
return isEqual(key, openModel);
|
|
||||||
});
|
|
||||||
setEditModelFormValues({
|
|
||||||
name: openModel,
|
|
||||||
description: retrievedModel[openModel]?.description,
|
|
||||||
config: retrievedModel[openModel]?.config,
|
|
||||||
weights: retrievedModel[openModel]?.weights,
|
|
||||||
vae: retrievedModel[openModel]?.vae,
|
|
||||||
width: retrievedModel[openModel]?.width,
|
|
||||||
height: retrievedModel[openModel]?.height,
|
|
||||||
default: retrievedModel[openModel]?.default,
|
|
||||||
format: 'ckpt',
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, [model_list, openModel]);
|
|
||||||
|
|
||||||
const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
|
|
||||||
dispatch(
|
|
||||||
addNewModel({
|
|
||||||
...values,
|
|
||||||
width: Number(values.width),
|
|
||||||
height: Number(values.height),
|
|
||||||
})
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
return openModel ? (
|
|
||||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
|
||||||
<Flex alignItems="center" gap={4} justifyContent="space-between">
|
|
||||||
<Text fontSize="lg" fontWeight="bold">
|
|
||||||
{openModel}
|
|
||||||
</Text>
|
|
||||||
<ModelConvert model={openModel} />
|
|
||||||
</Flex>
|
|
||||||
<Flex
|
|
||||||
flexDirection="column"
|
|
||||||
maxHeight={window.innerHeight - 270}
|
|
||||||
overflowY="scroll"
|
|
||||||
paddingInlineEnd={8}
|
|
||||||
>
|
|
||||||
<Formik
|
|
||||||
enableReinitialize={true}
|
|
||||||
initialValues={editModelFormValues}
|
|
||||||
onSubmit={editModelFormSubmitHandler}
|
|
||||||
>
|
|
||||||
{({ handleSubmit, errors, touched }) => (
|
|
||||||
<IAIForm onSubmit={handleSubmit}>
|
|
||||||
<VStack rowGap={2} alignItems="start">
|
|
||||||
{/* Description */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.description && touched.description}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="description" fontSize="sm">
|
|
||||||
{t('modelManager.description')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="description"
|
|
||||||
name="description"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.description && touched.description ? (
|
|
||||||
<IAIFormErrorMessage>
|
|
||||||
{errors.description}
|
|
||||||
</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.descriptionValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* Config */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.config && touched.config}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="config" fontSize="sm">
|
|
||||||
{t('modelManager.config')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="config"
|
|
||||||
name="config"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.config && touched.config ? (
|
|
||||||
<IAIFormErrorMessage>{errors.config}</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.configValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* Weights */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.weights && touched.weights}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="config" fontSize="sm">
|
|
||||||
{t('modelManager.modelLocation')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="weights"
|
|
||||||
name="weights"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.weights && touched.weights ? (
|
|
||||||
<IAIFormErrorMessage>
|
|
||||||
{errors.weights}
|
|
||||||
</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.modelLocationValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* VAE */}
|
|
||||||
<FormControl isInvalid={!!errors.vae && touched.vae}>
|
|
||||||
<FormLabel htmlFor="vae" fontSize="sm">
|
|
||||||
{t('modelManager.vaeLocation')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="vae"
|
|
||||||
name="vae"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.vae && touched.vae ? (
|
|
||||||
<IAIFormErrorMessage>{errors.vae}</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.vaeLocationValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
<HStack width="100%">
|
|
||||||
{/* Width */}
|
|
||||||
<FormControl isInvalid={!!errors.width && touched.width}>
|
|
||||||
<FormLabel htmlFor="width" fontSize="sm">
|
|
||||||
{t('modelManager.width')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field id="width" name="width">
|
|
||||||
{({
|
|
||||||
field,
|
|
||||||
form,
|
|
||||||
}: {
|
|
||||||
field: FieldInputProps<number>;
|
|
||||||
form: FormikProps<InvokeModelConfigProps>;
|
|
||||||
}) => (
|
|
||||||
<IAINumberInput
|
|
||||||
id="width"
|
|
||||||
name="width"
|
|
||||||
min={MIN_MODEL_SIZE}
|
|
||||||
max={MAX_MODEL_SIZE}
|
|
||||||
step={64}
|
|
||||||
value={form.values.width}
|
|
||||||
onChange={(value) =>
|
|
||||||
form.setFieldValue(field.name, Number(value))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Field>
|
|
||||||
|
|
||||||
{!!errors.width && touched.width ? (
|
|
||||||
<IAIFormErrorMessage>
|
|
||||||
{errors.width}
|
|
||||||
</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.widthValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* Height */}
|
|
||||||
<FormControl isInvalid={!!errors.height && touched.height}>
|
|
||||||
<FormLabel htmlFor="height" fontSize="sm">
|
|
||||||
{t('modelManager.height')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field id="height" name="height">
|
|
||||||
{({
|
|
||||||
field,
|
|
||||||
form,
|
|
||||||
}: {
|
|
||||||
field: FieldInputProps<number>;
|
|
||||||
form: FormikProps<InvokeModelConfigProps>;
|
|
||||||
}) => (
|
|
||||||
<IAINumberInput
|
|
||||||
id="height"
|
|
||||||
name="height"
|
|
||||||
min={MIN_MODEL_SIZE}
|
|
||||||
max={MAX_MODEL_SIZE}
|
|
||||||
step={64}
|
|
||||||
value={form.values.height}
|
|
||||||
onChange={(value) =>
|
|
||||||
form.setFieldValue(field.name, Number(value))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Field>
|
|
||||||
|
|
||||||
{!!errors.height && touched.height ? (
|
|
||||||
<IAIFormErrorMessage>
|
|
||||||
{errors.height}
|
|
||||||
</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.heightValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
</HStack>
|
|
||||||
|
|
||||||
<IAIButton
|
|
||||||
type="submit"
|
|
||||||
className="modal-close-btn"
|
|
||||||
isLoading={isProcessing}
|
|
||||||
>
|
|
||||||
{t('modelManager.updateModel')}
|
|
||||||
</IAIButton>
|
|
||||||
</VStack>
|
|
||||||
</IAIForm>
|
|
||||||
)}
|
|
||||||
</Formik>
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
) : (
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
width: '100%',
|
|
||||||
justifyContent: 'center',
|
|
||||||
alignItems: 'center',
|
|
||||||
borderRadius: 'base',
|
|
||||||
bg: 'base.900',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Text fontWeight={500}>Pick A Model To Edit</Text>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,281 +0,0 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
|
|
||||||
import IAIButton from 'common/components/IAIButton';
|
|
||||||
import IAIInput from 'common/components/IAIInput';
|
|
||||||
import { useEffect, useState } from 'react';
|
|
||||||
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
|
|
||||||
import { Flex, FormControl, FormLabel, Text, VStack } from '@chakra-ui/react';
|
|
||||||
|
|
||||||
// import { addNewModel } from 'app/socketio/actions';
|
|
||||||
import { Field, Formik } from 'formik';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
import type { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
|
|
||||||
import type { RootState } from 'app/store/store';
|
|
||||||
import { isEqual, pickBy } from 'lodash-es';
|
|
||||||
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
|
|
||||||
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
|
|
||||||
import IAIForm from 'common/components/IAIForm';
|
|
||||||
|
|
||||||
const selector = createSelector(
|
|
||||||
[systemSelector],
|
|
||||||
(system) => {
|
|
||||||
const { openModel, model_list } = system;
|
|
||||||
return {
|
|
||||||
model_list,
|
|
||||||
openModel,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
export default function DiffusersModelEdit() {
|
|
||||||
const { openModel, model_list } = useAppSelector(selector);
|
|
||||||
const isProcessing = useAppSelector(
|
|
||||||
(state: RootState) => state.system.isProcessing
|
|
||||||
);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const [editModelFormValues, setEditModelFormValues] =
|
|
||||||
useState<InvokeDiffusersModelConfigProps>({
|
|
||||||
name: '',
|
|
||||||
description: '',
|
|
||||||
repo_id: '',
|
|
||||||
path: '',
|
|
||||||
vae: { repo_id: '', path: '' },
|
|
||||||
default: false,
|
|
||||||
format: 'diffusers',
|
|
||||||
});
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (openModel) {
|
|
||||||
const retrievedModel = pickBy(model_list, (_val, key) => {
|
|
||||||
return isEqual(key, openModel);
|
|
||||||
});
|
|
||||||
|
|
||||||
setEditModelFormValues({
|
|
||||||
name: openModel,
|
|
||||||
description: retrievedModel[openModel]?.description,
|
|
||||||
path:
|
|
||||||
retrievedModel[openModel]?.path &&
|
|
||||||
retrievedModel[openModel]?.path !== 'None'
|
|
||||||
? retrievedModel[openModel]?.path
|
|
||||||
: '',
|
|
||||||
repo_id:
|
|
||||||
retrievedModel[openModel]?.repo_id &&
|
|
||||||
retrievedModel[openModel]?.repo_id !== 'None'
|
|
||||||
? retrievedModel[openModel]?.repo_id
|
|
||||||
: '',
|
|
||||||
vae: {
|
|
||||||
repo_id: retrievedModel[openModel]?.vae?.repo_id
|
|
||||||
? retrievedModel[openModel]?.vae?.repo_id
|
|
||||||
: '',
|
|
||||||
path: retrievedModel[openModel]?.vae?.path
|
|
||||||
? retrievedModel[openModel]?.vae?.path
|
|
||||||
: '',
|
|
||||||
},
|
|
||||||
default: retrievedModel[openModel]?.default,
|
|
||||||
format: 'diffusers',
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, [model_list, openModel]);
|
|
||||||
|
|
||||||
const editModelFormSubmitHandler = (
|
|
||||||
values: InvokeDiffusersModelConfigProps
|
|
||||||
) => {
|
|
||||||
const diffusersModelToEdit = values;
|
|
||||||
|
|
||||||
if (values.path === '') delete diffusersModelToEdit.path;
|
|
||||||
if (values.repo_id === '') delete diffusersModelToEdit.repo_id;
|
|
||||||
if (values.vae.path === '') delete diffusersModelToEdit.vae.path;
|
|
||||||
if (values.vae.repo_id === '') delete diffusersModelToEdit.vae.repo_id;
|
|
||||||
|
|
||||||
dispatch(addNewModel(values));
|
|
||||||
};
|
|
||||||
|
|
||||||
return openModel ? (
|
|
||||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
|
||||||
<Flex alignItems="center">
|
|
||||||
<Text fontSize="lg" fontWeight="bold">
|
|
||||||
{openModel}
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
|
||||||
<Flex flexDirection="column" overflowY="scroll" paddingInlineEnd={8}>
|
|
||||||
<Formik
|
|
||||||
enableReinitialize={true}
|
|
||||||
initialValues={editModelFormValues}
|
|
||||||
onSubmit={editModelFormSubmitHandler}
|
|
||||||
>
|
|
||||||
{({ handleSubmit, errors, touched }) => (
|
|
||||||
<IAIForm onSubmit={handleSubmit}>
|
|
||||||
<VStack rowGap={2} alignItems="start">
|
|
||||||
{/* Description */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.description && touched.description}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="description" fontSize="sm">
|
|
||||||
{t('modelManager.description')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="description"
|
|
||||||
name="description"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.description && touched.description ? (
|
|
||||||
<IAIFormErrorMessage>
|
|
||||||
{errors.description}
|
|
||||||
</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.descriptionValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* Path */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.path && touched.path}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="path" fontSize="sm">
|
|
||||||
{t('modelManager.modelLocation')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="path"
|
|
||||||
name="path"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.path && touched.path ? (
|
|
||||||
<IAIFormErrorMessage>{errors.path}</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.modelLocationValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* Repo ID */}
|
|
||||||
<FormControl isInvalid={!!errors.repo_id && touched.repo_id}>
|
|
||||||
<FormLabel htmlFor="repo_id" fontSize="sm">
|
|
||||||
{t('modelManager.repo_id')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="repo_id"
|
|
||||||
name="repo_id"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.repo_id && touched.repo_id ? (
|
|
||||||
<IAIFormErrorMessage>
|
|
||||||
{errors.repo_id}
|
|
||||||
</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.repoIDValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* VAE Path */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.vae?.path && touched.vae?.path}
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="vae.path" fontSize="sm">
|
|
||||||
{t('modelManager.vaeLocation')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="vae.path"
|
|
||||||
name="vae.path"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.vae?.path && touched.vae?.path ? (
|
|
||||||
<IAIFormErrorMessage>
|
|
||||||
{errors.vae?.path}
|
|
||||||
</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.vaeLocationValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* VAE Repo ID */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.vae?.repo_id && touched.vae?.repo_id}
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="vae.repo_id" fontSize="sm">
|
|
||||||
{t('modelManager.vaeRepoID')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems="start">
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="vae.repo_id"
|
|
||||||
name="vae.repo_id"
|
|
||||||
type="text"
|
|
||||||
width="full"
|
|
||||||
/>
|
|
||||||
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
|
|
||||||
<IAIFormErrorMessage>
|
|
||||||
{errors.vae?.repo_id}
|
|
||||||
</IAIFormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<IAIFormHelperText>
|
|
||||||
{t('modelManager.vaeRepoIDValidationMsg')}
|
|
||||||
</IAIFormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
<IAIButton
|
|
||||||
type="submit"
|
|
||||||
className="modal-close-btn"
|
|
||||||
isLoading={isProcessing}
|
|
||||||
>
|
|
||||||
{t('modelManager.updateModel')}
|
|
||||||
</IAIButton>
|
|
||||||
</VStack>
|
|
||||||
</IAIForm>
|
|
||||||
)}
|
|
||||||
</Formik>
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
) : (
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
width: '100%',
|
|
||||||
justifyContent: 'center',
|
|
||||||
alignItems: 'center',
|
|
||||||
borderRadius: 'base',
|
|
||||||
bg: 'base.900',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Text fontWeight={'500'}>Pick A Model To Edit</Text>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,313 +0,0 @@
|
|||||||
import {
|
|
||||||
Flex,
|
|
||||||
Modal,
|
|
||||||
ModalBody,
|
|
||||||
ModalCloseButton,
|
|
||||||
ModalContent,
|
|
||||||
ModalFooter,
|
|
||||||
ModalHeader,
|
|
||||||
ModalOverlay,
|
|
||||||
Radio,
|
|
||||||
RadioGroup,
|
|
||||||
Text,
|
|
||||||
Tooltip,
|
|
||||||
useDisclosure,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
// import { mergeDiffusersModels } from 'app/socketio/actions';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import IAIButton from 'common/components/IAIButton';
|
|
||||||
import IAIInput from 'common/components/IAIInput';
|
|
||||||
import IAISelect from 'common/components/IAISelect';
|
|
||||||
import { diffusersModelsSelector } from 'features/system/store/systemSelectors';
|
|
||||||
import { useState } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import * as InvokeAI from 'app/types/invokeai';
|
|
||||||
import IAISlider from 'common/components/IAISlider';
|
|
||||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
|
||||||
|
|
||||||
export default function MergeModels() {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
|
||||||
|
|
||||||
const diffusersModels = useAppSelector(diffusersModelsSelector);
|
|
||||||
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const [modelOne, setModelOne] = useState<string>(
|
|
||||||
Object.keys(diffusersModels)[0]
|
|
||||||
);
|
|
||||||
const [modelTwo, setModelTwo] = useState<string>(
|
|
||||||
Object.keys(diffusersModels)[1]
|
|
||||||
);
|
|
||||||
const [modelThree, setModelThree] = useState<string>('none');
|
|
||||||
|
|
||||||
const [mergedModelName, setMergedModelName] = useState<string>('');
|
|
||||||
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
|
|
||||||
|
|
||||||
const [modelMergeInterp, setModelMergeInterp] = useState<
|
|
||||||
'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'
|
|
||||||
>('weighted_sum');
|
|
||||||
|
|
||||||
const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState<
|
|
||||||
'root' | 'custom'
|
|
||||||
>('root');
|
|
||||||
|
|
||||||
const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] =
|
|
||||||
useState<string>('');
|
|
||||||
|
|
||||||
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
|
|
||||||
|
|
||||||
const modelOneList = Object.keys(diffusersModels).filter(
|
|
||||||
(model) => model !== modelTwo && model !== modelThree
|
|
||||||
);
|
|
||||||
|
|
||||||
const modelTwoList = Object.keys(diffusersModels).filter(
|
|
||||||
(model) => model !== modelOne && model !== modelThree
|
|
||||||
);
|
|
||||||
|
|
||||||
const modelThreeList = [
|
|
||||||
{ key: t('modelManager.none'), value: 'none' },
|
|
||||||
...Object.keys(diffusersModels)
|
|
||||||
.filter((model) => model !== modelOne && model !== modelTwo)
|
|
||||||
.map((model) => ({ key: model, value: model })),
|
|
||||||
];
|
|
||||||
|
|
||||||
const isProcessing = useAppSelector(
|
|
||||||
(state: RootState) => state.system.isProcessing
|
|
||||||
);
|
|
||||||
|
|
||||||
const mergeModelsHandler = () => {
|
|
||||||
let modelsToMerge: string[] = [modelOne, modelTwo, modelThree];
|
|
||||||
modelsToMerge = modelsToMerge.filter((model) => model !== 'none');
|
|
||||||
|
|
||||||
const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = {
|
|
||||||
models_to_merge: modelsToMerge,
|
|
||||||
merged_model_name:
|
|
||||||
mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'),
|
|
||||||
alpha: modelMergeAlpha,
|
|
||||||
interp: modelMergeInterp,
|
|
||||||
model_merge_save_path:
|
|
||||||
modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
|
|
||||||
force: modelMergeForce,
|
|
||||||
};
|
|
||||||
|
|
||||||
dispatch(mergeDiffusersModels(mergeModelsInfo));
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<>
|
|
||||||
<IAIButton onClick={onOpen} size="sm">
|
|
||||||
<Flex columnGap={2} alignItems="center">
|
|
||||||
{t('modelManager.mergeModels')}
|
|
||||||
</Flex>
|
|
||||||
</IAIButton>
|
|
||||||
|
|
||||||
<Modal
|
|
||||||
isOpen={isOpen}
|
|
||||||
onClose={onClose}
|
|
||||||
size="4xl"
|
|
||||||
closeOnOverlayClick={false}
|
|
||||||
>
|
|
||||||
<ModalOverlay />
|
|
||||||
<ModalContent fontFamily="Inter" margin="auto" paddingInlineEnd={4}>
|
|
||||||
<ModalHeader>{t('modelManager.mergeModels')}</ModalHeader>
|
|
||||||
<ModalCloseButton />
|
|
||||||
<ModalBody>
|
|
||||||
<Flex flexDirection="column" rowGap={4}>
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
flexDirection: 'column',
|
|
||||||
marginBottom: 4,
|
|
||||||
padding: 4,
|
|
||||||
borderRadius: 'base',
|
|
||||||
rowGap: 1,
|
|
||||||
bg: 'base.900',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
|
|
||||||
<Text fontSize="sm" variant="subtext">
|
|
||||||
{t('modelManager.modelMergeHeaderHelp2')}
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
|
||||||
<Flex columnGap={4}>
|
|
||||||
<IAISelect
|
|
||||||
label={t('modelManager.modelOne')}
|
|
||||||
validValues={modelOneList}
|
|
||||||
onChange={(e) => setModelOne(e.target.value)}
|
|
||||||
/>
|
|
||||||
<IAISelect
|
|
||||||
label={t('modelManager.modelTwo')}
|
|
||||||
validValues={modelTwoList}
|
|
||||||
onChange={(e) => setModelTwo(e.target.value)}
|
|
||||||
/>
|
|
||||||
<IAISelect
|
|
||||||
label={t('modelManager.modelThree')}
|
|
||||||
validValues={modelThreeList}
|
|
||||||
onChange={(e) => {
|
|
||||||
if (e.target.value !== 'none') {
|
|
||||||
setModelThree(e.target.value);
|
|
||||||
setModelMergeInterp('add_difference');
|
|
||||||
} else {
|
|
||||||
setModelThree('none');
|
|
||||||
setModelMergeInterp('weighted_sum');
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<IAIInput
|
|
||||||
label={t('modelManager.mergedModelName')}
|
|
||||||
value={mergedModelName}
|
|
||||||
onChange={(e) => setMergedModelName(e.target.value)}
|
|
||||||
/>
|
|
||||||
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
flexDirection: 'column',
|
|
||||||
padding: 4,
|
|
||||||
borderRadius: 'base',
|
|
||||||
gap: 4,
|
|
||||||
bg: 'base.900',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<IAISlider
|
|
||||||
label={t('modelManager.alpha')}
|
|
||||||
min={0.01}
|
|
||||||
max={0.99}
|
|
||||||
step={0.01}
|
|
||||||
value={modelMergeAlpha}
|
|
||||||
onChange={(v) => setModelMergeAlpha(v)}
|
|
||||||
withInput
|
|
||||||
withReset
|
|
||||||
handleReset={() => setModelMergeAlpha(0.5)}
|
|
||||||
withSliderMarks
|
|
||||||
/>
|
|
||||||
<Text variant="subtext" fontSize="sm">
|
|
||||||
{t('modelManager.modelMergeAlphaHelp')}
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
padding: 4,
|
|
||||||
borderRadius: 'base',
|
|
||||||
gap: 4,
|
|
||||||
bg: 'base.900',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Text fontWeight={500} fontSize="sm" variant="subtext">
|
|
||||||
{t('modelManager.interpolationType')}
|
|
||||||
</Text>
|
|
||||||
<RadioGroup
|
|
||||||
value={modelMergeInterp}
|
|
||||||
onChange={(
|
|
||||||
v:
|
|
||||||
| 'weighted_sum'
|
|
||||||
| 'sigmoid'
|
|
||||||
| 'inv_sigmoid'
|
|
||||||
| 'add_difference'
|
|
||||||
) => setModelMergeInterp(v)}
|
|
||||||
>
|
|
||||||
<Flex columnGap={4}>
|
|
||||||
{modelThree === 'none' ? (
|
|
||||||
<>
|
|
||||||
<Radio value="weighted_sum">
|
|
||||||
<Text fontSize="sm">
|
|
||||||
{t('modelManager.weightedSum')}
|
|
||||||
</Text>
|
|
||||||
</Radio>
|
|
||||||
<Radio value="sigmoid">
|
|
||||||
<Text fontSize="sm">{t('modelManager.sigmoid')}</Text>
|
|
||||||
</Radio>
|
|
||||||
<Radio value="inv_sigmoid">
|
|
||||||
<Text fontSize="sm">
|
|
||||||
{t('modelManager.inverseSigmoid')}
|
|
||||||
</Text>
|
|
||||||
</Radio>
|
|
||||||
</>
|
|
||||||
) : (
|
|
||||||
<Radio value="add_difference">
|
|
||||||
<Tooltip
|
|
||||||
label={t(
|
|
||||||
'modelManager.modelMergeInterpAddDifferenceHelp'
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
<Text fontSize="sm">
|
|
||||||
{t('modelManager.addDifference')}
|
|
||||||
</Text>
|
|
||||||
</Tooltip>
|
|
||||||
</Radio>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
</RadioGroup>
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
flexDirection: 'column',
|
|
||||||
padding: 4,
|
|
||||||
borderRadius: 'base',
|
|
||||||
gap: 4,
|
|
||||||
bg: 'base.900',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Flex columnGap={4}>
|
|
||||||
<Text fontWeight="500" fontSize="sm" variant="subtext">
|
|
||||||
{t('modelManager.mergedModelSaveLocation')}
|
|
||||||
</Text>
|
|
||||||
<RadioGroup
|
|
||||||
value={modelMergeSaveLocType}
|
|
||||||
onChange={(v: 'root' | 'custom') =>
|
|
||||||
setModelMergeSaveLocType(v)
|
|
||||||
}
|
|
||||||
>
|
|
||||||
<Flex columnGap={4}>
|
|
||||||
<Radio value="root">
|
|
||||||
<Text fontSize="sm">
|
|
||||||
{t('modelManager.invokeAIFolder')}
|
|
||||||
</Text>
|
|
||||||
</Radio>
|
|
||||||
|
|
||||||
<Radio value="custom">
|
|
||||||
<Text fontSize="sm">{t('modelManager.custom')}</Text>
|
|
||||||
</Radio>
|
|
||||||
</Flex>
|
|
||||||
</RadioGroup>
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
{modelMergeSaveLocType === 'custom' && (
|
|
||||||
<IAIInput
|
|
||||||
label={t('modelManager.mergedModelCustomSaveLocation')}
|
|
||||||
value={modelMergeCustomSaveLoc}
|
|
||||||
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<IAISimpleCheckbox
|
|
||||||
label={t('modelManager.ignoreMismatch')}
|
|
||||||
isChecked={modelMergeForce}
|
|
||||||
onChange={(e) => setModelMergeForce(e.target.checked)}
|
|
||||||
fontWeight="500"
|
|
||||||
/>
|
|
||||||
|
|
||||||
<IAIButton
|
|
||||||
onClick={mergeModelsHandler}
|
|
||||||
isLoading={isProcessing}
|
|
||||||
isDisabled={
|
|
||||||
modelMergeSaveLocType === 'custom' &&
|
|
||||||
modelMergeCustomSaveLoc === ''
|
|
||||||
}
|
|
||||||
>
|
|
||||||
{t('modelManager.merge')}
|
|
||||||
</IAIButton>
|
|
||||||
</Flex>
|
|
||||||
</ModalBody>
|
|
||||||
<ModalFooter />
|
|
||||||
</ModalContent>
|
|
||||||
</Modal>
|
|
||||||
</>
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,76 +0,0 @@
|
|||||||
import {
|
|
||||||
Flex,
|
|
||||||
Modal,
|
|
||||||
ModalBody,
|
|
||||||
ModalCloseButton,
|
|
||||||
ModalContent,
|
|
||||||
ModalFooter,
|
|
||||||
ModalHeader,
|
|
||||||
ModalOverlay,
|
|
||||||
useDisclosure,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
import { cloneElement } from 'react';
|
|
||||||
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
import type { ReactElement } from 'react';
|
|
||||||
|
|
||||||
import CheckpointModelEdit from './CheckpointModelEdit';
|
|
||||||
import DiffusersModelEdit from './DiffusersModelEdit';
|
|
||||||
import ModelList from './ModelList';
|
|
||||||
|
|
||||||
type ModelManagerModalProps = {
|
|
||||||
children: ReactElement;
|
|
||||||
};
|
|
||||||
|
|
||||||
export default function ModelManagerModal({
|
|
||||||
children,
|
|
||||||
}: ModelManagerModalProps) {
|
|
||||||
const {
|
|
||||||
isOpen: isModelManagerModalOpen,
|
|
||||||
onOpen: onModelManagerModalOpen,
|
|
||||||
onClose: onModelManagerModalClose,
|
|
||||||
} = useDisclosure();
|
|
||||||
|
|
||||||
const model_list = useAppSelector(
|
|
||||||
(state: RootState) => state.system.model_list
|
|
||||||
);
|
|
||||||
|
|
||||||
const openModel = useAppSelector(
|
|
||||||
(state: RootState) => state.system.openModel
|
|
||||||
);
|
|
||||||
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<>
|
|
||||||
{cloneElement(children, {
|
|
||||||
onClick: onModelManagerModalOpen,
|
|
||||||
})}
|
|
||||||
<Modal
|
|
||||||
isOpen={isModelManagerModalOpen}
|
|
||||||
onClose={onModelManagerModalClose}
|
|
||||||
size="full"
|
|
||||||
>
|
|
||||||
<ModalOverlay />
|
|
||||||
<ModalContent>
|
|
||||||
<ModalCloseButton />
|
|
||||||
<ModalHeader>{t('modelManager.modelManager')}</ModalHeader>
|
|
||||||
<ModalBody>
|
|
||||||
<Flex width="100%" columnGap={8}>
|
|
||||||
<ModelList />
|
|
||||||
{openModel && model_list[openModel]['format'] === 'diffusers' ? (
|
|
||||||
<DiffusersModelEdit />
|
|
||||||
) : (
|
|
||||||
<CheckpointModelEdit />
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
</ModalBody>
|
|
||||||
<ModalFooter />
|
|
||||||
</ModalContent>
|
|
||||||
</Modal>
|
|
||||||
</>
|
|
||||||
);
|
|
||||||
}
|
|
@ -5,9 +5,9 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
import { modelSelected } from 'features/parameters/store/generationSlice';
|
import { modelSelected } from 'features/parameters/store/generationSlice';
|
||||||
|
|
||||||
import { forEach, isString } from 'lodash-es';
|
|
||||||
import { SelectItem } from '@mantine/core';
|
import { SelectItem } from '@mantine/core';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
|
import { forEach, isString } from 'lodash-es';
|
||||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
export const MODEL_TYPE_MAP = {
|
export const MODEL_TYPE_MAP = {
|
||||||
@ -23,18 +23,18 @@ const ModelSelect = () => {
|
|||||||
(state: RootState) => state.generation.model
|
(state: RootState) => state.generation.model
|
||||||
);
|
);
|
||||||
|
|
||||||
const { data: pipelineModels, isLoading } = useListModelsQuery({
|
const { data: mainModels, isLoading } = useListModelsQuery({
|
||||||
model_type: 'main',
|
model_type: 'main',
|
||||||
});
|
});
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!pipelineModels) {
|
if (!mainModels) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
const data: SelectItem[] = [];
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
forEach(pipelineModels.entities, (model, id) => {
|
forEach(mainModels.entities, (model, id) => {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -47,11 +47,11 @@ const ModelSelect = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
return data;
|
return data;
|
||||||
}, [pipelineModels]);
|
}, [mainModels]);
|
||||||
|
|
||||||
const selectedModel = useMemo(
|
const selectedModel = useMemo(
|
||||||
() => pipelineModels?.entities[selectedModelId],
|
() => mainModels?.entities[selectedModelId],
|
||||||
[pipelineModels?.entities, selectedModelId]
|
[mainModels?.entities, selectedModelId]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleChangeModel = useCallback(
|
const handleChangeModel = useCallback(
|
||||||
@ -65,20 +65,18 @@ const ModelSelect = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// If the selected model is not in the list of models, select the first one
|
if (selectedModelId && mainModels?.ids.includes(selectedModelId)) {
|
||||||
// Handles first-run setting of models, and the user deleting the previously-selected model
|
|
||||||
if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const firstModel = pipelineModels?.ids[0];
|
const firstModel = mainModels?.ids[0];
|
||||||
|
|
||||||
if (!isString(firstModel)) {
|
if (!isString(firstModel)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
handleChangeModel(firstModel);
|
handleChangeModel(firstModel);
|
||||||
}, [handleChangeModel, pipelineModels?.ids, selectedModelId]);
|
}, [handleChangeModel, mainModels?.ids, selectedModelId]);
|
||||||
|
|
||||||
return isLoading ? (
|
return isLoading ? (
|
||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
|
@ -5,21 +5,18 @@ import StatusIndicator from './StatusIndicator';
|
|||||||
import { Link } from '@chakra-ui/react';
|
import { Link } from '@chakra-ui/react';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { FaBug, FaCube, FaDiscord, FaGithub, FaKeyboard } from 'react-icons/fa';
|
import { FaBug, FaDiscord, FaGithub, FaKeyboard } from 'react-icons/fa';
|
||||||
import { MdSettings } from 'react-icons/md';
|
import { MdSettings } from 'react-icons/md';
|
||||||
|
import { useFeatureStatus } from '../hooks/useFeatureStatus';
|
||||||
|
import ColorModeButton from './ColorModeButton';
|
||||||
import HotkeysModal from './HotkeysModal/HotkeysModal';
|
import HotkeysModal from './HotkeysModal/HotkeysModal';
|
||||||
import InvokeAILogoComponent from './InvokeAILogoComponent';
|
import InvokeAILogoComponent from './InvokeAILogoComponent';
|
||||||
import LanguagePicker from './LanguagePicker';
|
import LanguagePicker from './LanguagePicker';
|
||||||
import ModelManagerModal from './ModelManager/ModelManagerModal';
|
|
||||||
import SettingsModal from './SettingsModal/SettingsModal';
|
import SettingsModal from './SettingsModal/SettingsModal';
|
||||||
import { useFeatureStatus } from '../hooks/useFeatureStatus';
|
|
||||||
import ColorModeButton from './ColorModeButton';
|
|
||||||
|
|
||||||
const SiteHeader = () => {
|
const SiteHeader = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const isModelManagerEnabled =
|
|
||||||
useFeatureStatus('modelManager').isFeatureEnabled;
|
|
||||||
const isLocalizationEnabled =
|
const isLocalizationEnabled =
|
||||||
useFeatureStatus('localization').isFeatureEnabled;
|
useFeatureStatus('localization').isFeatureEnabled;
|
||||||
const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled;
|
const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled;
|
||||||
@ -37,20 +34,6 @@ const SiteHeader = () => {
|
|||||||
<Spacer />
|
<Spacer />
|
||||||
<StatusIndicator />
|
<StatusIndicator />
|
||||||
|
|
||||||
{isModelManagerEnabled && (
|
|
||||||
<ModelManagerModal>
|
|
||||||
<IAIIconButton
|
|
||||||
aria-label={t('modelManager.modelManager')}
|
|
||||||
tooltip={t('modelManager.modelManager')}
|
|
||||||
size="sm"
|
|
||||||
variant="link"
|
|
||||||
data-variant="link"
|
|
||||||
fontSize={20}
|
|
||||||
icon={<FaCube />}
|
|
||||||
/>
|
|
||||||
</ModelManagerModal>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<HotkeysModal>
|
<HotkeysModal>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
aria-label={t('common.hotkeysLabel')}
|
aria-label={t('common.hotkeysLabel')}
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import { Flex, Link } from '@chakra-ui/react';
|
import { Flex, Link } from '@chakra-ui/react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { FaBug, FaCube, FaDiscord, FaGithub, FaKeyboard } from 'react-icons/fa';
|
import { FaBug, FaDiscord, FaGithub, FaKeyboard } from 'react-icons/fa';
|
||||||
import { MdSettings } from 'react-icons/md';
|
import { MdSettings } from 'react-icons/md';
|
||||||
import HotkeysModal from './HotkeysModal/HotkeysModal';
|
import HotkeysModal from './HotkeysModal/HotkeysModal';
|
||||||
import LanguagePicker from './LanguagePicker';
|
import LanguagePicker from './LanguagePicker';
|
||||||
import ModelManagerModal from './ModelManager/ModelManagerModal';
|
|
||||||
import SettingsModal from './SettingsModal/SettingsModal';
|
import SettingsModal from './SettingsModal/SettingsModal';
|
||||||
|
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
@ -13,8 +12,6 @@ import { useFeatureStatus } from '../hooks/useFeatureStatus';
|
|||||||
const SiteHeaderMenu = () => {
|
const SiteHeaderMenu = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const isModelManagerEnabled =
|
|
||||||
useFeatureStatus('modelManager').isFeatureEnabled;
|
|
||||||
const isLocalizationEnabled =
|
const isLocalizationEnabled =
|
||||||
useFeatureStatus('localization').isFeatureEnabled;
|
useFeatureStatus('localization').isFeatureEnabled;
|
||||||
const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled;
|
const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled;
|
||||||
@ -27,20 +24,6 @@ const SiteHeaderMenu = () => {
|
|||||||
flexDirection={{ base: 'column', xl: 'row' }}
|
flexDirection={{ base: 'column', xl: 'row' }}
|
||||||
gap={{ base: 4, xl: 1 }}
|
gap={{ base: 4, xl: 1 }}
|
||||||
>
|
>
|
||||||
{isModelManagerEnabled && (
|
|
||||||
<ModelManagerModal>
|
|
||||||
<IAIIconButton
|
|
||||||
aria-label={t('modelManager.modelManager')}
|
|
||||||
tooltip={t('modelManager.modelManager')}
|
|
||||||
size="sm"
|
|
||||||
variant="link"
|
|
||||||
data-variant="link"
|
|
||||||
fontSize={20}
|
|
||||||
icon={<FaCube />}
|
|
||||||
/>
|
|
||||||
</ModelManagerModal>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<HotkeysModal>
|
<HotkeysModal>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
aria-label={t('common.hotkeysLabel')}
|
aria-label={t('common.hotkeysLabel')}
|
||||||
|
@ -0,0 +1,89 @@
|
|||||||
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
|
||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { vaeSelected } from 'features/parameters/store/generationSlice';
|
||||||
|
import { MODEL_TYPE_MAP } from './ModelSelect';
|
||||||
|
|
||||||
|
const VAESelect = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const { data: vaeModels } = useListModelsQuery({
|
||||||
|
model_type: 'vae',
|
||||||
|
});
|
||||||
|
|
||||||
|
const selectedModelId = useAppSelector(
|
||||||
|
(state: RootState) => state.generation.vae
|
||||||
|
);
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!vaeModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [
|
||||||
|
{
|
||||||
|
value: 'auto',
|
||||||
|
label: 'Automatic',
|
||||||
|
group: 'Default',
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
forEach(vaeModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.name,
|
||||||
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [vaeModels]);
|
||||||
|
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() => vaeModels?.entities[selectedModelId],
|
||||||
|
[vaeModels?.entities, selectedModelId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleChangeModel = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(vaeSelected(v));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (selectedModelId && vaeModels?.ids.includes(selectedModelId)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
handleChangeModel('auto');
|
||||||
|
}, [handleChangeModel, vaeModels?.ids, selectedModelId]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIMantineSelect
|
||||||
|
tooltip={selectedModel?.description}
|
||||||
|
label={t('modelManager.vae')}
|
||||||
|
value={selectedModelId}
|
||||||
|
placeholder="Pick one"
|
||||||
|
data={data}
|
||||||
|
onChange={handleChangeModel}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(VAESelect);
|
@ -1,13 +1,14 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||||
import { setShouldShowGallery } from 'features/ui/store/uiSlice';
|
import { setShouldShowGallery } from 'features/ui/store/uiSlice';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
import { MdPhotoLibrary } from 'react-icons/md';
|
import { MdPhotoLibrary } from 'react-icons/md';
|
||||||
import { activeTabNameSelector, uiSelector } from '../store/uiSelectors';
|
import { activeTabNameSelector, uiSelector } from '../store/uiSelectors';
|
||||||
import { memo } from 'react';
|
import { NO_GALLERY_TABS } from './InvokeTabs';
|
||||||
|
|
||||||
const floatingGalleryButtonSelector = createSelector(
|
const floatingGalleryButtonSelector = createSelector(
|
||||||
[activeTabNameSelector, uiSelector],
|
[activeTabNameSelector, uiSelector],
|
||||||
@ -16,7 +17,9 @@ const floatingGalleryButtonSelector = createSelector(
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
shouldPinGallery,
|
shouldPinGallery,
|
||||||
shouldShowGalleryButton: !shouldShowGallery,
|
shouldShowGalleryButton: NO_GALLERY_TABS.includes(activeTabName)
|
||||||
|
? false
|
||||||
|
: !shouldShowGallery,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
{ memoizeOptions: { resultEqualityCheck: isEqual } }
|
{ memoizeOptions: { resultEqualityCheck: isEqual } }
|
||||||
|
@ -9,35 +9,35 @@ import {
|
|||||||
Tooltip,
|
Tooltip,
|
||||||
VisuallyHidden,
|
VisuallyHidden,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import AuxiliaryProgressIndicator from 'app/components/AuxiliaryProgressIndicator';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||||
|
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
|
||||||
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
|
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
|
||||||
|
import { configSelector } from 'features/system/store/configSelectors';
|
||||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
|
import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
|
||||||
import { memo, MouseEvent, ReactNode, useCallback, useMemo } from 'react';
|
import { ResourceKey } from 'i18next';
|
||||||
|
import { isEqual } from 'lodash-es';
|
||||||
|
import { MouseEvent, ReactNode, memo, useCallback, useMemo } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { FaCube, FaFont, FaImage } from 'react-icons/fa';
|
||||||
import { MdDeviceHub, MdGridOn } from 'react-icons/md';
|
import { MdDeviceHub, MdGridOn } from 'react-icons/md';
|
||||||
|
import { Panel, PanelGroup } from 'react-resizable-panels';
|
||||||
|
import { useMinimumPanelSize } from '../hooks/useMinimumPanelSize';
|
||||||
import {
|
import {
|
||||||
activeTabIndexSelector,
|
activeTabIndexSelector,
|
||||||
activeTabNameSelector,
|
activeTabNameSelector,
|
||||||
} from '../store/uiSelectors';
|
} from '../store/uiSelectors';
|
||||||
import { useTranslation } from 'react-i18next';
|
import ImageTab from './tabs/ImageToImage/ImageToImageTab';
|
||||||
import { ResourceKey } from 'i18next';
|
import ModelManagerTab from './tabs/ModelManager/ModelManagerTab';
|
||||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
import NodesTab from './tabs/Nodes/NodesTab';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import ResizeHandle from './tabs/ResizeHandle';
|
||||||
import { configSelector } from 'features/system/store/configSelectors';
|
|
||||||
import { isEqual } from 'lodash-es';
|
|
||||||
import { Panel, PanelGroup } from 'react-resizable-panels';
|
|
||||||
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
|
|
||||||
import TextToImageTab from './tabs/TextToImage/TextToImageTab';
|
import TextToImageTab from './tabs/TextToImage/TextToImageTab';
|
||||||
import UnifiedCanvasTab from './tabs/UnifiedCanvas/UnifiedCanvasTab';
|
import UnifiedCanvasTab from './tabs/UnifiedCanvas/UnifiedCanvasTab';
|
||||||
import NodesTab from './tabs/Nodes/NodesTab';
|
|
||||||
import { FaFont, FaImage, FaLayerGroup } from 'react-icons/fa';
|
|
||||||
import ResizeHandle from './tabs/ResizeHandle';
|
|
||||||
import ImageTab from './tabs/ImageToImage/ImageToImageTab';
|
|
||||||
import AuxiliaryProgressIndicator from 'app/components/AuxiliaryProgressIndicator';
|
|
||||||
import { useMinimumPanelSize } from '../hooks/useMinimumPanelSize';
|
|
||||||
import BatchTab from './tabs/Batch/BatchTab';
|
|
||||||
|
|
||||||
export interface InvokeTabInfo {
|
export interface InvokeTabInfo {
|
||||||
id: InvokeTabName;
|
id: InvokeTabName;
|
||||||
@ -71,6 +71,11 @@ const tabs: InvokeTabInfo[] = [
|
|||||||
// icon: <Icon as={FaLayerGroup} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
|
// icon: <Icon as={FaLayerGroup} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
|
||||||
// content: <BatchTab />,
|
// content: <BatchTab />,
|
||||||
// },
|
// },
|
||||||
|
{
|
||||||
|
id: 'modelManager',
|
||||||
|
icon: <Icon as={FaCube} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
|
||||||
|
content: <ModelManagerTab />,
|
||||||
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
const enabledTabsSelector = createSelector(
|
const enabledTabsSelector = createSelector(
|
||||||
@ -87,6 +92,7 @@ const enabledTabsSelector = createSelector(
|
|||||||
|
|
||||||
const MIN_GALLERY_WIDTH = 300;
|
const MIN_GALLERY_WIDTH = 300;
|
||||||
const DEFAULT_GALLERY_PCT = 20;
|
const DEFAULT_GALLERY_PCT = 20;
|
||||||
|
export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager'];
|
||||||
|
|
||||||
const InvokeTabs = () => {
|
const InvokeTabs = () => {
|
||||||
const activeTab = useAppSelector(activeTabIndexSelector);
|
const activeTab = useAppSelector(activeTabIndexSelector);
|
||||||
@ -198,26 +204,28 @@ const InvokeTabs = () => {
|
|||||||
{tabPanels}
|
{tabPanels}
|
||||||
</TabPanels>
|
</TabPanels>
|
||||||
</Panel>
|
</Panel>
|
||||||
{shouldPinGallery && shouldShowGallery && (
|
{shouldPinGallery &&
|
||||||
<>
|
shouldShowGallery &&
|
||||||
<ResizeHandle />
|
!NO_GALLERY_TABS.includes(activeTabName) && (
|
||||||
<Panel
|
<>
|
||||||
ref={galleryPanelRef}
|
<ResizeHandle />
|
||||||
onResize={handleResizeGallery}
|
<Panel
|
||||||
id="gallery"
|
ref={galleryPanelRef}
|
||||||
order={3}
|
onResize={handleResizeGallery}
|
||||||
defaultSize={
|
id="gallery"
|
||||||
galleryMinSizePct > DEFAULT_GALLERY_PCT
|
order={3}
|
||||||
? galleryMinSizePct
|
defaultSize={
|
||||||
: DEFAULT_GALLERY_PCT
|
galleryMinSizePct > DEFAULT_GALLERY_PCT
|
||||||
}
|
? galleryMinSizePct
|
||||||
minSize={galleryMinSizePct}
|
: DEFAULT_GALLERY_PCT
|
||||||
maxSize={50}
|
}
|
||||||
>
|
minSize={galleryMinSizePct}
|
||||||
<ImageGalleryContent />
|
maxSize={50}
|
||||||
</Panel>
|
>
|
||||||
</>
|
<ImageGalleryContent />
|
||||||
)}
|
</Panel>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</PanelGroup>
|
</PanelGroup>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
);
|
);
|
||||||
|
@ -1,20 +1,21 @@
|
|||||||
import { memo } from 'react';
|
|
||||||
import { Box, Flex, useDisclosure } from '@chakra-ui/react';
|
import { Box, Flex, useDisclosure } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
|
|
||||||
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
|
|
||||||
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
|
|
||||||
import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth';
|
|
||||||
import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight';
|
|
||||||
import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength';
|
|
||||||
import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit';
|
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
|
||||||
import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel';
|
|
||||||
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
|
|
||||||
import IAICollapse from 'common/components/IAICollapse';
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
|
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
|
||||||
|
import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight';
|
||||||
|
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
|
||||||
|
import ParamModelandVAE from 'features/parameters/components/Parameters/Core/ParamModelandVAE';
|
||||||
|
import ParamScheduler from 'features/parameters/components/Parameters/Core/ParamScheduler';
|
||||||
|
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
|
||||||
|
import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth';
|
||||||
|
import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit';
|
||||||
|
import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength';
|
||||||
|
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
|
||||||
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
|
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[uiSelector, generationSelector],
|
[uiSelector, generationSelector],
|
||||||
@ -41,7 +42,7 @@ const ImageToImageTabCoreParameters = () => {
|
|||||||
>
|
>
|
||||||
{shouldUseSliders ? (
|
{shouldUseSliders ? (
|
||||||
<>
|
<>
|
||||||
<ParamSchedulerAndModel />
|
<ParamModelandVAE />
|
||||||
<Box pt={2}>
|
<Box pt={2}>
|
||||||
<ParamSeedFull />
|
<ParamSeedFull />
|
||||||
</Box>
|
</Box>
|
||||||
@ -58,7 +59,8 @@ const ImageToImageTabCoreParameters = () => {
|
|||||||
<ParamSteps />
|
<ParamSteps />
|
||||||
<ParamCFGScale />
|
<ParamCFGScale />
|
||||||
</Flex>
|
</Flex>
|
||||||
<ParamSchedulerAndModel />
|
<ParamModelandVAE />
|
||||||
|
<ParamScheduler />
|
||||||
<Box pt={2}>
|
<Box pt={2}>
|
||||||
<ParamSeedFull />
|
<ParamSeedFull />
|
||||||
</Box>
|
</Box>
|
||||||
|
@ -0,0 +1,81 @@
|
|||||||
|
import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react';
|
||||||
|
import i18n from 'i18n';
|
||||||
|
import { ReactNode, memo } from 'react';
|
||||||
|
import AddModelsPanel from './subpanels/AddModelsPanel';
|
||||||
|
import MergeModelsPanel from './subpanels/MergeModelsPanel';
|
||||||
|
import ModelManagerPanel from './subpanels/ModelManagerPanel';
|
||||||
|
|
||||||
|
type ModelManagerTabName = 'modelManager' | 'addModels' | 'mergeModels';
|
||||||
|
|
||||||
|
type ModelManagerTabInfo = {
|
||||||
|
id: ModelManagerTabName;
|
||||||
|
label: string;
|
||||||
|
content: ReactNode;
|
||||||
|
};
|
||||||
|
|
||||||
|
const modelManagerTabs: ModelManagerTabInfo[] = [
|
||||||
|
{
|
||||||
|
id: 'modelManager',
|
||||||
|
label: i18n.t('modelManager.modelManager'),
|
||||||
|
content: <ModelManagerPanel />,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'addModels',
|
||||||
|
label: i18n.t('modelManager.addModel'),
|
||||||
|
content: <AddModelsPanel />,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'mergeModels',
|
||||||
|
label: i18n.t('modelManager.mergeModels'),
|
||||||
|
content: <MergeModelsPanel />,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const renderTabsList = () => {
|
||||||
|
const modelManagerTabListsToRender: ReactNode[] = [];
|
||||||
|
modelManagerTabs.forEach((modelManagerTab) => {
|
||||||
|
modelManagerTabListsToRender.push(
|
||||||
|
<Tab key={modelManagerTab.id}>{modelManagerTab.label}</Tab>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<TabList
|
||||||
|
sx={{
|
||||||
|
w: '100%',
|
||||||
|
color: 'base.200',
|
||||||
|
flexDirection: 'row',
|
||||||
|
borderBottomWidth: 2,
|
||||||
|
borderColor: 'accent.700',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{modelManagerTabListsToRender}
|
||||||
|
</TabList>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const renderTabPanels = () => {
|
||||||
|
const modelManagerTabPanelsToRender: ReactNode[] = [];
|
||||||
|
modelManagerTabs.forEach((modelManagerTab) => {
|
||||||
|
modelManagerTabPanelsToRender.push(
|
||||||
|
<TabPanel key={modelManagerTab.id}>{modelManagerTab.content}</TabPanel>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
return <TabPanels sx={{ p: 2 }}>{modelManagerTabPanelsToRender}</TabPanels>;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ModelManagerTab = () => {
|
||||||
|
return (
|
||||||
|
<Tabs
|
||||||
|
isLazy
|
||||||
|
variant="invokeAI"
|
||||||
|
sx={{ w: 'full', h: 'full', p: 2, gap: 4, flexDirection: 'column' }}
|
||||||
|
>
|
||||||
|
{renderTabsList()}
|
||||||
|
{renderTabPanels()}
|
||||||
|
</Tabs>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ModelManagerTab);
|
@ -0,0 +1,55 @@
|
|||||||
|
import { Divider, Flex } from '@chakra-ui/react';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import AddCheckpointModel from './AddModelsPanel/AddCheckpointModel';
|
||||||
|
import AddDiffusersModel from './AddModelsPanel/AddDiffusersModel';
|
||||||
|
|
||||||
|
export default function AddModelsPanel() {
|
||||||
|
const addNewModelUIOption = useAppSelector(
|
||||||
|
(state: RootState) => state.ui.addNewModelUIOption
|
||||||
|
);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex flexDirection="column" gap={4}>
|
||||||
|
<Flex columnGap={4}>
|
||||||
|
<IAIButton
|
||||||
|
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
|
||||||
|
sx={{
|
||||||
|
backgroundColor:
|
||||||
|
addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.700',
|
||||||
|
'&:hover': {
|
||||||
|
backgroundColor:
|
||||||
|
addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.600',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{t('modelManager.addCheckpointModel')}
|
||||||
|
</IAIButton>
|
||||||
|
<IAIButton
|
||||||
|
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
|
||||||
|
sx={{
|
||||||
|
backgroundColor:
|
||||||
|
addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.700',
|
||||||
|
'&:hover': {
|
||||||
|
backgroundColor:
|
||||||
|
addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.600',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{t('modelManager.addDiffuserModel')}
|
||||||
|
</IAIButton>
|
||||||
|
</Flex>
|
||||||
|
|
||||||
|
<Divider />
|
||||||
|
|
||||||
|
{addNewModelUIOption == 'ckpt' && <AddCheckpointModel />}
|
||||||
|
{addNewModelUIOption == 'diffusers' && <AddDiffusersModel />}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
@ -10,13 +10,11 @@ import {
|
|||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
|
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
|
||||||
import IAIInput from 'common/components/IAIInput';
|
import IAIInput from 'common/components/IAIInput';
|
||||||
import IAINumberInput from 'common/components/IAINumberInput';
|
import IAINumberInput from 'common/components/IAINumberInput';
|
||||||
|
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||||
import React from 'react';
|
import React from 'react';
|
||||||
|
|
||||||
import SearchModels from './SearchModels';
|
|
||||||
|
|
||||||
// import { addNewModel } from 'app/socketio/actions';
|
// import { addNewModel } from 'app/socketio/actions';
|
||||||
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
@ -24,12 +22,13 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { Field, Formik } from 'formik';
|
import { Field, Formik } from 'formik';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import type { InvokeModelConfigProps } from 'app/types/invokeai';
|
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
import type { InvokeModelConfigProps } from 'app/types/invokeai';
|
||||||
import type { FieldInputProps, FormikProps } from 'formik';
|
|
||||||
import IAIForm from 'common/components/IAIForm';
|
import IAIForm from 'common/components/IAIForm';
|
||||||
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
|
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
|
||||||
|
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
||||||
|
import type { FieldInputProps, FormikProps } from 'formik';
|
||||||
|
import SearchModels from './SearchModels';
|
||||||
|
|
||||||
const MIN_MODEL_SIZE = 64;
|
const MIN_MODEL_SIZE = 64;
|
||||||
const MAX_MODEL_SIZE = 2048;
|
const MAX_MODEL_SIZE = 2048;
|
@ -66,7 +66,7 @@ export default function AddDiffusersModel() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex>
|
<Flex overflow="scroll" maxHeight={window.innerHeight - 270}>
|
||||||
<Formik
|
<Formik
|
||||||
initialValues={addModelFormValues}
|
initialValues={addModelFormValues}
|
||||||
onSubmit={addModelFormSubmitHandler}
|
onSubmit={addModelFormSubmitHandler}
|
@ -0,0 +1,260 @@
|
|||||||
|
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import IAIInput from 'common/components/IAIInput';
|
||||||
|
import IAISelect from 'common/components/IAISelect';
|
||||||
|
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { pickBy } from 'lodash-es';
|
||||||
|
import { useState } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
export default function MergeModelsPanel() {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const { data } = useListModelsQuery({
|
||||||
|
model_type: 'main',
|
||||||
|
});
|
||||||
|
|
||||||
|
const diffusersModels = pickBy(
|
||||||
|
data?.entities,
|
||||||
|
(value, _) => value?.model_format === 'diffusers'
|
||||||
|
);
|
||||||
|
|
||||||
|
const [modelOne, setModelOne] = useState<string>(
|
||||||
|
Object.keys(diffusersModels)[0]
|
||||||
|
);
|
||||||
|
const [modelTwo, setModelTwo] = useState<string>(
|
||||||
|
Object.keys(diffusersModels)[1]
|
||||||
|
);
|
||||||
|
const [modelThree, setModelThree] = useState<string>('none');
|
||||||
|
|
||||||
|
const [mergedModelName, setMergedModelName] = useState<string>('');
|
||||||
|
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
|
||||||
|
|
||||||
|
const [modelMergeInterp, setModelMergeInterp] = useState<
|
||||||
|
'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'
|
||||||
|
>('weighted_sum');
|
||||||
|
|
||||||
|
const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState<
|
||||||
|
'root' | 'custom'
|
||||||
|
>('root');
|
||||||
|
|
||||||
|
const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] =
|
||||||
|
useState<string>('');
|
||||||
|
|
||||||
|
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
|
||||||
|
|
||||||
|
const modelOneList = Object.keys(diffusersModels).filter(
|
||||||
|
(model) => model !== modelTwo && model !== modelThree
|
||||||
|
);
|
||||||
|
|
||||||
|
const modelTwoList = Object.keys(diffusersModels).filter(
|
||||||
|
(model) => model !== modelOne && model !== modelThree
|
||||||
|
);
|
||||||
|
|
||||||
|
const modelThreeList = [
|
||||||
|
{ key: t('modelManager.none'), value: 'none' },
|
||||||
|
...Object.keys(diffusersModels)
|
||||||
|
.filter((model) => model !== modelOne && model !== modelTwo)
|
||||||
|
.map((model) => ({ key: model, value: model })),
|
||||||
|
];
|
||||||
|
|
||||||
|
const isProcessing = useAppSelector(
|
||||||
|
(state: RootState) => state.system.isProcessing
|
||||||
|
);
|
||||||
|
|
||||||
|
const mergeModelsHandler = () => {
|
||||||
|
let modelsToMerge: string[] = [modelOne, modelTwo, modelThree];
|
||||||
|
modelsToMerge = modelsToMerge.filter((model) => model !== 'none');
|
||||||
|
|
||||||
|
const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = {
|
||||||
|
models_to_merge: modelsToMerge,
|
||||||
|
merged_model_name:
|
||||||
|
mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'),
|
||||||
|
alpha: modelMergeAlpha,
|
||||||
|
interp: modelMergeInterp,
|
||||||
|
model_merge_save_path:
|
||||||
|
modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
|
||||||
|
force: modelMergeForce,
|
||||||
|
};
|
||||||
|
|
||||||
|
dispatch(mergeDiffusersModels(mergeModelsInfo));
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex flexDirection="column" rowGap={4}>
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
flexDirection: 'column',
|
||||||
|
rowGap: 1,
|
||||||
|
bg: 'base.900',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
|
||||||
|
<Text fontSize="sm" variant="subtext">
|
||||||
|
{t('modelManager.modelMergeHeaderHelp2')}
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
<Flex columnGap={4}>
|
||||||
|
<IAISelect
|
||||||
|
label={t('modelManager.modelOne')}
|
||||||
|
validValues={modelOneList}
|
||||||
|
onChange={(e) => setModelOne(e.target.value)}
|
||||||
|
/>
|
||||||
|
<IAISelect
|
||||||
|
label={t('modelManager.modelTwo')}
|
||||||
|
validValues={modelTwoList}
|
||||||
|
onChange={(e) => setModelTwo(e.target.value)}
|
||||||
|
/>
|
||||||
|
<IAISelect
|
||||||
|
label={t('modelManager.modelThree')}
|
||||||
|
validValues={modelThreeList}
|
||||||
|
onChange={(e) => {
|
||||||
|
if (e.target.value !== 'none') {
|
||||||
|
setModelThree(e.target.value);
|
||||||
|
setModelMergeInterp('add_difference');
|
||||||
|
} else {
|
||||||
|
setModelThree('none');
|
||||||
|
setModelMergeInterp('weighted_sum');
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
|
||||||
|
<IAIInput
|
||||||
|
label={t('modelManager.mergedModelName')}
|
||||||
|
value={mergedModelName}
|
||||||
|
onChange={(e) => setMergedModelName(e.target.value)}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
flexDirection: 'column',
|
||||||
|
padding: 4,
|
||||||
|
borderRadius: 'base',
|
||||||
|
gap: 4,
|
||||||
|
bg: 'base.900',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<IAISlider
|
||||||
|
label={t('modelManager.alpha')}
|
||||||
|
min={0.01}
|
||||||
|
max={0.99}
|
||||||
|
step={0.01}
|
||||||
|
value={modelMergeAlpha}
|
||||||
|
onChange={(v) => setModelMergeAlpha(v)}
|
||||||
|
withInput
|
||||||
|
withReset
|
||||||
|
handleReset={() => setModelMergeAlpha(0.5)}
|
||||||
|
withSliderMarks
|
||||||
|
/>
|
||||||
|
<Text variant="subtext" fontSize="sm">
|
||||||
|
{t('modelManager.modelMergeAlphaHelp')}
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
padding: 4,
|
||||||
|
borderRadius: 'base',
|
||||||
|
gap: 4,
|
||||||
|
bg: 'base.900',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Text fontWeight={500} fontSize="sm" variant="subtext">
|
||||||
|
{t('modelManager.interpolationType')}
|
||||||
|
</Text>
|
||||||
|
<RadioGroup
|
||||||
|
value={modelMergeInterp}
|
||||||
|
onChange={(
|
||||||
|
v: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'
|
||||||
|
) => setModelMergeInterp(v)}
|
||||||
|
>
|
||||||
|
<Flex columnGap={4}>
|
||||||
|
{modelThree === 'none' ? (
|
||||||
|
<>
|
||||||
|
<Radio value="weighted_sum">
|
||||||
|
<Text fontSize="sm">{t('modelManager.weightedSum')}</Text>
|
||||||
|
</Radio>
|
||||||
|
<Radio value="sigmoid">
|
||||||
|
<Text fontSize="sm">{t('modelManager.sigmoid')}</Text>
|
||||||
|
</Radio>
|
||||||
|
<Radio value="inv_sigmoid">
|
||||||
|
<Text fontSize="sm">{t('modelManager.inverseSigmoid')}</Text>
|
||||||
|
</Radio>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<Radio value="add_difference">
|
||||||
|
<Tooltip
|
||||||
|
label={t('modelManager.modelMergeInterpAddDifferenceHelp')}
|
||||||
|
>
|
||||||
|
<Text fontSize="sm">{t('modelManager.addDifference')}</Text>
|
||||||
|
</Tooltip>
|
||||||
|
</Radio>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
</RadioGroup>
|
||||||
|
</Flex>
|
||||||
|
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
flexDirection: 'column',
|
||||||
|
padding: 4,
|
||||||
|
borderRadius: 'base',
|
||||||
|
gap: 4,
|
||||||
|
bg: 'base.900',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Flex columnGap={4}>
|
||||||
|
<Text fontWeight="500" fontSize="sm" variant="subtext">
|
||||||
|
{t('modelManager.mergedModelSaveLocation')}
|
||||||
|
</Text>
|
||||||
|
<RadioGroup
|
||||||
|
value={modelMergeSaveLocType}
|
||||||
|
onChange={(v: 'root' | 'custom') => setModelMergeSaveLocType(v)}
|
||||||
|
>
|
||||||
|
<Flex columnGap={4}>
|
||||||
|
<Radio value="root">
|
||||||
|
<Text fontSize="sm">{t('modelManager.invokeAIFolder')}</Text>
|
||||||
|
</Radio>
|
||||||
|
|
||||||
|
<Radio value="custom">
|
||||||
|
<Text fontSize="sm">{t('modelManager.custom')}</Text>
|
||||||
|
</Radio>
|
||||||
|
</Flex>
|
||||||
|
</RadioGroup>
|
||||||
|
</Flex>
|
||||||
|
|
||||||
|
{modelMergeSaveLocType === 'custom' && (
|
||||||
|
<IAIInput
|
||||||
|
label={t('modelManager.mergedModelCustomSaveLocation')}
|
||||||
|
value={modelMergeCustomSaveLoc}
|
||||||
|
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
|
||||||
|
<IAISimpleCheckbox
|
||||||
|
label={t('modelManager.ignoreMismatch')}
|
||||||
|
isChecked={modelMergeForce}
|
||||||
|
onChange={(e) => setModelMergeForce(e.target.checked)}
|
||||||
|
fontWeight="500"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<IAIButton
|
||||||
|
onClick={mergeModelsHandler}
|
||||||
|
isLoading={isProcessing}
|
||||||
|
isDisabled={
|
||||||
|
modelMergeSaveLocType === 'custom' && modelMergeCustomSaveLoc === ''
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{t('modelManager.merge')}
|
||||||
|
</IAIButton>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,46 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
|
||||||
|
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||||
|
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
||||||
|
import ModelList from './ModelManagerPanel/ModelList';
|
||||||
|
|
||||||
|
export default function ModelManagerPanel() {
|
||||||
|
const { data: mainModels } = useListModelsQuery({
|
||||||
|
model_type: 'main',
|
||||||
|
});
|
||||||
|
|
||||||
|
const openModel = useAppSelector(
|
||||||
|
(state: RootState) => state.system.openModel
|
||||||
|
);
|
||||||
|
|
||||||
|
const renderModelEditTabs = () => {
|
||||||
|
if (!openModel || !mainModels) return;
|
||||||
|
|
||||||
|
if (mainModels['entities'][openModel]['model_format'] === 'diffusers') {
|
||||||
|
return (
|
||||||
|
<DiffusersModelEdit
|
||||||
|
modelToEdit={openModel}
|
||||||
|
retrievedModel={mainModels['entities'][openModel]}
|
||||||
|
key={openModel}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
return (
|
||||||
|
<CheckpointModelEdit
|
||||||
|
modelToEdit={openModel}
|
||||||
|
retrievedModel={mainModels['entities'][openModel]}
|
||||||
|
key={openModel}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return (
|
||||||
|
<Flex width="100%" columnGap={8}>
|
||||||
|
<ModelList />
|
||||||
|
{renderModelEditTabs()}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,141 @@
|
|||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
|
||||||
|
import { Divider, Flex, Text } from '@chakra-ui/react';
|
||||||
|
|
||||||
|
// import { addNewModel } from 'app/socketio/actions';
|
||||||
|
import { useForm } from '@mantine/form';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
import type { RootState } from 'app/store/store';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import IAIInput from 'common/components/IAIInput';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
|
||||||
|
import { S } from 'services/api/types';
|
||||||
|
import ModelConvert from './ModelConvert';
|
||||||
|
|
||||||
|
const baseModelSelectData = [
|
||||||
|
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
||||||
|
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
|
||||||
|
];
|
||||||
|
|
||||||
|
const variantSelectData = [
|
||||||
|
{ value: 'normal', label: 'Normal' },
|
||||||
|
{ value: 'inpaint', label: 'Inpaint' },
|
||||||
|
{ value: 'depth', label: 'Depth' },
|
||||||
|
];
|
||||||
|
|
||||||
|
export type CheckpointModel =
|
||||||
|
| S<'StableDiffusion1ModelCheckpointConfig'>
|
||||||
|
| S<'StableDiffusion2ModelCheckpointConfig'>;
|
||||||
|
|
||||||
|
type CheckpointModelEditProps = {
|
||||||
|
modelToEdit: string;
|
||||||
|
retrievedModel: CheckpointModel;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
||||||
|
const isProcessing = useAppSelector(
|
||||||
|
(state: RootState) => state.system.isProcessing
|
||||||
|
);
|
||||||
|
|
||||||
|
const { modelToEdit, retrievedModel } = props;
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const checkpointEditForm = useForm({
|
||||||
|
initialValues: {
|
||||||
|
name: retrievedModel.name,
|
||||||
|
base_model: retrievedModel.base_model,
|
||||||
|
type: 'main',
|
||||||
|
path: retrievedModel.path,
|
||||||
|
description: retrievedModel.description,
|
||||||
|
model_format: 'checkpoint',
|
||||||
|
vae: retrievedModel.vae,
|
||||||
|
config: retrievedModel.config,
|
||||||
|
variant: retrievedModel.variant,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const editModelFormSubmitHandler = (values) => {
|
||||||
|
console.log(values);
|
||||||
|
};
|
||||||
|
|
||||||
|
return modelToEdit ? (
|
||||||
|
<Flex flexDirection="column" rowGap={4} width="100%">
|
||||||
|
<Flex justifyContent="space-between" alignItems="center">
|
||||||
|
<Flex flexDirection="column">
|
||||||
|
<Text fontSize="lg" fontWeight="bold">
|
||||||
|
{retrievedModel.name}
|
||||||
|
</Text>
|
||||||
|
<Text fontSize="sm" color="base.400">
|
||||||
|
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
<ModelConvert model={retrievedModel} />
|
||||||
|
</Flex>
|
||||||
|
<Divider />
|
||||||
|
|
||||||
|
<Flex
|
||||||
|
flexDirection="column"
|
||||||
|
maxHeight={window.innerHeight - 270}
|
||||||
|
overflowY="scroll"
|
||||||
|
>
|
||||||
|
<form
|
||||||
|
onSubmit={checkpointEditForm.onSubmit((values) =>
|
||||||
|
editModelFormSubmitHandler(values)
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||||
|
<IAIInput
|
||||||
|
label={t('modelManager.name')}
|
||||||
|
{...checkpointEditForm.getInputProps('name')}
|
||||||
|
/>
|
||||||
|
<IAIInput
|
||||||
|
label={t('modelManager.description')}
|
||||||
|
{...checkpointEditForm.getInputProps('description')}
|
||||||
|
/>
|
||||||
|
<IAIMantineSelect
|
||||||
|
label={t('modelManager.baseModel')}
|
||||||
|
data={baseModelSelectData}
|
||||||
|
{...checkpointEditForm.getInputProps('base_model')}
|
||||||
|
/>
|
||||||
|
<IAIMantineSelect
|
||||||
|
label={t('modelManager.variant')}
|
||||||
|
data={variantSelectData}
|
||||||
|
{...checkpointEditForm.getInputProps('variant')}
|
||||||
|
/>
|
||||||
|
<IAIInput
|
||||||
|
label={t('modelManager.modelLocation')}
|
||||||
|
{...checkpointEditForm.getInputProps('path')}
|
||||||
|
/>
|
||||||
|
<IAIInput
|
||||||
|
label={t('modelManager.vaeLocation')}
|
||||||
|
{...checkpointEditForm.getInputProps('vae')}
|
||||||
|
/>
|
||||||
|
<IAIInput
|
||||||
|
label={t('modelManager.config')}
|
||||||
|
{...checkpointEditForm.getInputProps('config')}
|
||||||
|
/>
|
||||||
|
<IAIButton disabled={isProcessing} type="submit">
|
||||||
|
{t('modelManager.updateModel')}
|
||||||
|
</IAIButton>
|
||||||
|
</Flex>
|
||||||
|
</form>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
) : (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
width: '100%',
|
||||||
|
justifyContent: 'center',
|
||||||
|
alignItems: 'center',
|
||||||
|
borderRadius: 'base',
|
||||||
|
bg: 'base.900',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Text fontWeight={500}>Pick A Model To Edit</Text>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,125 @@
|
|||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
|
||||||
|
import { Divider, Flex, Text } from '@chakra-ui/react';
|
||||||
|
|
||||||
|
// import { addNewModel } from 'app/socketio/actions';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
import { useForm } from '@mantine/form';
|
||||||
|
import type { RootState } from 'app/store/store';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import IAIInput from 'common/components/IAIInput';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
|
||||||
|
import { S } from 'services/api/types';
|
||||||
|
|
||||||
|
type DiffusersModel =
|
||||||
|
| S<'StableDiffusion1ModelDiffusersConfig'>
|
||||||
|
| S<'StableDiffusion2ModelDiffusersConfig'>;
|
||||||
|
|
||||||
|
type DiffusersModelEditProps = {
|
||||||
|
modelToEdit: string;
|
||||||
|
retrievedModel: DiffusersModel;
|
||||||
|
};
|
||||||
|
|
||||||
|
const baseModelSelectData = [
|
||||||
|
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
||||||
|
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
|
||||||
|
];
|
||||||
|
|
||||||
|
const variantSelectData = [
|
||||||
|
{ value: 'normal', label: 'Normal' },
|
||||||
|
{ value: 'inpaint', label: 'Inpaint' },
|
||||||
|
{ value: 'depth', label: 'Depth' },
|
||||||
|
];
|
||||||
|
|
||||||
|
export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
|
||||||
|
const isProcessing = useAppSelector(
|
||||||
|
(state: RootState) => state.system.isProcessing
|
||||||
|
);
|
||||||
|
const { retrievedModel, modelToEdit } = props;
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const diffusersEditForm = useForm({
|
||||||
|
initialValues: {
|
||||||
|
name: retrievedModel.name,
|
||||||
|
base_model: retrievedModel.base_model,
|
||||||
|
type: 'main',
|
||||||
|
path: retrievedModel.path,
|
||||||
|
description: retrievedModel.description,
|
||||||
|
model_format: 'diffusers',
|
||||||
|
vae: retrievedModel.vae,
|
||||||
|
variant: retrievedModel.variant,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const editModelFormSubmitHandler = (values) => {
|
||||||
|
console.log(values);
|
||||||
|
};
|
||||||
|
|
||||||
|
return modelToEdit ? (
|
||||||
|
<Flex flexDirection="column" rowGap={4} width="100%">
|
||||||
|
<Flex flexDirection="column">
|
||||||
|
<Text fontSize="lg" fontWeight="bold">
|
||||||
|
{retrievedModel.name}
|
||||||
|
</Text>
|
||||||
|
<Text fontSize="sm" color="base.400">
|
||||||
|
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
<Divider />
|
||||||
|
|
||||||
|
<form
|
||||||
|
onSubmit={diffusersEditForm.onSubmit((values) =>
|
||||||
|
editModelFormSubmitHandler(values)
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||||
|
<IAIInput
|
||||||
|
label={t('modelManager.name')}
|
||||||
|
{...diffusersEditForm.getInputProps('name')}
|
||||||
|
/>
|
||||||
|
<IAIInput
|
||||||
|
label={t('modelManager.description')}
|
||||||
|
{...diffusersEditForm.getInputProps('description')}
|
||||||
|
/>
|
||||||
|
<IAIMantineSelect
|
||||||
|
label={t('modelManager.baseModel')}
|
||||||
|
data={baseModelSelectData}
|
||||||
|
{...diffusersEditForm.getInputProps('base_model')}
|
||||||
|
/>
|
||||||
|
<IAIMantineSelect
|
||||||
|
label={t('modelManager.variant')}
|
||||||
|
data={variantSelectData}
|
||||||
|
{...diffusersEditForm.getInputProps('variant')}
|
||||||
|
/>
|
||||||
|
<IAIInput
|
||||||
|
label={t('modelManager.modelLocation')}
|
||||||
|
{...diffusersEditForm.getInputProps('path')}
|
||||||
|
/>
|
||||||
|
<IAIInput
|
||||||
|
label={t('modelManager.vaeLocation')}
|
||||||
|
{...diffusersEditForm.getInputProps('vae')}
|
||||||
|
/>
|
||||||
|
<IAIButton disabled={isProcessing} type="submit">
|
||||||
|
{t('modelManager.updateModel')}
|
||||||
|
</IAIButton>
|
||||||
|
</Flex>
|
||||||
|
</form>
|
||||||
|
</Flex>
|
||||||
|
) : (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
width: '100%',
|
||||||
|
justifyContent: 'center',
|
||||||
|
alignItems: 'center',
|
||||||
|
borderRadius: 'base',
|
||||||
|
bg: 'base.900',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Text fontWeight={'500'}>Pick A Model To Edit</Text>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
@ -4,42 +4,28 @@ import {
|
|||||||
Radio,
|
Radio,
|
||||||
RadioGroup,
|
RadioGroup,
|
||||||
Text,
|
Text,
|
||||||
UnorderedList,
|
|
||||||
Tooltip,
|
Tooltip,
|
||||||
|
UnorderedList,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
// import { convertToDiffusers } from 'app/socketio/actions';
|
// import { convertToDiffusers } from 'app/socketio/actions';
|
||||||
import { RootState } from 'app/store/store';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import IAIAlertDialog from 'common/components/IAIAlertDialog';
|
import IAIAlertDialog from 'common/components/IAIAlertDialog';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAIInput from 'common/components/IAIInput';
|
import IAIInput from 'common/components/IAIInput';
|
||||||
import { useState, useEffect } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { CheckpointModel } from './CheckpointModelEdit';
|
||||||
|
|
||||||
interface ModelConvertProps {
|
interface ModelConvertProps {
|
||||||
model: string;
|
model: CheckpointModel;
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function ModelConvert(props: ModelConvertProps) {
|
export default function ModelConvert(props: ModelConvertProps) {
|
||||||
const { model } = props;
|
const { model } = props;
|
||||||
|
|
||||||
const model_list = useAppSelector(
|
|
||||||
(state: RootState) => state.system.model_list
|
|
||||||
);
|
|
||||||
|
|
||||||
const retrievedModel = model_list[model];
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const isProcessing = useAppSelector(
|
|
||||||
(state: RootState) => state.system.isProcessing
|
|
||||||
);
|
|
||||||
|
|
||||||
const isConnected = useAppSelector(
|
|
||||||
(state: RootState) => state.system.isConnected
|
|
||||||
);
|
|
||||||
|
|
||||||
const [saveLocation, setSaveLocation] = useState<string>('same');
|
const [saveLocation, setSaveLocation] = useState<string>('same');
|
||||||
const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
|
const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
|
||||||
|
|
||||||
@ -65,7 +51,7 @@ export default function ModelConvert(props: ModelConvertProps) {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIAlertDialog
|
<IAIAlertDialog
|
||||||
title={`${t('modelManager.convert')} ${model}`}
|
title={`${t('modelManager.convert')} ${model.name}`}
|
||||||
acceptCallback={modelConvertHandler}
|
acceptCallback={modelConvertHandler}
|
||||||
cancelCallback={modelConvertCancelHandler}
|
cancelCallback={modelConvertCancelHandler}
|
||||||
acceptButtonText={`${t('modelManager.convert')}`}
|
acceptButtonText={`${t('modelManager.convert')}`}
|
||||||
@ -73,11 +59,7 @@ export default function ModelConvert(props: ModelConvertProps) {
|
|||||||
<IAIButton
|
<IAIButton
|
||||||
size={'sm'}
|
size={'sm'}
|
||||||
aria-label={t('modelManager.convertToDiffusers')}
|
aria-label={t('modelManager.convertToDiffusers')}
|
||||||
isDisabled={
|
|
||||||
retrievedModel.status === 'active' || isProcessing || !isConnected
|
|
||||||
}
|
|
||||||
className=" modal-close-btn"
|
className=" modal-close-btn"
|
||||||
marginInlineEnd={8}
|
|
||||||
>
|
>
|
||||||
🧨 {t('modelManager.convertToDiffusers')}
|
🧨 {t('modelManager.convertToDiffusers')}
|
||||||
</IAIButton>
|
</IAIButton>
|
@ -1,36 +1,14 @@
|
|||||||
import { Box, Flex, Heading, Spacer, Spinner, Text } from '@chakra-ui/react';
|
import { Box, Flex, Spinner, Text } from '@chakra-ui/react';
|
||||||
import IAIInput from 'common/components/IAIInput';
|
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import IAIInput from 'common/components/IAIInput';
|
||||||
|
|
||||||
import AddModel from './AddModel';
|
|
||||||
import ModelListItem from './ModelListItem';
|
import ModelListItem from './ModelListItem';
|
||||||
import MergeModels from './MergeModels';
|
|
||||||
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
import type { SystemState } from 'features/system/store/systemSlice';
|
|
||||||
import { isEqual, map } from 'lodash-es';
|
|
||||||
|
|
||||||
import React, { useMemo, useState, useTransition } from 'react';
|
|
||||||
import type { ChangeEvent, ReactNode } from 'react';
|
import type { ChangeEvent, ReactNode } from 'react';
|
||||||
|
import React, { useMemo, useState, useTransition } from 'react';
|
||||||
const modelListSelector = createSelector(
|
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||||
systemSelector,
|
|
||||||
(system: SystemState) => {
|
|
||||||
const models = map(system.model_list, (model, key) => {
|
|
||||||
return { name: key, ...model };
|
|
||||||
});
|
|
||||||
return models;
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
function ModelFilterButton({
|
function ModelFilterButton({
|
||||||
label,
|
label,
|
||||||
@ -58,7 +36,9 @@ function ModelFilterButton({
|
|||||||
}
|
}
|
||||||
|
|
||||||
const ModelList = () => {
|
const ModelList = () => {
|
||||||
const models = useAppSelector(modelListSelector);
|
const { data: mainModels } = useListModelsQuery({
|
||||||
|
model_type: 'main',
|
||||||
|
});
|
||||||
|
|
||||||
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
|
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
|
||||||
|
|
||||||
@ -90,43 +70,49 @@ const ModelList = () => {
|
|||||||
const filteredModelListItemsToRender: ReactNode[] = [];
|
const filteredModelListItemsToRender: ReactNode[] = [];
|
||||||
const localFilteredModelListItemsToRender: ReactNode[] = [];
|
const localFilteredModelListItemsToRender: ReactNode[] = [];
|
||||||
|
|
||||||
models.forEach((model, i) => {
|
if (!mainModels) return;
|
||||||
if (model.name.toLowerCase().includes(searchText.toLowerCase())) {
|
|
||||||
|
const modelList = mainModels.entities;
|
||||||
|
|
||||||
|
Object.keys(modelList).forEach((model, i) => {
|
||||||
|
if (
|
||||||
|
modelList[model].name.toLowerCase().includes(searchText.toLowerCase())
|
||||||
|
) {
|
||||||
filteredModelListItemsToRender.push(
|
filteredModelListItemsToRender.push(
|
||||||
<ModelListItem
|
<ModelListItem
|
||||||
key={i}
|
key={i}
|
||||||
name={model.name}
|
modelKey={model}
|
||||||
status={model.status}
|
name={modelList[model].name}
|
||||||
description={model.description}
|
description={modelList[model].description}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
if (model.format === isSelectedFilter) {
|
if (modelList[model]?.model_format === isSelectedFilter) {
|
||||||
localFilteredModelListItemsToRender.push(
|
localFilteredModelListItemsToRender.push(
|
||||||
<ModelListItem
|
<ModelListItem
|
||||||
key={i}
|
key={i}
|
||||||
name={model.name}
|
modelKey={model}
|
||||||
status={model.status}
|
name={modelList[model].name}
|
||||||
description={model.description}
|
description={modelList[model].description}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (model.format !== 'diffusers') {
|
if (modelList[model]?.model_format !== 'diffusers') {
|
||||||
ckptModelListItemsToRender.push(
|
ckptModelListItemsToRender.push(
|
||||||
<ModelListItem
|
<ModelListItem
|
||||||
key={i}
|
key={i}
|
||||||
name={model.name}
|
modelKey={model}
|
||||||
status={model.status}
|
name={modelList[model].name}
|
||||||
description={model.description}
|
description={modelList[model].description}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
diffusersModelListItemsToRender.push(
|
diffusersModelListItemsToRender.push(
|
||||||
<ModelListItem
|
<ModelListItem
|
||||||
key={i}
|
key={i}
|
||||||
name={model.name}
|
modelKey={model}
|
||||||
status={model.status}
|
name={modelList[model].name}
|
||||||
description={model.description}
|
description={modelList[model].description}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -142,6 +128,23 @@ const ModelList = () => {
|
|||||||
<Flex flexDirection="column" rowGap={6}>
|
<Flex flexDirection="column" rowGap={6}>
|
||||||
{isSelectedFilter === 'all' && (
|
{isSelectedFilter === 'all' && (
|
||||||
<>
|
<>
|
||||||
|
<Box>
|
||||||
|
<Text
|
||||||
|
sx={{
|
||||||
|
fontWeight: '500',
|
||||||
|
py: 2,
|
||||||
|
px: 4,
|
||||||
|
mb: 4,
|
||||||
|
borderRadius: 'base',
|
||||||
|
width: 'max-content',
|
||||||
|
fontSize: 'sm',
|
||||||
|
bg: 'base.750',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{t('modelManager.diffusersModels')}
|
||||||
|
</Text>
|
||||||
|
{diffusersModelListItemsToRender}
|
||||||
|
</Box>
|
||||||
<Box>
|
<Box>
|
||||||
<Text
|
<Text
|
||||||
sx={{
|
sx={{
|
||||||
@ -160,50 +163,26 @@ const ModelList = () => {
|
|||||||
</Text>
|
</Text>
|
||||||
{ckptModelListItemsToRender}
|
{ckptModelListItemsToRender}
|
||||||
</Box>
|
</Box>
|
||||||
<Box>
|
|
||||||
<Text
|
|
||||||
sx={{
|
|
||||||
fontWeight: '500',
|
|
||||||
py: 2,
|
|
||||||
px: 4,
|
|
||||||
mb: 4,
|
|
||||||
borderRadius: 'base',
|
|
||||||
width: 'max-content',
|
|
||||||
fontSize: 'sm',
|
|
||||||
bg: 'base.750',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{t('modelManager.diffusersModels')}
|
|
||||||
</Text>
|
|
||||||
{diffusersModelListItemsToRender}
|
|
||||||
</Box>
|
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{isSelectedFilter === 'ckpt' && (
|
|
||||||
<Flex flexDirection="column" marginTop={4}>
|
|
||||||
{ckptModelListItemsToRender}
|
|
||||||
</Flex>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{isSelectedFilter === 'diffusers' && (
|
{isSelectedFilter === 'diffusers' && (
|
||||||
<Flex flexDirection="column" marginTop={4}>
|
<Flex flexDirection="column" marginTop={4}>
|
||||||
{diffusersModelListItemsToRender}
|
{diffusersModelListItemsToRender}
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{isSelectedFilter === 'ckpt' && (
|
||||||
|
<Flex flexDirection="column" marginTop={4}>
|
||||||
|
{ckptModelListItemsToRender}
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
}, [models, searchText, t, isSelectedFilter]);
|
}, [mainModels, searchText, t, isSelectedFilter]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
|
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
|
||||||
<Flex justifyContent="space-between" alignItems="center" gap={2}>
|
|
||||||
<Heading size="md">{t('modelManager.availableModels')}</Heading>
|
|
||||||
<Spacer />
|
|
||||||
<AddModel />
|
|
||||||
<MergeModels />
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<IAIInput
|
<IAIInput
|
||||||
onChange={handleSearchFilter}
|
onChange={handleSearchFilter}
|
||||||
label={t('modelManager.search')}
|
label={t('modelManager.search')}
|
||||||
@ -211,7 +190,7 @@ const ModelList = () => {
|
|||||||
|
|
||||||
<Flex
|
<Flex
|
||||||
flexDirection="column"
|
flexDirection="column"
|
||||||
gap={1}
|
gap={4}
|
||||||
maxHeight={window.innerHeight - 240}
|
maxHeight={window.innerHeight - 240}
|
||||||
overflow="scroll"
|
overflow="scroll"
|
||||||
paddingInlineEnd={4}
|
paddingInlineEnd={4}
|
||||||
@ -222,16 +201,16 @@ const ModelList = () => {
|
|||||||
onClick={() => setIsSelectedFilter('all')}
|
onClick={() => setIsSelectedFilter('all')}
|
||||||
isActive={isSelectedFilter === 'all'}
|
isActive={isSelectedFilter === 'all'}
|
||||||
/>
|
/>
|
||||||
<ModelFilterButton
|
|
||||||
label={t('modelManager.checkpointModels')}
|
|
||||||
onClick={() => setIsSelectedFilter('ckpt')}
|
|
||||||
isActive={isSelectedFilter === 'ckpt'}
|
|
||||||
/>
|
|
||||||
<ModelFilterButton
|
<ModelFilterButton
|
||||||
label={t('modelManager.diffusersModels')}
|
label={t('modelManager.diffusersModels')}
|
||||||
onClick={() => setIsSelectedFilter('diffusers')}
|
onClick={() => setIsSelectedFilter('diffusers')}
|
||||||
isActive={isSelectedFilter === 'diffusers'}
|
isActive={isSelectedFilter === 'diffusers'}
|
||||||
/>
|
/>
|
||||||
|
<ModelFilterButton
|
||||||
|
label={t('modelManager.checkpointModels')}
|
||||||
|
onClick={() => setIsSelectedFilter('ckpt')}
|
||||||
|
isActive={isSelectedFilter === 'ckpt'}
|
||||||
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
{renderModelList ? (
|
{renderModelList ? (
|
@ -1,6 +1,6 @@
|
|||||||
import { DeleteIcon, EditIcon } from '@chakra-ui/icons';
|
import { DeleteIcon, EditIcon } from '@chakra-ui/icons';
|
||||||
import { Box, Button, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react';
|
import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react';
|
||||||
import { ModelStatus } from 'app/types/invokeai';
|
|
||||||
// import { deleteModel, requestModelChange } from 'app/socketio/actions';
|
// import { deleteModel, requestModelChange } from 'app/socketio/actions';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
@ -10,9 +10,9 @@ import { setOpenModel } from 'features/system/store/systemSlice';
|
|||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
type ModelListItemProps = {
|
type ModelListItemProps = {
|
||||||
|
modelKey: string;
|
||||||
name: string;
|
name: string;
|
||||||
status: ModelStatus;
|
description: string | undefined;
|
||||||
description: string;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export default function ModelListItem(props: ModelListItemProps) {
|
export default function ModelListItem(props: ModelListItemProps) {
|
||||||
@ -28,39 +28,24 @@ export default function ModelListItem(props: ModelListItemProps) {
|
|||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { name, status, description } = props;
|
const { modelKey, name, description } = props;
|
||||||
|
|
||||||
const handleChangeModel = () => {
|
|
||||||
dispatch(requestModelChange(name));
|
|
||||||
};
|
|
||||||
|
|
||||||
const openModelHandler = () => {
|
const openModelHandler = () => {
|
||||||
dispatch(setOpenModel(name));
|
dispatch(setOpenModel(modelKey));
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleModelDelete = () => {
|
const handleModelDelete = () => {
|
||||||
dispatch(deleteModel(name));
|
dispatch(deleteModel(modelKey));
|
||||||
dispatch(setOpenModel(null));
|
dispatch(setOpenModel(null));
|
||||||
};
|
};
|
||||||
|
|
||||||
const statusTextColor = () => {
|
|
||||||
switch (status) {
|
|
||||||
case 'active':
|
|
||||||
return 'ok.500';
|
|
||||||
case 'cached':
|
|
||||||
return 'warning.500';
|
|
||||||
case 'not loaded':
|
|
||||||
return 'inherit';
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
alignItems="center"
|
alignItems="center"
|
||||||
p={2}
|
p={2}
|
||||||
borderRadius="base"
|
borderRadius="base"
|
||||||
sx={
|
sx={
|
||||||
name === openModel
|
modelKey === openModel
|
||||||
? {
|
? {
|
||||||
bg: 'accent.750',
|
bg: 'accent.750',
|
||||||
_hover: {
|
_hover: {
|
||||||
@ -81,15 +66,6 @@ export default function ModelListItem(props: ModelListItemProps) {
|
|||||||
</Box>
|
</Box>
|
||||||
<Spacer onClick={openModelHandler} cursor="pointer" />
|
<Spacer onClick={openModelHandler} cursor="pointer" />
|
||||||
<Flex gap={2} alignItems="center">
|
<Flex gap={2} alignItems="center">
|
||||||
<Text color={statusTextColor()}>{status}</Text>
|
|
||||||
<Button
|
|
||||||
size="sm"
|
|
||||||
onClick={handleChangeModel}
|
|
||||||
isDisabled={status === 'active' || isProcessing || !isConnected}
|
|
||||||
>
|
|
||||||
{t('modelManager.load')}
|
|
||||||
</Button>
|
|
||||||
|
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
icon={<EditIcon />}
|
icon={<EditIcon />}
|
||||||
size="sm"
|
size="sm"
|
@ -1,17 +1,18 @@
|
|||||||
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
|
|
||||||
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
|
|
||||||
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
|
|
||||||
import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth';
|
|
||||||
import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight';
|
|
||||||
import { Box, Flex, useDisclosure } from '@chakra-ui/react';
|
import { Box, Flex, useDisclosure } from '@chakra-ui/react';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { memo } from 'react';
|
|
||||||
import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel';
|
|
||||||
import IAICollapse from 'common/components/IAICollapse';
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
|
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
|
||||||
|
import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight';
|
||||||
|
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
|
||||||
|
import ParamModelandVAE from 'features/parameters/components/Parameters/Core/ParamModelandVAE';
|
||||||
|
import ParamScheduler from 'features/parameters/components/Parameters/Core/ParamScheduler';
|
||||||
|
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
|
||||||
|
import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth';
|
||||||
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
|
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
|
||||||
|
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
uiSelector,
|
uiSelector,
|
||||||
@ -37,7 +38,7 @@ const TextToImageTabCoreParameters = () => {
|
|||||||
>
|
>
|
||||||
{shouldUseSliders ? (
|
{shouldUseSliders ? (
|
||||||
<>
|
<>
|
||||||
<ParamSchedulerAndModel />
|
<ParamModelandVAE />
|
||||||
<Box pt={2}>
|
<Box pt={2}>
|
||||||
<ParamSeedFull />
|
<ParamSeedFull />
|
||||||
</Box>
|
</Box>
|
||||||
@ -54,7 +55,8 @@ const TextToImageTabCoreParameters = () => {
|
|||||||
<ParamSteps />
|
<ParamSteps />
|
||||||
<ParamCFGScale />
|
<ParamCFGScale />
|
||||||
</Flex>
|
</Flex>
|
||||||
<ParamSchedulerAndModel />
|
<ParamModelandVAE />
|
||||||
|
<ParamScheduler />
|
||||||
<Box pt={2}>
|
<Box pt={2}>
|
||||||
<ParamSeedFull />
|
<ParamSeedFull />
|
||||||
</Box>
|
</Box>
|
||||||
|
@ -1,18 +1,19 @@
|
|||||||
import { memo } from 'react';
|
|
||||||
import { Box, Flex, useDisclosure } from '@chakra-ui/react';
|
import { Box, Flex, useDisclosure } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
|
|
||||||
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
|
|
||||||
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
|
|
||||||
import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength';
|
|
||||||
import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel';
|
|
||||||
import ParamBoundingBoxWidth from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxWidth';
|
|
||||||
import ParamBoundingBoxHeight from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight';
|
|
||||||
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
|
|
||||||
import IAICollapse from 'common/components/IAICollapse';
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
|
import ParamBoundingBoxHeight from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight';
|
||||||
|
import ParamBoundingBoxWidth from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxWidth';
|
||||||
|
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
|
||||||
|
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
|
||||||
|
import ParamModelandVAE from 'features/parameters/components/Parameters/Core/ParamModelandVAE';
|
||||||
|
import ParamScheduler from 'features/parameters/components/Parameters/Core/ParamScheduler';
|
||||||
|
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
|
||||||
|
import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength';
|
||||||
|
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
|
||||||
|
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
uiSelector,
|
uiSelector,
|
||||||
@ -38,7 +39,7 @@ const UnifiedCanvasCoreParameters = () => {
|
|||||||
>
|
>
|
||||||
{shouldUseSliders ? (
|
{shouldUseSliders ? (
|
||||||
<>
|
<>
|
||||||
<ParamSchedulerAndModel />
|
<ParamModelandVAE />
|
||||||
<Box pt={2}>
|
<Box pt={2}>
|
||||||
<ParamSeedFull />
|
<ParamSeedFull />
|
||||||
</Box>
|
</Box>
|
||||||
@ -55,7 +56,8 @@ const UnifiedCanvasCoreParameters = () => {
|
|||||||
<ParamSteps />
|
<ParamSteps />
|
||||||
<ParamCFGScale />
|
<ParamCFGScale />
|
||||||
</Flex>
|
</Flex>
|
||||||
<ParamSchedulerAndModel />
|
<ParamModelandVAE />
|
||||||
|
<ParamScheduler />
|
||||||
<Box pt={2}>
|
<Box pt={2}>
|
||||||
<ParamSeedFull />
|
<ParamSeedFull />
|
||||||
</Box>
|
</Box>
|
||||||
|
@ -7,6 +7,7 @@ export const tabMap = [
|
|||||||
'batch',
|
'batch',
|
||||||
// 'postprocessing',
|
// 'postprocessing',
|
||||||
// 'training',
|
// 'training',
|
||||||
|
'modelManager',
|
||||||
] as const;
|
] as const;
|
||||||
|
|
||||||
export type InvokeTabName = (typeof tabMap)[number];
|
export type InvokeTabName = (typeof tabMap)[number];
|
||||||
|
240
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
240
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
@ -76,9 +76,16 @@ export type paths = {
|
|||||||
*/
|
*/
|
||||||
get: operations["list_models"];
|
get: operations["list_models"];
|
||||||
/**
|
/**
|
||||||
* Import Model
|
* Update Model
|
||||||
* @description Add Model
|
* @description Add Model
|
||||||
*/
|
*/
|
||||||
|
post: operations["update_model"];
|
||||||
|
};
|
||||||
|
"/api/v1/models/import": {
|
||||||
|
/**
|
||||||
|
* Import Model
|
||||||
|
* @description Add a model using its local path, repo_id, or remote URL
|
||||||
|
*/
|
||||||
post: operations["import_model"];
|
post: operations["import_model"];
|
||||||
};
|
};
|
||||||
"/api/v1/models/{model_name}": {
|
"/api/v1/models/{model_name}": {
|
||||||
@ -227,6 +234,23 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
b?: number;
|
b?: number;
|
||||||
};
|
};
|
||||||
|
/** AddModelResult */
|
||||||
|
AddModelResult: {
|
||||||
|
/**
|
||||||
|
* Name
|
||||||
|
* @description The name of the model after import
|
||||||
|
*/
|
||||||
|
name: string;
|
||||||
|
/** @description The type of model */
|
||||||
|
model_type: components["schemas"]["ModelType"];
|
||||||
|
/** @description The base model */
|
||||||
|
base_model: components["schemas"]["BaseModelType"];
|
||||||
|
/**
|
||||||
|
* Config
|
||||||
|
* @description The configuration of the model
|
||||||
|
*/
|
||||||
|
config: components["schemas"]["ModelConfigBase"];
|
||||||
|
};
|
||||||
/**
|
/**
|
||||||
* BaseModelType
|
* BaseModelType
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
@ -1030,7 +1054,7 @@ export type components = {
|
|||||||
* @description The nodes in this graph
|
* @description The nodes in this graph
|
||||||
*/
|
*/
|
||||||
nodes?: {
|
nodes?: {
|
||||||
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
|
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* Edges
|
* Edges
|
||||||
@ -1073,7 +1097,7 @@ export type components = {
|
|||||||
* @description The results of node executions
|
* @description The results of node executions
|
||||||
*/
|
*/
|
||||||
results: {
|
results: {
|
||||||
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
|
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* Errors
|
* Errors
|
||||||
@ -1975,19 +1999,23 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
thumbnail_url: string;
|
thumbnail_url: string;
|
||||||
};
|
};
|
||||||
/** ImportModelRequest */
|
/** ImportModelResponse */
|
||||||
ImportModelRequest: {
|
ImportModelResponse: {
|
||||||
/**
|
/**
|
||||||
* Name
|
* Name
|
||||||
* @description A model path, repo_id or URL to import
|
* @description The name of the imported model
|
||||||
*/
|
*/
|
||||||
name: string;
|
name: string;
|
||||||
/**
|
/**
|
||||||
* Prediction Type
|
* Info
|
||||||
* @description Prediction type for SDv2 checkpoint files
|
* @description The model info
|
||||||
* @enum {string}
|
|
||||||
*/
|
*/
|
||||||
prediction_type?: "epsilon" | "v_prediction" | "sample";
|
info: components["schemas"]["AddModelResult"];
|
||||||
|
/**
|
||||||
|
* Status
|
||||||
|
* @description The status of the API response
|
||||||
|
*/
|
||||||
|
status: string;
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* InfillColorInvocation
|
* InfillColorInvocation
|
||||||
@ -2781,6 +2809,47 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
clip?: components["schemas"]["ClipField"];
|
clip?: components["schemas"]["ClipField"];
|
||||||
};
|
};
|
||||||
|
/**
|
||||||
|
* MainModelField
|
||||||
|
* @description Main model field
|
||||||
|
*/
|
||||||
|
MainModelField: {
|
||||||
|
/**
|
||||||
|
* Model Name
|
||||||
|
* @description Name of the model
|
||||||
|
*/
|
||||||
|
model_name: string;
|
||||||
|
/** @description Base model */
|
||||||
|
base_model: components["schemas"]["BaseModelType"];
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* MainModelLoaderInvocation
|
||||||
|
* @description Loads a main model, outputting its submodels.
|
||||||
|
*/
|
||||||
|
MainModelLoaderInvocation: {
|
||||||
|
/**
|
||||||
|
* Id
|
||||||
|
* @description The id of this node. Must be unique among all nodes.
|
||||||
|
*/
|
||||||
|
id: string;
|
||||||
|
/**
|
||||||
|
* Is Intermediate
|
||||||
|
* @description Whether or not this node is an intermediate node.
|
||||||
|
* @default false
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
|
/**
|
||||||
|
* Type
|
||||||
|
* @default main_model_loader
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
type?: "main_model_loader";
|
||||||
|
/**
|
||||||
|
* Model
|
||||||
|
* @description The model to load
|
||||||
|
*/
|
||||||
|
model: components["schemas"]["MainModelField"];
|
||||||
|
};
|
||||||
/**
|
/**
|
||||||
* MaskFromAlphaInvocation
|
* MaskFromAlphaInvocation
|
||||||
* @description Extracts the alpha channel of an image as a mask.
|
* @description Extracts the alpha channel of an image as a mask.
|
||||||
@ -2974,6 +3043,16 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
thr_d?: number;
|
thr_d?: number;
|
||||||
};
|
};
|
||||||
|
/** ModelConfigBase */
|
||||||
|
ModelConfigBase: {
|
||||||
|
/** Path */
|
||||||
|
path: string;
|
||||||
|
/** Description */
|
||||||
|
description?: string;
|
||||||
|
/** Model Format */
|
||||||
|
model_format?: string;
|
||||||
|
error?: components["schemas"]["ModelError"];
|
||||||
|
};
|
||||||
/**
|
/**
|
||||||
* ModelError
|
* ModelError
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
@ -3036,7 +3115,7 @@ export type components = {
|
|||||||
/** ModelsList */
|
/** ModelsList */
|
||||||
ModelsList: {
|
ModelsList: {
|
||||||
/** Models */
|
/** Models */
|
||||||
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[];
|
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[];
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* MultiplyInvocation
|
* MultiplyInvocation
|
||||||
@ -3425,47 +3504,6 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
scribble?: boolean;
|
scribble?: boolean;
|
||||||
};
|
};
|
||||||
/**
|
|
||||||
* PipelineModelField
|
|
||||||
* @description Pipeline model field
|
|
||||||
*/
|
|
||||||
PipelineModelField: {
|
|
||||||
/**
|
|
||||||
* Model Name
|
|
||||||
* @description Name of the model
|
|
||||||
*/
|
|
||||||
model_name: string;
|
|
||||||
/** @description Base model */
|
|
||||||
base_model: components["schemas"]["BaseModelType"];
|
|
||||||
};
|
|
||||||
/**
|
|
||||||
* PipelineModelLoaderInvocation
|
|
||||||
* @description Loads a pipeline model, outputting its submodels.
|
|
||||||
*/
|
|
||||||
PipelineModelLoaderInvocation: {
|
|
||||||
/**
|
|
||||||
* Id
|
|
||||||
* @description The id of this node. Must be unique among all nodes.
|
|
||||||
*/
|
|
||||||
id: string;
|
|
||||||
/**
|
|
||||||
* Is Intermediate
|
|
||||||
* @description Whether or not this node is an intermediate node.
|
|
||||||
* @default false
|
|
||||||
*/
|
|
||||||
is_intermediate?: boolean;
|
|
||||||
/**
|
|
||||||
* Type
|
|
||||||
* @default pipeline_model_loader
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
|
||||||
type?: "pipeline_model_loader";
|
|
||||||
/**
|
|
||||||
* Model
|
|
||||||
* @description The model to load
|
|
||||||
*/
|
|
||||||
model: components["schemas"]["PipelineModelField"];
|
|
||||||
};
|
|
||||||
/**
|
/**
|
||||||
* PromptCollectionOutput
|
* PromptCollectionOutput
|
||||||
* @description Base class for invocations that output a collection of prompts
|
* @description Base class for invocations that output a collection of prompts
|
||||||
@ -4266,6 +4304,19 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
level?: 2 | 4;
|
level?: 2 | 4;
|
||||||
};
|
};
|
||||||
|
/**
|
||||||
|
* VAEModelField
|
||||||
|
* @description Vae model field
|
||||||
|
*/
|
||||||
|
VAEModelField: {
|
||||||
|
/**
|
||||||
|
* Model Name
|
||||||
|
* @description Name of the model
|
||||||
|
*/
|
||||||
|
model_name: string;
|
||||||
|
/** @description Base model */
|
||||||
|
base_model: components["schemas"]["BaseModelType"];
|
||||||
|
};
|
||||||
/** VaeField */
|
/** VaeField */
|
||||||
VaeField: {
|
VaeField: {
|
||||||
/**
|
/**
|
||||||
@ -4274,6 +4325,51 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
vae: components["schemas"]["ModelInfo"];
|
vae: components["schemas"]["ModelInfo"];
|
||||||
};
|
};
|
||||||
|
/**
|
||||||
|
* VaeLoaderInvocation
|
||||||
|
* @description Loads a VAE model, outputting a VaeLoaderOutput
|
||||||
|
*/
|
||||||
|
VaeLoaderInvocation: {
|
||||||
|
/**
|
||||||
|
* Id
|
||||||
|
* @description The id of this node. Must be unique among all nodes.
|
||||||
|
*/
|
||||||
|
id: string;
|
||||||
|
/**
|
||||||
|
* Is Intermediate
|
||||||
|
* @description Whether or not this node is an intermediate node.
|
||||||
|
* @default false
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
|
/**
|
||||||
|
* Type
|
||||||
|
* @default vae_loader
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
type?: "vae_loader";
|
||||||
|
/**
|
||||||
|
* Vae Model
|
||||||
|
* @description The VAE to load
|
||||||
|
*/
|
||||||
|
vae_model: components["schemas"]["VAEModelField"];
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* VaeLoaderOutput
|
||||||
|
* @description Model loader output
|
||||||
|
*/
|
||||||
|
VaeLoaderOutput: {
|
||||||
|
/**
|
||||||
|
* Type
|
||||||
|
* @default vae_loader_output
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
type?: "vae_loader_output";
|
||||||
|
/**
|
||||||
|
* Vae
|
||||||
|
* @description Vae model
|
||||||
|
*/
|
||||||
|
vae?: components["schemas"]["VaeField"];
|
||||||
|
};
|
||||||
/** VaeModelConfig */
|
/** VaeModelConfig */
|
||||||
VaeModelConfig: {
|
VaeModelConfig: {
|
||||||
/** Name */
|
/** Name */
|
||||||
@ -4474,7 +4570,7 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
requestBody: {
|
requestBody: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
responses: {
|
responses: {
|
||||||
@ -4511,7 +4607,7 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
requestBody: {
|
requestBody: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
responses: {
|
responses: {
|
||||||
@ -4731,13 +4827,13 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* Import Model
|
* Update Model
|
||||||
* @description Add Model
|
* @description Add Model
|
||||||
*/
|
*/
|
||||||
import_model: {
|
update_model: {
|
||||||
requestBody: {
|
requestBody: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["ImportModelRequest"];
|
"application/json": components["schemas"]["CreateModelRequest"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
responses: {
|
responses: {
|
||||||
@ -4755,6 +4851,36 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
/**
|
||||||
|
* Import Model
|
||||||
|
* @description Add a model using its local path, repo_id, or remote URL
|
||||||
|
*/
|
||||||
|
import_model: {
|
||||||
|
parameters: {
|
||||||
|
query: {
|
||||||
|
/** @description A model path, repo_id or URL to import */
|
||||||
|
name: string;
|
||||||
|
/** @description Prediction type for SDv2 checkpoint files */
|
||||||
|
prediction_type?: "v_prediction" | "epsilon" | "sample";
|
||||||
|
};
|
||||||
|
};
|
||||||
|
responses: {
|
||||||
|
/** @description The model imported successfully */
|
||||||
|
201: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["ImportModelResponse"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
/** @description The model could not be found */
|
||||||
|
404: never;
|
||||||
|
/** @description Validation Error */
|
||||||
|
422: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["HTTPValidationError"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
/**
|
/**
|
||||||
* Delete Model
|
* Delete Model
|
||||||
* @description Delete Model
|
* @description Delete Model
|
||||||
|
@ -33,7 +33,8 @@ export type OffsetPaginatedResults_ImageDTO_ =
|
|||||||
// Models
|
// Models
|
||||||
export type ModelType = S<'ModelType'>;
|
export type ModelType = S<'ModelType'>;
|
||||||
export type BaseModelType = S<'BaseModelType'>;
|
export type BaseModelType = S<'BaseModelType'>;
|
||||||
export type PipelineModelField = S<'PipelineModelField'>;
|
export type MainModelField = S<'MainModelField'>;
|
||||||
|
export type VAEModelField = S<'VAEModelField'>;
|
||||||
export type ModelsList = S<'ModelsList'>;
|
export type ModelsList = S<'ModelsList'>;
|
||||||
|
|
||||||
// Graphs
|
// Graphs
|
||||||
@ -57,8 +58,8 @@ export type TextToLatentsInvocation = N<'TextToLatentsInvocation'>;
|
|||||||
export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>;
|
export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>;
|
||||||
export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>;
|
export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>;
|
||||||
export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>;
|
export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>;
|
||||||
export type PipelineModelLoaderInvocation = N<'PipelineModelLoaderInvocation'>;
|
|
||||||
export type ImageCollectionInvocation = N<'ImageCollectionInvocation'>;
|
export type ImageCollectionInvocation = N<'ImageCollectionInvocation'>;
|
||||||
|
export type MainModelLoaderInvocation = N<'MainModelLoaderInvocation'>;
|
||||||
|
|
||||||
// ControlNet Nodes
|
// ControlNet Nodes
|
||||||
export type ControlNetInvocation = N<'ControlNetInvocation'>;
|
export type ControlNetInvocation = N<'ControlNetInvocation'>;
|
||||||
|
File diff suppressed because one or more lines are too long
@ -1328,6 +1328,14 @@
|
|||||||
react-remove-scroll "^2.5.5"
|
react-remove-scroll "^2.5.5"
|
||||||
react-textarea-autosize "8.3.4"
|
react-textarea-autosize "8.3.4"
|
||||||
|
|
||||||
|
"@mantine/form@^6.0.15":
|
||||||
|
version "6.0.15"
|
||||||
|
resolved "https://registry.yarnpkg.com/@mantine/form/-/form-6.0.15.tgz#e78d953669888e01d3778ee8f62d469a12668c42"
|
||||||
|
integrity sha512-Tz4AuZZ/ddGvEh5zJbDyi9PlGqTilJBdCjRGIgs3zn3hQsfg+ku7/NUR5zNB64dcWPJvGKc074y4iopNIl3FWQ==
|
||||||
|
dependencies:
|
||||||
|
fast-deep-equal "^3.1.3"
|
||||||
|
klona "^2.0.5"
|
||||||
|
|
||||||
"@mantine/hooks@^6.0.14":
|
"@mantine/hooks@^6.0.14":
|
||||||
version "6.0.14"
|
version "6.0.14"
|
||||||
resolved "https://registry.yarnpkg.com/@mantine/hooks/-/hooks-6.0.14.tgz#5f52a75cdd36b14c13a5ffeeedc510d73db76dc0"
|
resolved "https://registry.yarnpkg.com/@mantine/hooks/-/hooks-6.0.14.tgz#5f52a75cdd36b14c13a5ffeeedc510d73db76dc0"
|
||||||
@ -4454,6 +4462,11 @@ klaw-sync@^6.0.0:
|
|||||||
dependencies:
|
dependencies:
|
||||||
graceful-fs "^4.1.11"
|
graceful-fs "^4.1.11"
|
||||||
|
|
||||||
|
klona@^2.0.5:
|
||||||
|
version "2.0.6"
|
||||||
|
resolved "https://registry.yarnpkg.com/klona/-/klona-2.0.6.tgz#85bffbf819c03b2f53270412420a4555ef882e22"
|
||||||
|
integrity sha512-dhG34DXATL5hSxJbIexCft8FChFXtmskoZYnoPWjXQuebWYCNkVeV3KkGegCK9CP1oswI/vQibS2GY7Em/sJJA==
|
||||||
|
|
||||||
kolorist@^1.7.0:
|
kolorist@^1.7.0:
|
||||||
version "1.8.0"
|
version "1.8.0"
|
||||||
resolved "https://registry.yarnpkg.com/kolorist/-/kolorist-1.8.0.tgz#edddbbbc7894bc13302cdf740af6374d4a04743c"
|
resolved "https://registry.yarnpkg.com/kolorist/-/kolorist-1.8.0.tgz#edddbbbc7894bc13302cdf740af6374d4a04743c"
|
||||||
|
Loading…
Reference in New Issue
Block a user