Update UI To Use New Model Manager (#3548)

PR for the Model Manager UI work related to 3.0

[DONE]

- Update ModelType Config names to be specific so that the front end can
parse them correctly.
- Rebuild frontend schema to reflect these changes.
- Update Linear UI Text To Image and Image to Image to work with the new
model loader.
- Updated the ModelInput component in the Node Editor to work with the
new changes.

[TODO REMEMBER]

- Add proper types for ModelLoaderType in `ModelSelect.tsx`

[TODO] 

- Everything else.
This commit is contained in:
blessedcoolant 2023-06-22 22:06:26 +12:00 committed by GitHub
commit 22c337b1aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
67 changed files with 710 additions and 668 deletions

View File

@ -7,8 +7,8 @@ from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import get_all_model_configs from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
MODEL_CONFIGS = Union[tuple(get_all_model_configs())] MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -62,8 +62,7 @@ class ConvertedModelResponse(BaseModel):
info: DiffusersModelInfo = Field(description="The converted model info") info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend models: list[MODEL_CONFIGS]
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
@models_router.get( @models_router.get(
@ -72,10 +71,10 @@ class ModelsList(BaseModel):
responses={200: {"model": ModelsList }}, responses={200: {"model": ModelsList }},
) )
async def list_models( async def list_models(
base_model: BaseModelType = Query( base_model: Optional[BaseModelType] = Query(
default=None, description="Base model" default=None, description="Base model"
), ),
model_type: ModelType = Query( model_type: Optional[ModelType] = Query(
default=None, description="The type of model to get" default=None, description="The type of model to get"
), ),
) -> ModelsList: ) -> ModelsList:

View File

@ -120,6 +120,22 @@ def custom_openapi():
invoker_schema["output"] = outputs_ref invoker_schema["output"] = outputs_ref
from invokeai.backend.model_management.models import get_model_config_enums
for model_config_format_enum in set(get_model_config_enums()):
name = model_config_format_enum.__qualname__
if name in openapi_schema["components"]["schemas"]:
# print(f"Config with name {name} already defined")
continue
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
openapi_schema["components"]["schemas"][name] = dict(
title=name,
description="An enumeration.",
type="string",
enum=list(v.value for v in model_config_format_enum),
)
app.openapi_schema = openapi_schema app.openapi_schema = openapi_schema
return app.openapi_schema return app.openapi_schema

View File

@ -43,12 +43,19 @@ class ModelLoaderOutput(BaseInvocationOutput):
#fmt: on #fmt: on
class SD1ModelLoaderInvocation(BaseInvocation): class PipelineModelField(BaseModel):
"""Loading submodels of selected model.""" """Pipeline model field"""
type: Literal["sd1_model_loader"] = "sd1_model_loader" model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
model_name: str = Field(default="", description="Model to load")
class PipelineModelLoaderInvocation(BaseInvocation):
"""Loads a pipeline model, outputting its submodels."""
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
model: PipelineModelField = Field(description="The model to load")
# TODO: precision? # TODO: precision?
# Schema customisation # Schema customisation
@ -57,22 +64,24 @@ class SD1ModelLoaderInvocation(BaseInvocation):
"ui": { "ui": {
"tags": ["model", "loader"], "tags": ["model", "loader"],
"type_hints": { "type_hints": {
"model_name": "model" # TODO: rename to model_name? "model": "model"
} }
}, },
} }
def invoke(self, context: InvocationContext) -> ModelLoaderOutput: def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = BaseModelType.StableDiffusion1 # TODO: base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Pipeline
# TODO: not found exceptions # TODO: not found exceptions
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(
model_name=self.model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=ModelType.Pipeline, model_type=model_type,
): ):
raise Exception(f"Unkown model name: {self.model_name}!") raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
""" """
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(
@ -107,142 +116,39 @@ class SD1ModelLoaderInvocation(BaseInvocation):
return ModelLoaderOutput( return ModelLoaderOutput(
unet=UNetField( unet=UNetField(
unet=ModelInfo( unet=ModelInfo(
model_name=self.model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=ModelType.Pipeline, model_type=model_type,
submodel=SubModelType.UNet, submodel=SubModelType.UNet,
), ),
scheduler=ModelInfo( scheduler=ModelInfo(
model_name=self.model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=ModelType.Pipeline, model_type=model_type,
submodel=SubModelType.Scheduler, submodel=SubModelType.Scheduler,
), ),
loras=[], loras=[],
), ),
clip=ClipField( clip=ClipField(
tokenizer=ModelInfo( tokenizer=ModelInfo(
model_name=self.model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=ModelType.Pipeline, model_type=model_type,
submodel=SubModelType.Tokenizer, submodel=SubModelType.Tokenizer,
), ),
text_encoder=ModelInfo( text_encoder=ModelInfo(
model_name=self.model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=ModelType.Pipeline, model_type=model_type,
submodel=SubModelType.TextEncoder, submodel=SubModelType.TextEncoder,
), ),
loras=[], loras=[],
), ),
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
model_name=self.model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=ModelType.Pipeline, model_type=model_type,
submodel=SubModelType.Vae,
),
)
)
# TODO: optimize(less code copy)
class SD2ModelLoaderInvocation(BaseInvocation):
"""Loading submodels of selected model."""
type: Literal["sd2_model_loader"] = "sd2_model_loader"
model_name: str = Field(default="", description="Model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["model", "loader"],
"type_hints": {
"model_name": "model" # TODO: rename to model_name?
}
},
}
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = BaseModelType.StableDiffusion2 # TODO:
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
):
raise Exception(f"Unkown model name: {self.model_name}!")
"""
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Tokenizer,
):
raise Exception(
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder,
):
raise Exception(
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.UNet,
):
raise Exception(
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
)
"""
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
vae=VaeField(
vae=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.Vae, submodel=SubModelType.Vae,
), ),
) )

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import torch import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
from dataclasses import dataclass from dataclasses import dataclass
from invokeai.backend.model_management.model_manager import ( from invokeai.backend.model_management.model_manager import (
@ -69,19 +69,6 @@ class ModelManagerServiceBase(ABC):
) -> bool: ) -> bool:
pass pass
@abstractmethod
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
"""
Returns the name and typeof the default model, or None
if none is defined.
"""
pass
@abstractmethod
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
"""Sets the default model to the indicated name."""
pass
@abstractmethod @abstractmethod
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
""" """
@ -270,17 +257,6 @@ class ModelManagerService(ModelManagerServiceBase):
model_type, model_type,
) )
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
"""
Returns the name of the default model, or None
if none is defined.
"""
return self.mgr.default_model()
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
"""Sets the default model to the indicated name."""
self.mgr.set_default_model(model_name, base_model, model_type)
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
""" """
Given a model name returns a dict-like (OmegaConf) object describing it. Given a model name returns a dict-like (OmegaConf) object describing it.
@ -297,21 +273,10 @@ class ModelManagerService(ModelManagerServiceBase):
self, self,
base_model: Optional[BaseModelType] = None, base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None model_type: Optional[ModelType] = None
) -> dict: ) -> list[dict]:
# ) -> dict:
""" """
Return a dict of models in the format: Return a list of models.
{ model_type1:
{ model_name1: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description,
'format': 'folder'|'safetensors'|'ckpt'
},
model_name2: { etc }
},
model_type2:
{ model_name_n: etc
}
""" """
return self.mgr.list_models(base_model, model_type) return self.mgr.list_models(base_model, model_type)

View File

@ -266,6 +266,8 @@ class ModelManager(object):
for model_key, model_config in config.items(): for model_key, model_config in config.items():
model_name, base_model, model_type = self.parse_key(model_key) model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config) self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary # check config version number and update on disk/RAM if necessary
@ -445,38 +447,6 @@ class ModelManager(object):
_cache = self.cache, _cache = self.cache,
) )
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
"""
Returns the name of the default model, or None
if none is defined.
"""
for model_key, model_config in self.models.items():
if model_config.default:
return self.parse_key(model_key)
for model_key, _ in self.models.items():
return self.parse_key(model_key)
else:
return None # TODO: or redo as (None, None, None)
def set_default_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> None:
"""
Set the default model. The change will not take
effect until you call model_manager.commit()
"""
model_key = self.model_key(model_name, base_model, model_type)
if model_key not in self.models:
raise Exception(f"Unknown model: {model_key}")
for cur_model_key, config in self.models.items():
config.default = cur_model_key == model_key
def model_info( def model_info(
self, self,
model_name: str, model_name: str,
@ -503,9 +473,9 @@ class ModelManager(object):
self, self,
base_model: Optional[BaseModelType] = None, base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
) -> Dict[str, Dict[str, str]]: ) -> list[dict]:
""" """
Return a dict of models, in format [base_model][model_type][model_name] Return a list of models.
Please use model_manager.models() to get all the model names, Please use model_manager.models() to get all the model names,
model_manager.model_info('model-name') to get the stanza for the model model_manager.model_info('model-name') to get the stanza for the model
@ -513,7 +483,7 @@ class ModelManager(object):
object derived from models.yaml object derived from models.yaml
""" """
models = dict() models = []
for model_key in sorted(self.models, key=str.casefold): for model_key in sorted(self.models, key=str.casefold):
model_config = self.models[model_key] model_config = self.models[model_key]
@ -523,18 +493,16 @@ class ModelManager(object):
if model_type is not None and cur_model_type != model_type: if model_type is not None and cur_model_type != model_type:
continue continue
if cur_base_model not in models: model_dict = dict(
models[cur_base_model] = dict()
if cur_model_type not in models[cur_base_model]:
models[cur_base_model][cur_model_type] = dict()
models[cur_base_model][cur_model_type][cur_model_name] = dict(
**model_config.dict(exclude_defaults=True), **model_config.dict(exclude_defaults=True),
# OpenAPIModelInfoBase
name=cur_model_name, name=cur_model_name,
base_model=cur_base_model, base_model=cur_base_model,
type=cur_model_type, type=cur_model_type,
) )
models.append(model_dict)
return models return models
def print_models(self) -> None: def print_models(self) -> None:
@ -646,7 +614,9 @@ class ModelManager(object):
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
if model_class.save_to_config: if model_class.save_to_config:
# TODO: or exclude_unset better fits here? # TODO: or exclude_unset better fits here?
data_to_save[model_key] = model_config.dict(exclude_defaults=True) data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
# alias for config file
data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format")
yaml_str = OmegaConf.to_yaml(data_to_save) yaml_str = OmegaConf.to_yaml(data_to_save)
config_file_path = conf_file or self.config_path config_file_path = conf_file or self.config_path

View File

@ -1,3 +1,7 @@
import inspect
from enum import Enum
from pydantic import BaseModel
from typing import Literal, get_origin
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel from .vae import VaeModel
@ -29,10 +33,63 @@ MODEL_CLASSES = {
#}, #},
} }
def get_all_model_configs(): MODEL_CONFIGS = list()
configs = set() OPENAPI_MODEL_CONFIGS = list()
for models in MODEL_CLASSES.values():
for _, model in models.items(): class OpenAPIModelInfoBase(BaseModel):
configs.update(model._get_configs().values()) name: str
configs.discard(None) base_model: BaseModelType
return list(configs) # TODO: set, list or tuple type: ModelType
for base_model, models in MODEL_CLASSES.items():
for model_type, model_class in models.items():
model_configs = set(model_class._get_configs().values())
model_configs.discard(None)
MODEL_CONFIGS.extend(model_configs)
for cfg in model_configs:
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
openapi_cfg_name = model_name + cfg_name
if openapi_cfg_name in vars():
continue
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
__annotations__ = dict(
type=Literal[model_type.value],
),
))
#globals()[openapi_cfg_name] = api_wrapper
vars()[openapi_cfg_name] = api_wrapper
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
def get_model_config_enums():
enums = list()
for model_config in MODEL_CONFIGS:
fields = inspect.get_annotations(model_config)
try:
field = fields["model_format"]
except:
raise Exception("format field not found")
# model_format: None
# model_format: SomeModelFormat
# model_format: Literal[SomeModelFormat.Diffusers]
# model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint]
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
enums.append(field)
elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
enums.append(type(field.__args__[0]))
elif field is None:
pass
else:
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
return enums

View File

@ -48,12 +48,10 @@ class ModelError(str, Enum):
class ModelConfigBase(BaseModel): class ModelConfigBase(BaseModel):
path: str # or Path path: str # or Path
#name: str # not included as present in model key
description: Optional[str] = Field(None) description: Optional[str] = Field(None)
format: Optional[str] = Field(None) model_format: Optional[str] = Field(None)
default: Optional[bool] = Field(False)
# do not save to config # do not save to config
error: Optional[ModelError] = Field(None, exclude=True) error: Optional[ModelError] = Field(None)
class Config: class Config:
use_enum_values = True use_enum_values = True
@ -94,6 +92,11 @@ class ModelBase(metaclass=ABCMeta):
def _hf_definition_to_type(self, subtypes: List[str]) -> Type: def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
if len(subtypes) < 2: if len(subtypes) < 2:
raise Exception("Invalid subfolder definition!") raise Exception("Invalid subfolder definition!")
if all(t is None for t in subtypes):
return None
elif any(t is None for t in subtypes):
raise Exception(f"Unsupported definition: {subtypes}")
if subtypes[0] in ["diffusers", "transformers"]: if subtypes[0] in ["diffusers", "transformers"]:
res_type = sys.modules[subtypes[0]] res_type = sys.modules[subtypes[0]]
subtypes = subtypes[1:] subtypes = subtypes[1:]
@ -122,47 +125,41 @@ class ModelBase(metaclass=ABCMeta):
continue continue
fields = inspect.get_annotations(value) fields = inspect.get_annotations(value)
if "format" not in fields: try:
raise Exception("Invalid config definition - format field not found") field = fields["model_format"]
except:
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
format_type = typing.get_origin(fields["format"]) if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
if format_type not in {None, Literal, Union}: for model_format in field:
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}") configs[model_format.value] = value
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__): elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}") for model_format in field.__args__:
configs[model_format.value] = value
elif field is None:
configs[None] = value
if format_type == Union:
f_fields = fields["format"].__args__
else: else:
f_fields = (fields["format"],) raise Exception(f"Unsupported format definition in {cls.__qualname__}")
for field in f_fields:
if field is None:
format_name = None
else:
format_name = field.__args__[0]
configs[format_name] = value # TODO: error when override(multiple)?
cls.__configs = configs cls.__configs = configs
return cls.__configs return cls.__configs
@classmethod @classmethod
def create_config(cls, **kwargs) -> ModelConfigBase: def create_config(cls, **kwargs) -> ModelConfigBase:
if "format" not in kwargs: if "model_format" not in kwargs:
raise Exception("Field 'format' not found in model config") raise Exception("Field 'model_format' not found in model config")
configs = cls._get_configs() configs = cls._get_configs()
return configs[kwargs["format"]](**kwargs) return configs[kwargs["model_format"]](**kwargs)
@classmethod @classmethod
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase: def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
return cls.create_config( return cls.create_config(
path=path, path=path,
format=cls.detect_format(path), model_format=cls.detect_format(path),
) )
@classmethod @classmethod

View File

@ -1,5 +1,6 @@
import os import os
import torch import torch
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Optional, Union, Literal from typing import Optional, Union, Literal
from .base import ( from .base import (
@ -14,12 +15,16 @@ from .base import (
classproperty, classproperty,
) )
class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class ControlNetModel(ModelBase): class ControlNetModel(ModelBase):
#model_class: Type #model_class: Type
#model_size: int #model_size: int
class Config(ModelConfigBase): class Config(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]] model_format: ControlNetModelFormat
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet assert model_type == ModelType.ControlNet
@ -69,9 +74,9 @@ class ControlNetModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
if os.path.isdir(path): if os.path.isdir(path):
return "diffusers" return ControlNetModelFormat.Diffusers
else: else:
return "checkpoint" return ControlNetModelFormat.Checkpoint
@classmethod @classmethod
def convert_if_required( def convert_if_required(
@ -81,7 +86,7 @@ class ControlNetModel(ModelBase):
config: ModelConfigBase, # empty config or config of parent model config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType, base_model: BaseModelType,
) -> str: ) -> str:
if cls.detect_format(model_path) != "diffusers": if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
raise NotImlemetedError("Checkpoint controlnet models currently unsupported") raise NotImplementedError("Checkpoint controlnet models currently unsupported")
else: else:
return model_path return model_path

View File

@ -1,5 +1,6 @@
import os import os
import torch import torch
from enum import Enum
from typing import Optional, Union, Literal from typing import Optional, Union, Literal
from .base import ( from .base import (
ModelBase, ModelBase,
@ -12,11 +13,15 @@ from .base import (
# TODO: naming # TODO: naming
from ..lora import LoRAModel as LoRAModelRaw from ..lora import LoRAModel as LoRAModelRaw
class LoRAModelFormat(str, Enum):
LyCORIS = "lycoris"
Diffusers = "diffusers"
class LoRAModel(ModelBase): class LoRAModel(ModelBase):
#model_size: int #model_size: int
class Config(ModelConfigBase): class Config(ModelConfigBase):
format: Union[Literal["lycoris"], Literal["diffusers"]] model_format: LoRAModelFormat # TODO:
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora assert model_type == ModelType.Lora
@ -52,9 +57,9 @@ class LoRAModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
if os.path.isdir(path): if os.path.isdir(path):
return "diffusers" return LoRAModelFormat.Diffusers
else: else:
return "lycoris" return LoRAModelFormat.LyCORIS
@classmethod @classmethod
def convert_if_required( def convert_if_required(
@ -64,7 +69,7 @@ class LoRAModel(ModelBase):
config: ModelConfigBase, config: ModelConfigBase,
base_model: BaseModelType, base_model: BaseModelType,
) -> str: ) -> str:
if cls.detect_format(model_path) == "diffusers": if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
# TODO: add diffusers lora when it stabilizes a bit # TODO: add diffusers lora when it stabilizes a bit
raise NotImplementedError("Diffusers lora not supported") raise NotImplementedError("Diffusers lora not supported")
else: else:

View File

@ -1,5 +1,6 @@
import os import os
import json import json
from enum import Enum
from pydantic import Field from pydantic import Field
from pathlib import Path from pathlib import Path
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
@ -19,16 +20,19 @@ from .base import (
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf from omegaconf import OmegaConf
class StableDiffusion1ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion1Model(DiffusersModel): class StableDiffusion1Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase): class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"] model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
variant: ModelVariantType variant: ModelVariantType
class CheckpointConfig(ModelConfigBase): class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"] model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
config: Optional[str] = Field(None) config: Optional[str] = Field(None)
variant: ModelVariantType variant: ModelVariantType
@ -47,7 +51,7 @@ class StableDiffusion1Model(DiffusersModel):
def probe_config(cls, path: str, **kwargs): def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path) model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None) ckpt_config_path = kwargs.get("config", None)
if model_format == "checkpoint": if model_format == StableDiffusion1ModelFormat.Checkpoint:
if ckpt_config_path: if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path) ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
@ -57,7 +61,7 @@ class StableDiffusion1Model(DiffusersModel):
checkpoint = checkpoint.get('state_dict', checkpoint) checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == "diffusers": elif model_format == StableDiffusion1ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json") unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path): if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f: with open(unet_config_path, "r") as f:
@ -80,7 +84,7 @@ class StableDiffusion1Model(DiffusersModel):
return cls.create_config( return cls.create_config(
path=path, path=path,
format=model_format, model_format=model_format,
config=ckpt_config_path, config=ckpt_config_path,
variant=variant, variant=variant,
@ -93,9 +97,9 @@ class StableDiffusion1Model(DiffusersModel):
@classmethod @classmethod
def detect_format(cls, model_path: str): def detect_format(cls, model_path: str):
if os.path.isdir(model_path): if os.path.isdir(model_path):
return "diffusers" return StableDiffusion1ModelFormat.Diffusers
else: else:
return "checkpoint" return StableDiffusion1ModelFormat.Checkpoint
@classmethod @classmethod
def convert_if_required( def convert_if_required(
@ -116,19 +120,22 @@ class StableDiffusion1Model(DiffusersModel):
else: else:
return model_path return model_path
class StableDiffusion2ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion2Model(DiffusersModel): class StableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly # TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase): class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"] model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
variant: ModelVariantType variant: ModelVariantType
prediction_type: SchedulerPredictionType prediction_type: SchedulerPredictionType
upcast_attention: bool upcast_attention: bool
class CheckpointConfig(ModelConfigBase): class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"] model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
config: Optional[str] = Field(None) config: Optional[str] = Field(None)
variant: ModelVariantType variant: ModelVariantType
@ -149,7 +156,7 @@ class StableDiffusion2Model(DiffusersModel):
def probe_config(cls, path: str, **kwargs): def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path) model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None) ckpt_config_path = kwargs.get("config", None)
if model_format == "checkpoint": if model_format == StableDiffusion2ModelFormat.Checkpoint:
if ckpt_config_path: if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path) ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
@ -159,7 +166,7 @@ class StableDiffusion2Model(DiffusersModel):
checkpoint = checkpoint.get('state_dict', checkpoint) checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == "diffusers": elif model_format == StableDiffusion2ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json") unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path): if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f: with open(unet_config_path, "r") as f:
@ -191,7 +198,7 @@ class StableDiffusion2Model(DiffusersModel):
return cls.create_config( return cls.create_config(
path=path, path=path,
format=model_format, model_format=model_format,
config=ckpt_config_path, config=ckpt_config_path,
variant=variant, variant=variant,
@ -206,9 +213,9 @@ class StableDiffusion2Model(DiffusersModel):
@classmethod @classmethod
def detect_format(cls, model_path: str): def detect_format(cls, model_path: str):
if os.path.isdir(model_path): if os.path.isdir(model_path):
return "diffusers" return StableDiffusion2ModelFormat.Diffusers
else: else:
return "checkpoint" return StableDiffusion2ModelFormat.Checkpoint
@classmethod @classmethod
def convert_if_required( def convert_if_required(
@ -281,8 +288,8 @@ def _convert_ckpt_and_cache(
prediction_type = SchedulerPredictionType.Epsilon prediction_type = SchedulerPredictionType.Epsilon
elif version == BaseModelType.StableDiffusion2: elif version == BaseModelType.StableDiffusion2:
upcast_attention = config.upcast_attention upcast_attention = model_config.upcast_attention
prediction_type = config.prediction_type prediction_type = model_config.prediction_type
else: else:
raise Exception(f"Unknown model provided: {version}") raise Exception(f"Unknown model provided: {version}")

View File

@ -16,7 +16,7 @@ class TextualInversionModel(ModelBase):
#model_size: int #model_size: int
class Config(ModelConfigBase): class Config(ModelConfigBase):
format: None model_format: None
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.TextualInversion assert model_type == ModelType.TextualInversion

View File

@ -1,5 +1,7 @@
import os import os
import torch import torch
import safetensors
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Optional, Union, Literal from typing import Optional, Union, Literal
from .base import ( from .base import (
@ -18,12 +20,16 @@ from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf from omegaconf import OmegaConf
class VaeModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class VaeModel(ModelBase): class VaeModel(ModelBase):
#vae_class: Type #vae_class: Type
#model_size: int #model_size: int
class Config(ModelConfigBase): class Config(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]] model_format: VaeModelFormat
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Vae assert model_type == ModelType.Vae
@ -70,9 +76,9 @@ class VaeModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
if os.path.isdir(path): if os.path.isdir(path):
return "diffusers" return VaeModelFormat.Diffusers
else: else:
return "checkpoint" return VaeModelFormat.Checkpoint
@classmethod @classmethod
def convert_if_required( def convert_if_required(
@ -82,7 +88,7 @@ class VaeModel(ModelBase):
config: ModelConfigBase, # empty config or config of parent model config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType, base_model: BaseModelType,
) -> str: ) -> str:
if cls.detect_format(model_path) != "diffusers": if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
return _convert_vae_ckpt_and_cache( return _convert_vae_ckpt_and_cache(
weights_path=model_path, weights_path=model_path,
output_path=output_path, output_path=output_path,

View File

@ -24,6 +24,7 @@ import Toaster from './Toaster';
import DeleteImageModal from 'features/gallery/components/DeleteImageModal'; import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal'; import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import { useListModelsQuery } from 'services/apiSlice';
const DEFAULT_CONFIG = {}; const DEFAULT_CONFIG = {};
@ -46,6 +47,18 @@ const App = ({
const isApplicationReady = useIsApplicationReady(); const isApplicationReady = useIsApplicationReady();
const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline',
});
const { data: controlnetModels } = useListModelsQuery({
model_type: 'controlnet',
});
const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' });
const { data: loraModels } = useListModelsQuery({ model_type: 'lora' });
const { data: embeddingModels } = useListModelsQuery({
model_type: 'embedding',
});
const [loadingOverridden, setLoadingOverridden] = useState(false); const [loadingOverridden, setLoadingOverridden] = useState(false);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();

View File

@ -5,7 +5,6 @@ import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersist
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist'; import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist'; import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist'; import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist'; import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist'; import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
import { omit } from 'lodash-es'; import { omit } from 'lodash-es';
@ -18,7 +17,6 @@ const serializationDenylist: {
gallery: galleryPersistDenylist, gallery: galleryPersistDenylist,
generation: generationPersistDenylist, generation: generationPersistDenylist,
lightbox: lightboxPersistDenylist, lightbox: lightboxPersistDenylist,
models: modelsPersistDenylist,
nodes: nodesPersistDenylist, nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist, postprocessing: postprocessingPersistDenylist,
system: systemPersistDenylist, system: systemPersistDenylist,

View File

@ -7,7 +7,6 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice'; import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
import { initialConfigState } from 'features/system/store/configSlice'; import { initialConfigState } from 'features/system/store/configSlice';
import { initialModelsState } from 'features/system/store/modelSlice';
import { initialSystemState } from 'features/system/store/systemSlice'; import { initialSystemState } from 'features/system/store/systemSlice';
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice'; import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
import { initialUIState } from 'features/ui/store/uiSlice'; import { initialUIState } from 'features/ui/store/uiSlice';
@ -21,7 +20,6 @@ const initialStates: {
gallery: initialGalleryState, gallery: initialGalleryState,
generation: initialGenerationState, generation: initialGenerationState,
lightbox: initialLightboxState, lightbox: initialLightboxState,
models: initialModelsState,
nodes: initialNodesState, nodes: initialNodesState,
postprocessing: initialPostprocessingState, postprocessing: initialPostprocessingState,
system: initialSystemState, system: initialSystemState,

View File

@ -1,9 +1,8 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { appSocketConnected, socketConnected } from 'services/events/actions'; import { appSocketConnected, socketConnected } from 'services/events/actions';
import { receivedPageOfImages } from 'services/thunks/image'; import { receivedPageOfImages } from 'services/thunks/image';
import { receivedModels } from 'services/thunks/model';
import { receivedOpenAPISchema } from 'services/thunks/schema'; import { receivedOpenAPISchema } from 'services/thunks/schema';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' }); const moduleLog = log.child({ namespace: 'socketio' });
@ -15,7 +14,7 @@ export const addSocketConnectedEventListener = () => {
moduleLog.debug({ timestamp }, 'Connected'); moduleLog.debug({ timestamp }, 'Connected');
const { models, nodes, config, images } = getState(); const { nodes, config, images } = getState();
const { disabledTabs } = config; const { disabledTabs } = config;
@ -28,10 +27,6 @@ export const addSocketConnectedEventListener = () => {
); );
} }
if (!models.ids.length) {
dispatch(receivedModels());
}
if (!nodes.schema && !disabledTabs.includes('nodes')) { if (!nodes.schema && !disabledTabs.includes('nodes')) {
dispatch(receivedOpenAPISchema()); dispatch(receivedOpenAPISchema());
} }

View File

@ -5,34 +5,32 @@ import {
configureStore, configureStore,
} from '@reduxjs/toolkit'; } from '@reduxjs/toolkit';
import { rememberReducer, rememberEnhancer } from 'redux-remember';
import dynamicMiddlewares from 'redux-dynamic-middlewares'; import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { rememberEnhancer, rememberReducer } from 'redux-remember';
import canvasReducer from 'features/canvas/store/canvasSlice'; import canvasReducer from 'features/canvas/store/canvasSlice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import galleryReducer from 'features/gallery/store/gallerySlice'; import galleryReducer from 'features/gallery/store/gallerySlice';
import imagesReducer from 'features/gallery/store/imagesSlice'; import imagesReducer from 'features/gallery/store/imagesSlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import generationReducer from 'features/parameters/store/generationSlice'; import generationReducer from 'features/parameters/store/generationSlice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice'; import systemReducer from 'features/system/store/systemSlice';
// import sessionReducer from 'features/system/store/sessionSlice'; // import sessionReducer from 'features/system/store/sessionSlice';
import configReducer from 'features/system/store/configSlice';
import uiReducer from 'features/ui/store/uiSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import modelsReducer from 'features/system/store/modelSlice';
import nodesReducer from 'features/nodes/store/nodesSlice'; import nodesReducer from 'features/nodes/store/nodesSlice';
import boardsReducer from 'features/gallery/store/boardSlice'; import boardsReducer from 'features/gallery/store/boardSlice';
import configReducer from 'features/system/store/configSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice';
import { listenerMiddleware } from './middleware/listenerMiddleware'; import { listenerMiddleware } from './middleware/listenerMiddleware';
import { actionSanitizer } from './middleware/devtools/actionSanitizer'; import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist'; import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { LOCALSTORAGE_PREFIX } from './constants';
import { serialize } from './enhancers/reduxRemember/serialize'; import { serialize } from './enhancers/reduxRemember/serialize';
import { unserialize } from './enhancers/reduxRemember/unserialize'; import { unserialize } from './enhancers/reduxRemember/unserialize';
import { LOCALSTORAGE_PREFIX } from './constants';
import { api } from 'services/apiSlice'; import { api } from 'services/apiSlice';
const allReducers = { const allReducers = {
@ -40,7 +38,6 @@ const allReducers = {
gallery: galleryReducer, gallery: galleryReducer,
generation: generationReducer, generation: generationReducer,
lightbox: lightboxReducer, lightbox: lightboxReducer,
models: modelsReducer,
nodes: nodesReducer, nodes: nodesReducer,
postprocessing: postprocessingReducer, postprocessing: postprocessingReducer,
system: systemReducer, system: systemReducer,
@ -50,8 +47,8 @@ const allReducers = {
images: imagesReducer, images: imagesReducer,
controlNet: controlNetReducer, controlNet: controlNetReducer,
boards: boardsReducer, boards: boardsReducer,
[api.reducerPath]: api.reducer,
// session: sessionReducer, // session: sessionReducer,
[api.reducerPath]: api.reducer,
}; };
const rootReducer = combineReducers(allReducers); const rootReducer = combineReducers(allReducers);
@ -63,7 +60,6 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'gallery', 'gallery',
'generation', 'generation',
'lightbox', 'lightbox',
// 'models',
'nodes', 'nodes',
'postprocessing', 'postprocessing',
'system', 'system',

View File

@ -1,28 +1,18 @@
import { Select } from '@chakra-ui/react'; import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
ModelInputFieldTemplate, ModelInputFieldTemplate,
ModelInputFieldValue, ModelInputFieldValue,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { selectModelsIds } from 'features/system/store/modelSlice';
import { isEqual } from 'lodash-es';
import { ChangeEvent, memo } from 'react';
import { FieldComponentProps } from './types';
const availableModelsSelector = createSelector( import { memo, useCallback, useEffect, useMemo } from 'react';
[selectModelsIds], import { FieldComponentProps } from './types';
(allModelNames) => { import { forEach, isString } from 'lodash-es';
return { allModelNames }; import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
// return map(modelList, (_, name) => name); import IAIMantineSelect from 'common/components/IAIMantineSelect';
}, import { useTranslation } from 'react-i18next';
{ import { useListModelsQuery } from 'services/apiSlice';
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const ModelInputFieldComponent = ( const ModelInputFieldComponent = (
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate> props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
@ -30,28 +20,82 @@ const ModelInputFieldComponent = (
const { nodeId, field } = props; const { nodeId, field } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation();
const { allModelNames } = useAppSelector(availableModelsSelector); const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline',
});
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => { const data = useMemo(() => {
dispatch( if (!pipelineModels) {
fieldValueChanged({ return [];
nodeId, }
fieldName: field.name,
value: e.target.value, const data: SelectItem[] = [];
})
); forEach(pipelineModels.entities, (model, id) => {
}; if (!model) {
return;
}
data.push({
value: id,
label: model.name,
group: BASE_MODEL_NAME_MAP[model.base_model],
});
});
return data;
}, [pipelineModels]);
const selectedModel = useMemo(
() => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]],
[pipelineModels?.entities, pipelineModels?.ids, field.value]
);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: v,
})
);
},
[dispatch, field.name, nodeId]
);
useEffect(() => {
if (field.value && pipelineModels?.ids.includes(field.value)) {
return;
}
const firstModel = pipelineModels?.ids[0];
if (!isString(firstModel)) {
return;
}
handleValueChanged(firstModel);
}, [field.value, handleValueChanged, pipelineModels?.ids]);
return ( return (
<Select <IAIMantineSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model &&
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
}
value={field.value}
placeholder="Pick one"
data={data}
onChange={handleValueChanged} onChange={handleValueChanged}
value={field.value || allModelNames[0]} />
>
{allModelNames.map((option) => (
<option key={option}>{option}</option>
))}
</Select>
); );
}; };

View File

@ -101,21 +101,6 @@ const nodesSlice = createSlice({
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => { builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
state.schema = action.payload; state.schema = action.payload;
}); });
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload;
state.nodes.forEach((node) => {
forEach(node.data.inputs, (input) => {
if (input.type === 'image') {
if (input.value?.image_name === image_name) {
input.value.image_url = image_url;
input.value.thumbnail_url = thumbnail_url;
}
}
});
});
});
}, },
}); });

View File

@ -23,6 +23,7 @@ import {
} from './constants'; } from './constants';
import { set } from 'lodash-es'; import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -36,7 +37,7 @@ export const buildCanvasImageToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: model_name, model: modelId,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -49,6 +50,8 @@ export const buildCanvasImageToImageGraph = (
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions; const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToPipelineModelField(modelId);
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node * full graph here as a template. Then use the parameters from app state and set friendlier node
@ -85,9 +88,9 @@ export const buildCanvasImageToImageGraph = (
id: NOISE, id: NOISE,
}, },
[MODEL_LOADER]: { [MODEL_LOADER]: {
type: 'sd1_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: MODEL_LOADER,
model_name, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',

View File

@ -17,6 +17,7 @@ import {
INPAINT_GRAPH, INPAINT_GRAPH,
INPAINT, INPAINT,
} from './constants'; } from './constants';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -31,7 +32,7 @@ export const buildCanvasInpaintGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: model_name, model: modelId,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -54,6 +55,8 @@ export const buildCanvasInpaintGraph = (
// We may need to set the inpaint width and height to scale the image // We may need to set the inpaint width and height to scale the image
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas; const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const model = modelIdToPipelineModelField(modelId);
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: INPAINT_GRAPH, id: INPAINT_GRAPH,
nodes: { nodes: {
@ -99,9 +102,9 @@ export const buildCanvasInpaintGraph = (
prompt: negativePrompt, prompt: negativePrompt,
}, },
[MODEL_LOADER]: { [MODEL_LOADER]: {
type: 'sd1_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: MODEL_LOADER,
model_name, model,
}, },
[RANGE_OF_SIZE]: { [RANGE_OF_SIZE]: {
type: 'range_of_size', type: 'range_of_size',

View File

@ -14,6 +14,7 @@ import {
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
} from './constants'; } from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
/** /**
* Builds the Canvas tab's Text to Image graph. * Builds the Canvas tab's Text to Image graph.
@ -24,7 +25,7 @@ export const buildCanvasTextToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: model_name, model: modelId,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -36,6 +37,8 @@ export const buildCanvasTextToImageGraph = (
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions; const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToPipelineModelField(modelId);
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node * full graph here as a template. Then use the parameters from app state and set friendlier node
@ -80,9 +83,9 @@ export const buildCanvasTextToImageGraph = (
steps, steps,
}, },
[MODEL_LOADER]: { [MODEL_LOADER]: {
type: 'sd1_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: MODEL_LOADER,
model_name, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',

View File

@ -22,6 +22,7 @@ import {
} from './constants'; } from './constants';
import { set } from 'lodash-es'; import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -34,7 +35,7 @@ export const buildLinearImageToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: model_name, model: modelId,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -62,6 +63,8 @@ export const buildLinearImageToImageGraph = (
throw new Error('No initial image found in state'); throw new Error('No initial image found in state');
} }
const model = modelIdToPipelineModelField(modelId);
// copy-pasted graph from node editor, filled in with state values & friendly node ids // copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: IMAGE_TO_IMAGE_GRAPH, id: IMAGE_TO_IMAGE_GRAPH,
@ -89,9 +92,9 @@ export const buildLinearImageToImageGraph = (
id: NOISE, id: NOISE,
}, },
[MODEL_LOADER]: { [MODEL_LOADER]: {
type: 'sd1_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: MODEL_LOADER,
model_name, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',

View File

@ -1,6 +1,10 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api'; import {
BaseModelType,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api';
import { import {
ITERATE, ITERATE,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
@ -14,6 +18,7 @@ import {
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
} from './constants'; } from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
type TextToImageGraphOverrides = { type TextToImageGraphOverrides = {
width: number; width: number;
@ -27,7 +32,7 @@ export const buildLinearTextToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: model_name, model: modelId,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -38,6 +43,8 @@ export const buildLinearTextToImageGraph = (
shouldRandomizeSeed, shouldRandomizeSeed,
} = state.generation; } = state.generation;
const model = modelIdToPipelineModelField(modelId);
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node * full graph here as a template. Then use the parameters from app state and set friendlier node
@ -82,9 +89,9 @@ export const buildLinearTextToImageGraph = (
steps, steps,
}, },
[MODEL_LOADER]: { [MODEL_LOADER]: {
type: 'sd1_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: MODEL_LOADER,
model_name, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',

View File

@ -1,9 +1,10 @@
import { Graph } from 'services/api'; import { Graph } from 'services/api';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import { cloneDeep, forEach, omit, reduce, values } from 'lodash-es'; import { cloneDeep, omit, reduce } from 'lodash-es';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { InputFieldValue } from 'features/nodes/types/types'; import { InputFieldValue } from 'features/nodes/types/types';
import { AnyInvocation } from 'services/events/types'; import { AnyInvocation } from 'services/events/types';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
/** /**
* We need to do special handling for some fields * We need to do special handling for some fields
@ -24,6 +25,12 @@ export const parseFieldValue = (field: InputFieldValue) => {
} }
} }
if (field.type === 'model') {
if (field.value) {
return modelIdToPipelineModelField(field.value);
}
}
return field.value; return field.value;
}; };

View File

@ -7,7 +7,7 @@ export const NOISE = 'noise';
export const RANDOM_INT = 'rand_int'; export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size'; export const RANGE_OF_SIZE = 'range_of_size';
export const ITERATE = 'iterate'; export const ITERATE = 'iterate';
export const MODEL_LOADER = 'model_loader'; export const MODEL_LOADER = 'pipeline_model_loader';
export const IMAGE_TO_LATENTS = 'image_to_latents'; export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image'; export const RESIZE = 'resize_image';

View File

@ -0,0 +1,18 @@
import { BaseModelType, PipelineModelField } from 'services/api';
/**
* Crudely converts a model id to a pipeline model field
* TODO: Make better
*/
export const modelIdToPipelineModelField = (
modelId: string
): PipelineModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: PipelineModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -6,7 +6,7 @@ import ParamScheduler from './ParamScheduler';
const ParamSchedulerAndModel = () => { const ParamSchedulerAndModel = () => {
return ( return (
<Flex gap={3} w="full"> <Flex gap={3} w="full">
<Box w="20rem"> <Box w="25rem">
<ParamScheduler /> <ParamScheduler />
</Box> </Box>
<Box w="full"> <Box w="full">

View File

@ -1,10 +1,9 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
import { configChanged } from 'features/system/store/configSlice'; import { configChanged } from 'features/system/store/configSlice';
import { clamp, sortBy } from 'lodash-es'; import { clamp } from 'lodash-es';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { imageUrlsReceived } from 'services/thunks/image';
import { receivedModels } from 'services/thunks/model';
import { import {
CfgScaleParam, CfgScaleParam,
HeightParam, HeightParam,
@ -17,7 +16,6 @@ import {
StrengthParam, StrengthParam,
WidthParam, WidthParam,
} from './parameterZodSchemas'; } from './parameterZodSchemas';
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
export interface GenerationState { export interface GenerationState {
cfgScale: CfgScaleParam; cfgScale: CfgScaleParam;
@ -220,28 +218,12 @@ export const generationSlice = createSlice({
}, },
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
builder.addCase(receivedModels.fulfilled, (state, action) => {
if (!state.model) {
const firstModel = sortBy(action.payload, 'name')[0];
state.model = firstModel.name;
}
});
builder.addCase(configChanged, (state, action) => { builder.addCase(configChanged, (state, action) => {
const defaultModel = action.payload.sd?.defaultModel; const defaultModel = action.payload.sd?.defaultModel;
if (defaultModel && !state.model) { if (defaultModel && !state.model) {
state.model = defaultModel; state.model = defaultModel;
} }
}); });
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
// const { image_name, image_url, thumbnail_url } = action.payload;
// if (state.initialImage?.image_name === image_name) {
// state.initialImage.image_url = image_url;
// state.initialImage.thumbnail_url = thumbnail_url;
// }
// });
}, },
}); });

View File

@ -154,3 +154,17 @@ export type StrengthParam = z.infer<typeof zStrength>;
*/ */
export const isValidStrength = (val: unknown): val is StrengthParam => export const isValidStrength = (val: unknown): val is StrengthParam =>
zStrength.safeParse(val).success; zStrength.safeParse(val).success;
// /**
// * Zod schema for BaseModelType
// */
// export const zBaseModelType = z.enum(['sd-1', 'sd-2']);
// /**
// * Type alias for base model type, inferred from its zod schema. Should be identical to the type alias from OpenAPI.
// */
// export type BaseModelType = z.infer<typeof zBaseModelType>;
// /**
// * Validates/type-guards a value as a base model type
// */
// export const isValidBaseModelType = (val: unknown): val is BaseModelType =>
// zBaseModelType.safeParse(val).success;

View File

@ -1,44 +1,59 @@
import { createSelector } from '@reduxjs/toolkit'; import { memo, useCallback, useEffect, useMemo } from 'react';
import { isEqual } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect, { import IAIMantineSelect from 'common/components/IAIMantineSelect';
IAISelectDataType,
} from 'common/components/IAIMantineSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { modelSelected } from 'features/parameters/store/generationSlice'; import { modelSelected } from 'features/parameters/store/generationSlice';
import { selectModelsAll, selectModelsById } from '../store/modelSlice';
const selector = createSelector( import { forEach, isString } from 'lodash-es';
[(state: RootState) => state, generationSelector], import { SelectItem } from '@mantine/core';
(state, generation) => { import { RootState } from 'app/store/store';
const selectedModel = selectModelsById(state, generation.model); import { useListModelsQuery } from 'services/apiSlice';
const modelData = selectModelsAll(state) export const MODEL_TYPE_MAP = {
.map<IAISelectDataType>((m) => ({ 'sd-1': 'Stable Diffusion 1.x',
value: m.name, 'sd-2': 'Stable Diffusion 2.x',
label: m.name, };
}))
.sort((a, b) => a.label.localeCompare(b.label));
return {
selectedModel,
modelData,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const ModelSelect = () => { const ModelSelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { selectedModel, modelData } = useAppSelector(selector);
const selectedModelId = useAppSelector(
(state: RootState) => state.generation.model
);
const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline',
});
const data = useMemo(() => {
if (!pipelineModels) {
return [];
}
const data: SelectItem[] = [];
forEach(pipelineModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [pipelineModels]);
const selectedModel = useMemo(
() => pipelineModels?.entities[selectedModelId],
[pipelineModels?.entities, selectedModelId]
);
const handleChangeModel = useCallback( const handleChangeModel = useCallback(
(v: string | null) => { (v: string | null) => {
if (!v) { if (!v) {
@ -49,13 +64,27 @@ const ModelSelect = () => {
[dispatch] [dispatch]
); );
useEffect(() => {
if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) {
return;
}
const firstModel = pipelineModels?.ids[0];
if (!isString(firstModel)) {
return;
}
handleChangeModel(firstModel);
}, [handleChangeModel, pipelineModels?.ids, selectedModelId]);
return ( return (
<IAIMantineSelect <IAIMantineSelect
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
label={t('modelManager.model')} label={t('modelManager.model')}
value={selectedModel?.name ?? ''} value={selectedModelId}
placeholder="Pick one" placeholder="Pick one"
data={modelData} data={data}
onChange={handleChangeModel} onChange={handleChangeModel}
/> />
); );

View File

@ -1,6 +1,5 @@
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants'; import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
@ -16,6 +15,7 @@ const data = map(SCHEDULER_NAMES, (s) => ({
export default function SettingsSchedulers() { export default function SettingsSchedulers() {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const enabledSchedulers = useAppSelector( const enabledSchedulers = useAppSelector(

View File

@ -7,13 +7,12 @@ import { systemSelector } from '../store/systemSelectors';
const isApplicationReadySelector = createSelector( const isApplicationReadySelector = createSelector(
[systemSelector, configSelector], [systemSelector, configSelector],
(system, config) => { (system, config) => {
const { wereModelsReceived, wasSchemaParsed } = system; const { wasSchemaParsed } = system;
const { disabledTabs } = config; const { disabledTabs } = config;
return { return {
disabledTabs, disabledTabs,
wereModelsReceived,
wasSchemaParsed, wasSchemaParsed,
}; };
} }
@ -23,21 +22,17 @@ const isApplicationReadySelector = createSelector(
* Checks if the application is ready to be used, i.e. if the initial startup process is finished. * Checks if the application is ready to be used, i.e. if the initial startup process is finished.
*/ */
export const useIsApplicationReady = () => { export const useIsApplicationReady = () => {
const { disabledTabs, wereModelsReceived, wasSchemaParsed } = useAppSelector( const { disabledTabs, wasSchemaParsed } = useAppSelector(
isApplicationReadySelector isApplicationReadySelector
); );
const isApplicationReady = useMemo(() => { const isApplicationReady = useMemo(() => {
if (!wereModelsReceived) {
return false;
}
if (!disabledTabs.includes('nodes') && !wasSchemaParsed) { if (!disabledTabs.includes('nodes') && !wasSchemaParsed) {
return false; return false;
} }
return true; return true;
}, [disabledTabs, wereModelsReceived, wasSchemaParsed]); }, [disabledTabs, wasSchemaParsed]);
return isApplicationReady; return isApplicationReady;
}; };

View File

@ -1,3 +0,0 @@
import { RootState } from 'app/store/store';
export const modelSelector = (state: RootState) => state.models;

View File

@ -1,47 +0,0 @@
import { createEntityAdapter } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { CkptModelInfo, DiffusersModelInfo } from 'services/api';
import { receivedModels } from 'services/thunks/model';
export type Model = (CkptModelInfo | DiffusersModelInfo) & {
name: string;
};
export const modelsAdapter = createEntityAdapter<Model>({
selectId: (model) => model.name,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const initialModelsState = modelsAdapter.getInitialState();
export type ModelsState = typeof initialModelsState;
export const modelsSlice = createSlice({
name: 'models',
initialState: initialModelsState,
reducers: {
modelAdded: modelsAdapter.upsertOne,
},
extraReducers(builder) {
/**
* Received Models - FULFILLED
*/
builder.addCase(receivedModels.fulfilled, (state, action) => {
const models = action.payload;
modelsAdapter.setAll(state, models);
});
},
});
export const {
selectAll: selectModelsAll,
selectById: selectModelsById,
selectEntities: selectModelsEntities,
selectIds: selectModelsIds,
selectTotal: selectModelsTotal,
} = modelsAdapter.getSelectors<RootState>((state) => state.models);
export const { modelAdded } = modelsSlice.actions;
export default modelsSlice.reducer;

View File

@ -1,6 +0,0 @@
import { ModelsState } from './modelSlice';
/**
* Models slice persist denylist
*/
export const modelsPersistDenylist: (keyof ModelsState)[] = ['entities', 'ids'];

View File

@ -1,20 +1,12 @@
import { UseToastOptions } from '@chakra-ui/react'; import { UseToastOptions } from '@chakra-ui/react';
import { PayloadAction } from '@reduxjs/toolkit'; import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai'; import * as InvokeAI from 'app/types/invokeai';
import { ProgressImage } from 'services/events/types';
import { makeToast } from '../../../app/components/Toaster';
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
import { receivedModels } from 'services/thunks/model';
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
import { LogLevelName } from 'roarr';
import { InvokeLogLevel } from 'app/logging/useLogger'; import { InvokeLogLevel } from 'app/logging/useLogger';
import { TFuncKey } from 'i18next';
import { t } from 'i18next';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { LANGUAGES } from '../components/LanguagePicker'; import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
import { imageUploaded } from 'services/thunks/image'; import { TFuncKey, t } from 'i18next';
import { LogLevelName } from 'roarr';
import { import {
appSocketConnected, appSocketConnected,
appSocketDisconnected, appSocketDisconnected,
@ -26,6 +18,11 @@ import {
appSocketSubscribed, appSocketSubscribed,
appSocketUnsubscribed, appSocketUnsubscribed,
} from 'services/events/actions'; } from 'services/events/actions';
import { ProgressImage } from 'services/events/types';
import { imageUploaded } from 'services/thunks/image';
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
import { makeToast } from '../../../app/components/Toaster';
import { LANGUAGES } from '../components/LanguagePicker';
export type CancelStrategy = 'immediate' | 'scheduled'; export type CancelStrategy = 'immediate' | 'scheduled';
@ -379,13 +376,6 @@ export const systemSlice = createSlice({
); );
}); });
/**
* Received available models from the backend
*/
builder.addCase(receivedModels.fulfilled, (state) => {
state.wereModelsReceived = true;
});
/** /**
* OpenAPI schema was parsed * OpenAPI schema was parsed
*/ */

View File

@ -25,6 +25,8 @@ export type { ConditioningField } from './models/ConditioningField';
export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation'; export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation';
export type { ControlField } from './models/ControlField'; export type { ControlField } from './models/ControlField';
export type { ControlNetInvocation } from './models/ControlNetInvocation'; export type { ControlNetInvocation } from './models/ControlNetInvocation';
export type { ControlNetModelConfig } from './models/ControlNetModelConfig';
export type { ControlNetModelFormat } from './models/ControlNetModelFormat';
export type { ControlOutput } from './models/ControlOutput'; export type { ControlOutput } from './models/ControlOutput';
export type { CreateModelRequest } from './models/CreateModelRequest'; export type { CreateModelRequest } from './models/CreateModelRequest';
export type { CvInpaintInvocation } from './models/CvInpaintInvocation'; export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
@ -67,14 +69,6 @@ export type { InfillTileInvocation } from './models/InfillTileInvocation';
export type { InpaintInvocation } from './models/InpaintInvocation'; export type { InpaintInvocation } from './models/InpaintInvocation';
export type { IntCollectionOutput } from './models/IntCollectionOutput'; export type { IntCollectionOutput } from './models/IntCollectionOutput';
export type { IntOutput } from './models/IntOutput'; export type { IntOutput } from './models/IntOutput';
export type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config';
export type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './models/invokeai__backend__model_management__models__lora__LoRAModel__Config';
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig';
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig';
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig';
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig';
export type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config';
export type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './models/invokeai__backend__model_management__models__vae__VaeModel__Config';
export type { IterateInvocation } from './models/IterateInvocation'; export type { IterateInvocation } from './models/IterateInvocation';
export type { IterateInvocationOutput } from './models/IterateInvocationOutput'; export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
export type { LatentsField } from './models/LatentsField'; export type { LatentsField } from './models/LatentsField';
@ -87,6 +81,8 @@ export type { LoadImageInvocation } from './models/LoadImageInvocation';
export type { LoraInfo } from './models/LoraInfo'; export type { LoraInfo } from './models/LoraInfo';
export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation'; export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation';
export type { LoraLoaderOutput } from './models/LoraLoaderOutput'; export type { LoraLoaderOutput } from './models/LoraLoaderOutput';
export type { LoRAModelConfig } from './models/LoRAModelConfig';
export type { LoRAModelFormat } from './models/LoRAModelFormat';
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation'; export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
export type { MaskOutput } from './models/MaskOutput'; export type { MaskOutput } from './models/MaskOutput';
export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation'; export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation';
@ -109,6 +105,8 @@ export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedRe
export type { ParamFloatInvocation } from './models/ParamFloatInvocation'; export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
export type { ParamIntInvocation } from './models/ParamIntInvocation'; export type { ParamIntInvocation } from './models/ParamIntInvocation';
export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation'; export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
export type { PipelineModelField } from './models/PipelineModelField';
export type { PipelineModelLoaderInvocation } from './models/PipelineModelLoaderInvocation';
export type { PromptCollectionOutput } from './models/PromptCollectionOutput'; export type { PromptCollectionOutput } from './models/PromptCollectionOutput';
export type { PromptOutput } from './models/PromptOutput'; export type { PromptOutput } from './models/PromptOutput';
export type { RandomIntInvocation } from './models/RandomIntInvocation'; export type { RandomIntInvocation } from './models/RandomIntInvocation';
@ -120,16 +118,23 @@ export type { ResourceOrigin } from './models/ResourceOrigin';
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation'; export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation'; export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
export type { SchedulerPredictionType } from './models/SchedulerPredictionType'; export type { SchedulerPredictionType } from './models/SchedulerPredictionType';
export type { SD1ModelLoaderInvocation } from './models/SD1ModelLoaderInvocation';
export type { SD2ModelLoaderInvocation } from './models/SD2ModelLoaderInvocation';
export type { ShowImageInvocation } from './models/ShowImageInvocation'; export type { ShowImageInvocation } from './models/ShowImageInvocation';
export type { StableDiffusion1ModelCheckpointConfig } from './models/StableDiffusion1ModelCheckpointConfig';
export type { StableDiffusion1ModelDiffusersConfig } from './models/StableDiffusion1ModelDiffusersConfig';
export type { StableDiffusion1ModelFormat } from './models/StableDiffusion1ModelFormat';
export type { StableDiffusion2ModelCheckpointConfig } from './models/StableDiffusion2ModelCheckpointConfig';
export type { StableDiffusion2ModelDiffusersConfig } from './models/StableDiffusion2ModelDiffusersConfig';
export type { StableDiffusion2ModelFormat } from './models/StableDiffusion2ModelFormat';
export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation'; export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation';
export type { SubModelType } from './models/SubModelType'; export type { SubModelType } from './models/SubModelType';
export type { SubtractInvocation } from './models/SubtractInvocation'; export type { SubtractInvocation } from './models/SubtractInvocation';
export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation'; export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';
export type { TextualInversionModelConfig } from './models/TextualInversionModelConfig';
export type { UNetField } from './models/UNetField'; export type { UNetField } from './models/UNetField';
export type { UpscaleInvocation } from './models/UpscaleInvocation'; export type { UpscaleInvocation } from './models/UpscaleInvocation';
export type { VaeField } from './models/VaeField'; export type { VaeField } from './models/VaeField';
export type { VaeModelConfig } from './models/VaeModelConfig';
export type { VaeModelFormat } from './models/VaeModelFormat';
export type { VaeRepo } from './models/VaeRepo'; export type { VaeRepo } from './models/VaeRepo';
export type { ValidationError } from './models/ValidationError'; export type { ValidationError } from './models/ValidationError';
export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation'; export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation';

View File

@ -0,0 +1,18 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { BaseModelType } from './BaseModelType';
import type { ControlNetModelFormat } from './ControlNetModelFormat';
import type { ModelError } from './ModelError';
export type ControlNetModelConfig = {
name: string;
base_model: BaseModelType;
type: 'controlnet';
path: string;
description?: string;
model_format: ControlNetModelFormat;
error?: ModelError;
};

View File

@ -0,0 +1,8 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* An enumeration.
*/
export type ControlNetModelFormat = 'checkpoint' | 'diffusers';

View File

@ -49,6 +49,7 @@ import type { OpenposeImageProcessorInvocation } from './OpenposeImageProcessorI
import type { ParamFloatInvocation } from './ParamFloatInvocation'; import type { ParamFloatInvocation } from './ParamFloatInvocation';
import type { ParamIntInvocation } from './ParamIntInvocation'; import type { ParamIntInvocation } from './ParamIntInvocation';
import type { PidiImageProcessorInvocation } from './PidiImageProcessorInvocation'; import type { PidiImageProcessorInvocation } from './PidiImageProcessorInvocation';
import type { PipelineModelLoaderInvocation } from './PipelineModelLoaderInvocation';
import type { RandomIntInvocation } from './RandomIntInvocation'; import type { RandomIntInvocation } from './RandomIntInvocation';
import type { RandomRangeInvocation } from './RandomRangeInvocation'; import type { RandomRangeInvocation } from './RandomRangeInvocation';
import type { RangeInvocation } from './RangeInvocation'; import type { RangeInvocation } from './RangeInvocation';
@ -56,8 +57,6 @@ import type { RangeOfSizeInvocation } from './RangeOfSizeInvocation';
import type { ResizeLatentsInvocation } from './ResizeLatentsInvocation'; import type { ResizeLatentsInvocation } from './ResizeLatentsInvocation';
import type { RestoreFaceInvocation } from './RestoreFaceInvocation'; import type { RestoreFaceInvocation } from './RestoreFaceInvocation';
import type { ScaleLatentsInvocation } from './ScaleLatentsInvocation'; import type { ScaleLatentsInvocation } from './ScaleLatentsInvocation';
import type { SD1ModelLoaderInvocation } from './SD1ModelLoaderInvocation';
import type { SD2ModelLoaderInvocation } from './SD2ModelLoaderInvocation';
import type { ShowImageInvocation } from './ShowImageInvocation'; import type { ShowImageInvocation } from './ShowImageInvocation';
import type { StepParamEasingInvocation } from './StepParamEasingInvocation'; import type { StepParamEasingInvocation } from './StepParamEasingInvocation';
import type { SubtractInvocation } from './SubtractInvocation'; import type { SubtractInvocation } from './SubtractInvocation';
@ -73,7 +72,7 @@ export type Graph = {
/** /**
* The nodes in this graph * The nodes in this graph
*/ */
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation)>; nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | PipelineModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation)>;
/** /**
* The connections between nodes and their fields in this graph * The connections between nodes and their fields in this graph
*/ */

View File

@ -0,0 +1,18 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { BaseModelType } from './BaseModelType';
import type { LoRAModelFormat } from './LoRAModelFormat';
import type { ModelError } from './ModelError';
export type LoRAModelConfig = {
name: string;
base_model: BaseModelType;
type: 'lora';
path: string;
description?: string;
model_format: LoRAModelFormat;
error?: ModelError;
};

View File

@ -0,0 +1,8 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* An enumeration.
*/
export type LoRAModelFormat = 'lycoris' | 'diffusers';

View File

@ -2,16 +2,16 @@
/* tslint:disable */ /* tslint:disable */
/* eslint-disable */ /* eslint-disable */
import type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './invokeai__backend__model_management__models__controlnet__ControlNetModel__Config'; import type { ControlNetModelConfig } from './ControlNetModelConfig';
import type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './invokeai__backend__model_management__models__lora__LoRAModel__Config'; import type { LoRAModelConfig } from './LoRAModelConfig';
import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig'; import type { StableDiffusion1ModelCheckpointConfig } from './StableDiffusion1ModelCheckpointConfig';
import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig'; import type { StableDiffusion1ModelDiffusersConfig } from './StableDiffusion1ModelDiffusersConfig';
import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig'; import type { StableDiffusion2ModelCheckpointConfig } from './StableDiffusion2ModelCheckpointConfig';
import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig'; import type { StableDiffusion2ModelDiffusersConfig } from './StableDiffusion2ModelDiffusersConfig';
import type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config'; import type { TextualInversionModelConfig } from './TextualInversionModelConfig';
import type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './invokeai__backend__model_management__models__vae__VaeModel__Config'; import type { VaeModelConfig } from './VaeModelConfig';
export type ModelsList = { export type ModelsList = {
models: Record<string, Record<string, Record<string, (invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig | invokeai__backend__model_management__models__controlnet__ControlNetModel__Config | invokeai__backend__model_management__models__lora__LoRAModel__Config | invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig | invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config | invokeai__backend__model_management__models__vae__VaeModel__Config | invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig | invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig)>>>; models: Array<(StableDiffusion1ModelCheckpointConfig | StableDiffusion1ModelDiffusersConfig | VaeModelConfig | LoRAModelConfig | ControlNetModelConfig | TextualInversionModelConfig | StableDiffusion2ModelCheckpointConfig | StableDiffusion2ModelDiffusersConfig)>;
}; };

View File

@ -0,0 +1,20 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { BaseModelType } from './BaseModelType';
/**
* Pipeline model field
*/
export type PipelineModelField = {
/**
* Name of the model
*/
model_name: string;
/**
* Base model
*/
base_model: BaseModelType;
};

View File

@ -2,10 +2,12 @@
/* tslint:disable */ /* tslint:disable */
/* eslint-disable */ /* eslint-disable */
import type { PipelineModelField } from './PipelineModelField';
/** /**
* Loading submodels of selected model. * Loads a pipeline model, outputting its submodels.
*/ */
export type SD2ModelLoaderInvocation = { export type PipelineModelLoaderInvocation = {
/** /**
* The id of this node. Must be unique among all nodes. * The id of this node. Must be unique among all nodes.
*/ */
@ -14,10 +16,10 @@ export type SD2ModelLoaderInvocation = {
* Whether or not this node is an intermediate node. * Whether or not this node is an intermediate node.
*/ */
is_intermediate?: boolean; is_intermediate?: boolean;
type?: 'sd2_model_loader'; type?: 'pipeline_model_loader';
/** /**
* Model to load * The model to load
*/ */
model_name?: string; model: PipelineModelField;
}; };

View File

@ -1,23 +0,0 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* Loading submodels of selected model.
*/
export type SD1ModelLoaderInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'sd1_model_loader';
/**
* Model to load
*/
model_name?: string;
};

View File

@ -2,14 +2,17 @@
/* tslint:disable */ /* tslint:disable */
/* eslint-disable */ /* eslint-disable */
import type { BaseModelType } from './BaseModelType';
import type { ModelError } from './ModelError'; import type { ModelError } from './ModelError';
import type { ModelVariantType } from './ModelVariantType'; import type { ModelVariantType } from './ModelVariantType';
export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig = { export type StableDiffusion1ModelCheckpointConfig = {
name: string;
base_model: BaseModelType;
type: 'pipeline';
path: string; path: string;
description?: string; description?: string;
format: 'checkpoint'; model_format: 'checkpoint';
default?: boolean;
error?: ModelError; error?: ModelError;
vae?: string; vae?: string;
config?: string; config?: string;

View File

@ -2,14 +2,17 @@
/* tslint:disable */ /* tslint:disable */
/* eslint-disable */ /* eslint-disable */
import type { BaseModelType } from './BaseModelType';
import type { ModelError } from './ModelError'; import type { ModelError } from './ModelError';
import type { ModelVariantType } from './ModelVariantType'; import type { ModelVariantType } from './ModelVariantType';
export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig = { export type StableDiffusion1ModelDiffusersConfig = {
name: string;
base_model: BaseModelType;
type: 'pipeline';
path: string; path: string;
description?: string; description?: string;
format: 'diffusers'; model_format: 'diffusers';
default?: boolean;
error?: ModelError; error?: ModelError;
vae?: string; vae?: string;
variant: ModelVariantType; variant: ModelVariantType;

View File

@ -0,0 +1,8 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* An enumeration.
*/
export type StableDiffusion1ModelFormat = 'checkpoint' | 'diffusers';

View File

@ -2,15 +2,18 @@
/* tslint:disable */ /* tslint:disable */
/* eslint-disable */ /* eslint-disable */
import type { BaseModelType } from './BaseModelType';
import type { ModelError } from './ModelError'; import type { ModelError } from './ModelError';
import type { ModelVariantType } from './ModelVariantType'; import type { ModelVariantType } from './ModelVariantType';
import type { SchedulerPredictionType } from './SchedulerPredictionType'; import type { SchedulerPredictionType } from './SchedulerPredictionType';
export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig = { export type StableDiffusion2ModelCheckpointConfig = {
name: string;
base_model: BaseModelType;
type: 'pipeline';
path: string; path: string;
description?: string; description?: string;
format: 'checkpoint'; model_format: 'checkpoint';
default?: boolean;
error?: ModelError; error?: ModelError;
vae?: string; vae?: string;
config?: string; config?: string;

View File

@ -2,15 +2,18 @@
/* tslint:disable */ /* tslint:disable */
/* eslint-disable */ /* eslint-disable */
import type { BaseModelType } from './BaseModelType';
import type { ModelError } from './ModelError'; import type { ModelError } from './ModelError';
import type { ModelVariantType } from './ModelVariantType'; import type { ModelVariantType } from './ModelVariantType';
import type { SchedulerPredictionType } from './SchedulerPredictionType'; import type { SchedulerPredictionType } from './SchedulerPredictionType';
export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig = { export type StableDiffusion2ModelDiffusersConfig = {
name: string;
base_model: BaseModelType;
type: 'pipeline';
path: string; path: string;
description?: string; description?: string;
format: 'diffusers'; model_format: 'diffusers';
default?: boolean;
error?: ModelError; error?: ModelError;
vae?: string; vae?: string;
variant: ModelVariantType; variant: ModelVariantType;

View File

@ -0,0 +1,8 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* An enumeration.
*/
export type StableDiffusion2ModelFormat = 'checkpoint' | 'diffusers';

View File

@ -0,0 +1,17 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { BaseModelType } from './BaseModelType';
import type { ModelError } from './ModelError';
export type TextualInversionModelConfig = {
name: string;
base_model: BaseModelType;
type: 'embedding';
path: string;
description?: string;
model_format: null;
error?: ModelError;
};

View File

@ -0,0 +1,18 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { BaseModelType } from './BaseModelType';
import type { ModelError } from './ModelError';
import type { VaeModelFormat } from './VaeModelFormat';
export type VaeModelConfig = {
name: string;
base_model: BaseModelType;
type: 'vae';
path: string;
description?: string;
model_format: VaeModelFormat;
error?: ModelError;
};

View File

@ -0,0 +1,8 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* An enumeration.
*/
export type VaeModelFormat = 'checkpoint' | 'diffusers';

View File

@ -1,14 +0,0 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ModelError } from './ModelError';
export type invokeai__backend__model_management__models__controlnet__ControlNetModel__Config = {
path: string;
description?: string;
format: ('checkpoint' | 'diffusers');
default?: boolean;
error?: ModelError;
};

View File

@ -1,14 +0,0 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ModelError } from './ModelError';
export type invokeai__backend__model_management__models__lora__LoRAModel__Config = {
path: string;
description?: string;
format: ('lycoris' | 'diffusers');
default?: boolean;
error?: ModelError;
};

View File

@ -1,14 +0,0 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ModelError } from './ModelError';
export type invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config = {
path: string;
description?: string;
format: null;
default?: boolean;
error?: ModelError;
};

View File

@ -1,14 +0,0 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ModelError } from './ModelError';
export type invokeai__backend__model_management__models__vae__VaeModel__Config = {
path: string;
description?: string;
format: ('checkpoint' | 'diffusers');
default?: boolean;
error?: ModelError;
};

View File

@ -51,6 +51,7 @@ import type { PaginatedResults_GraphExecutionState_ } from '../models/PaginatedR
import type { ParamFloatInvocation } from '../models/ParamFloatInvocation'; import type { ParamFloatInvocation } from '../models/ParamFloatInvocation';
import type { ParamIntInvocation } from '../models/ParamIntInvocation'; import type { ParamIntInvocation } from '../models/ParamIntInvocation';
import type { PidiImageProcessorInvocation } from '../models/PidiImageProcessorInvocation'; import type { PidiImageProcessorInvocation } from '../models/PidiImageProcessorInvocation';
import type { PipelineModelLoaderInvocation } from '../models/PipelineModelLoaderInvocation';
import type { RandomIntInvocation } from '../models/RandomIntInvocation'; import type { RandomIntInvocation } from '../models/RandomIntInvocation';
import type { RandomRangeInvocation } from '../models/RandomRangeInvocation'; import type { RandomRangeInvocation } from '../models/RandomRangeInvocation';
import type { RangeInvocation } from '../models/RangeInvocation'; import type { RangeInvocation } from '../models/RangeInvocation';
@ -58,8 +59,6 @@ import type { RangeOfSizeInvocation } from '../models/RangeOfSizeInvocation';
import type { ResizeLatentsInvocation } from '../models/ResizeLatentsInvocation'; import type { ResizeLatentsInvocation } from '../models/ResizeLatentsInvocation';
import type { RestoreFaceInvocation } from '../models/RestoreFaceInvocation'; import type { RestoreFaceInvocation } from '../models/RestoreFaceInvocation';
import type { ScaleLatentsInvocation } from '../models/ScaleLatentsInvocation'; import type { ScaleLatentsInvocation } from '../models/ScaleLatentsInvocation';
import type { SD1ModelLoaderInvocation } from '../models/SD1ModelLoaderInvocation';
import type { SD2ModelLoaderInvocation } from '../models/SD2ModelLoaderInvocation';
import type { ShowImageInvocation } from '../models/ShowImageInvocation'; import type { ShowImageInvocation } from '../models/ShowImageInvocation';
import type { StepParamEasingInvocation } from '../models/StepParamEasingInvocation'; import type { StepParamEasingInvocation } from '../models/StepParamEasingInvocation';
import type { SubtractInvocation } from '../models/SubtractInvocation'; import type { SubtractInvocation } from '../models/SubtractInvocation';
@ -175,7 +174,7 @@ export class SessionsService {
* The id of the session * The id of the session
*/ */
sessionId: string, sessionId: string,
requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation), requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | PipelineModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation),
}): CancelablePromise<string> { }): CancelablePromise<string> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'POST', method: 'POST',
@ -212,7 +211,7 @@ export class SessionsService {
* The path to the node in the graph * The path to the node in the graph
*/ */
nodePath: string, nodePath: string,
requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation), requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | PipelineModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation),
}): CancelablePromise<GraphExecutionState> { }): CancelablePromise<GraphExecutionState> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'PUT', method: 'PUT',

View File

@ -13,23 +13,68 @@ import {
TagTypesFrom, TagTypesFrom,
TagTypesFromApi, TagTypesFromApi,
} from '@reduxjs/toolkit/dist/query/endpointDefinitions'; } from '@reduxjs/toolkit/dist/query/endpointDefinitions';
import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
import { BaseModelType } from './api/models/BaseModelType';
import { ModelType } from './api/models/ModelType';
import { ModelsList } from './api/models/ModelsList';
import { keyBy } from 'lodash-es';
type ListBoardsArg = { offset: number; limit: number }; type ListBoardsArg = { offset: number; limit: number };
type UpdateBoardArg = { board_id: string; changes: BoardChanges }; type UpdateBoardArg = { board_id: string; changes: BoardChanges };
type AddImageToBoardArg = { board_id: string; image_name: string }; type AddImageToBoardArg = { board_id: string; image_name: string };
type RemoveImageFromBoardArg = { board_id: string; image_name: string }; type RemoveImageFromBoardArg = { board_id: string; image_name: string };
type ListBoardImagesArg = { board_id: string; offset: number; limit: number }; type ListBoardImagesArg = { board_id: string; offset: number; limit: number };
type ListModelsArg = { base_model?: BaseModelType; model_type?: ModelType };
const tagTypes = ['Board', 'Image']; type ModelConfig = ModelsList['models'][number];
const tagTypes = ['Board', 'Image', 'Model'];
type ApiFullTagDescription = FullTagDescription<(typeof tagTypes)[number]>; type ApiFullTagDescription = FullTagDescription<(typeof tagTypes)[number]>;
const LIST = 'LIST'; const LIST = 'LIST';
const modelsAdapter = createEntityAdapter<ModelConfig>({
selectId: (model) => getModelId(model),
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
const getModelId = ({ base_model, type, name }: ModelConfig) =>
`${base_model}/${type}/${name}`;
export const api = createApi({ export const api = createApi({
baseQuery: fetchBaseQuery({ baseUrl: 'http://localhost:5173/api/v1/' }), baseQuery: fetchBaseQuery({ baseUrl: 'http://localhost:5173/api/v1/' }),
reducerPath: 'api', reducerPath: 'api',
tagTypes, tagTypes,
endpoints: (build) => ({ endpoints: (build) => ({
/**
* Models Queries
*/
listModels: build.query<EntityState<ModelConfig>, ListModelsArg>({
query: (arg) => ({ url: 'models/', params: arg }),
providesTags: (result, error, arg) => {
// any list of boards
const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST }];
if (result) {
// and individual tags for each board
tags.push(
...result.ids.map((id) => ({
type: 'Model' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: ModelsList, meta, arg) => {
return modelsAdapter.addMany(
modelsAdapter.getInitialState(),
keyBy(response.models, getModelId)
);
},
}),
/** /**
* Boards Queries * Boards Queries
*/ */
@ -174,4 +219,5 @@ export const {
useRemoveImageFromBoardMutation, useRemoveImageFromBoardMutation,
useListBoardImagesQuery, useListBoardImagesQuery,
useGetImageDTOQuery, useGetImageDTOQuery,
useListModelsQuery,
} = api; } = api;

View File

@ -1,33 +0,0 @@
import { log } from 'app/logging/useLogger';
import { createAppAsyncThunk } from 'app/store/storeUtils';
import { Model } from 'features/system/store/modelSlice';
import { reduce, size } from 'lodash-es';
import { ModelsService } from 'services/api';
const models = log.child({ namespace: 'model' });
export const IMAGES_PER_PAGE = 20;
export const receivedModels = createAppAsyncThunk(
'models/receivedModels',
async (_) => {
const response = await ModelsService.listModels();
const deserializedModels = reduce(
response.models['sd-1']['pipeline'],
(modelsAccumulator, model, modelName) => {
modelsAccumulator[modelName] = { ...model, name: modelName };
return modelsAccumulator;
},
{} as Record<string, Model>
);
models.info(
{ response },
`Received ${size(response.models['sd-1']['pipeline'])} models`
);
return deserializedModels;
}
);

View File

@ -4,8 +4,8 @@ import { generateColorPalette } from '../util/generateColorPalette';
export const greenTeaThemeColors: InvokeAIThemeColors = { export const greenTeaThemeColors: InvokeAIThemeColors = {
base: generateColorPalette(223, 10), base: generateColorPalette(223, 10),
baseAlpha: generateColorPalette(223, 10, false, true), baseAlpha: generateColorPalette(223, 10, false, true),
accent: generateColorPalette(155, 80), accent: generateColorPalette(160, 60),
accentAlpha: generateColorPalette(155, 80, false, true), accentAlpha: generateColorPalette(160, 60, false, true),
working: generateColorPalette(47, 68), working: generateColorPalette(47, 68),
workingAlpha: generateColorPalette(47, 68, false, true), workingAlpha: generateColorPalette(47, 68, false, true),
warning: generateColorPalette(28, 75), warning: generateColorPalette(28, 75),
@ -14,5 +14,5 @@ export const greenTeaThemeColors: InvokeAIThemeColors = {
okAlpha: generateColorPalette(122, 49, false, true), okAlpha: generateColorPalette(122, 49, false, true),
error: generateColorPalette(0, 50), error: generateColorPalette(0, 50),
errorAlpha: generateColorPalette(0, 50, false, true), errorAlpha: generateColorPalette(0, 50, false, true),
gridLineColor: 'rgba(255, 255, 255, 0.2)', gridLineColor: 'rgba(255, 255, 255, 0.15)',
}; };

View File

@ -2,8 +2,8 @@ import { InvokeAIThemeColors } from 'theme/themeTypes';
import { generateColorPalette } from 'theme/util/generateColorPalette'; import { generateColorPalette } from 'theme/util/generateColorPalette';
export const invokeAIThemeColors: InvokeAIThemeColors = { export const invokeAIThemeColors: InvokeAIThemeColors = {
base: generateColorPalette(225, 15), base: generateColorPalette(220, 15),
baseAlpha: generateColorPalette(225, 15, false, true), baseAlpha: generateColorPalette(220, 15, false, true),
accent: generateColorPalette(250, 50), accent: generateColorPalette(250, 50),
accentAlpha: generateColorPalette(250, 50, false, true), accentAlpha: generateColorPalette(250, 50, false, true),
working: generateColorPalette(47, 67), working: generateColorPalette(47, 67),
@ -14,5 +14,5 @@ export const invokeAIThemeColors: InvokeAIThemeColors = {
okAlpha: generateColorPalette(113, 70, false, true), okAlpha: generateColorPalette(113, 70, false, true),
error: generateColorPalette(0, 76), error: generateColorPalette(0, 76),
errorAlpha: generateColorPalette(0, 76, false, true), errorAlpha: generateColorPalette(0, 76, false, true),
gridLineColor: 'rgba(255, 255, 255, 0.2)', gridLineColor: 'rgba(150, 150, 180, 0.15)',
}; };

View File

@ -14,5 +14,5 @@ export const lightThemeColors: InvokeAIThemeColors = {
okAlpha: generateColorPalette(122, 49, true, true), okAlpha: generateColorPalette(122, 49, true, true),
error: generateColorPalette(0, 50, true), error: generateColorPalette(0, 50, true),
errorAlpha: generateColorPalette(0, 50, true, true), errorAlpha: generateColorPalette(0, 50, true, true),
gridLineColor: 'rgba(0, 0, 0, 0.2)', gridLineColor: 'rgba(0, 0, 0, 0.15)',
}; };

View File

@ -14,5 +14,5 @@ export const oceanBlueColors: InvokeAIThemeColors = {
okAlpha: generateColorPalette(122, 49, false, true), okAlpha: generateColorPalette(122, 49, false, true),
error: generateColorPalette(0, 100), error: generateColorPalette(0, 100),
errorAlpha: generateColorPalette(0, 100, false, true), errorAlpha: generateColorPalette(0, 100, false, true),
gridLineColor: 'rgba(136, 148, 184, 0.2)', gridLineColor: 'rgba(136, 148, 184, 0.15)',
}; };