mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update UI To Use New Model Manager (#3548)
PR for the Model Manager UI work related to 3.0 [DONE] - Update ModelType Config names to be specific so that the front end can parse them correctly. - Rebuild frontend schema to reflect these changes. - Update Linear UI Text To Image and Image to Image to work with the new model loader. - Updated the ModelInput component in the Node Editor to work with the new changes. [TODO REMEMBER] - Add proper types for ModelLoaderType in `ModelSelect.tsx` [TODO] - Everything else.
This commit is contained in:
commit
22c337b1aa
@ -7,8 +7,8 @@ from fastapi.routing import APIRouter, HTTPException
|
|||||||
from pydantic import BaseModel, Field, parse_obj_as
|
from pydantic import BaseModel, Field, parse_obj_as
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
from invokeai.backend import BaseModelType, ModelType
|
from invokeai.backend import BaseModelType, ModelType
|
||||||
from invokeai.backend.model_management.models import get_all_model_configs
|
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
|
||||||
MODEL_CONFIGS = Union[tuple(get_all_model_configs())]
|
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
@ -62,8 +62,7 @@ class ConvertedModelResponse(BaseModel):
|
|||||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend
|
models: list[MODEL_CONFIGS]
|
||||||
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
|
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
@ -72,10 +71,10 @@ class ModelsList(BaseModel):
|
|||||||
responses={200: {"model": ModelsList }},
|
responses={200: {"model": ModelsList }},
|
||||||
)
|
)
|
||||||
async def list_models(
|
async def list_models(
|
||||||
base_model: BaseModelType = Query(
|
base_model: Optional[BaseModelType] = Query(
|
||||||
default=None, description="Base model"
|
default=None, description="Base model"
|
||||||
),
|
),
|
||||||
model_type: ModelType = Query(
|
model_type: Optional[ModelType] = Query(
|
||||||
default=None, description="The type of model to get"
|
default=None, description="The type of model to get"
|
||||||
),
|
),
|
||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
|
@ -120,6 +120,22 @@ def custom_openapi():
|
|||||||
|
|
||||||
invoker_schema["output"] = outputs_ref
|
invoker_schema["output"] = outputs_ref
|
||||||
|
|
||||||
|
from invokeai.backend.model_management.models import get_model_config_enums
|
||||||
|
for model_config_format_enum in set(get_model_config_enums()):
|
||||||
|
name = model_config_format_enum.__qualname__
|
||||||
|
|
||||||
|
if name in openapi_schema["components"]["schemas"]:
|
||||||
|
# print(f"Config with name {name} already defined")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
|
||||||
|
openapi_schema["components"]["schemas"][name] = dict(
|
||||||
|
title=name,
|
||||||
|
description="An enumeration.",
|
||||||
|
type="string",
|
||||||
|
enum=list(v.value for v in model_config_format_enum),
|
||||||
|
)
|
||||||
|
|
||||||
app.openapi_schema = openapi_schema
|
app.openapi_schema = openapi_schema
|
||||||
return app.openapi_schema
|
return app.openapi_schema
|
||||||
|
|
||||||
|
@ -43,12 +43,19 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
|||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
class SD1ModelLoaderInvocation(BaseInvocation):
|
class PipelineModelField(BaseModel):
|
||||||
"""Loading submodels of selected model."""
|
"""Pipeline model field"""
|
||||||
|
|
||||||
type: Literal["sd1_model_loader"] = "sd1_model_loader"
|
model_name: str = Field(description="Name of the model")
|
||||||
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
model_name: str = Field(default="", description="Model to load")
|
|
||||||
|
class PipelineModelLoaderInvocation(BaseInvocation):
|
||||||
|
"""Loads a pipeline model, outputting its submodels."""
|
||||||
|
|
||||||
|
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
|
||||||
|
|
||||||
|
model: PipelineModelField = Field(description="The model to load")
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
@ -57,22 +64,24 @@ class SD1ModelLoaderInvocation(BaseInvocation):
|
|||||||
"ui": {
|
"ui": {
|
||||||
"tags": ["model", "loader"],
|
"tags": ["model", "loader"],
|
||||||
"type_hints": {
|
"type_hints": {
|
||||||
"model_name": "model" # TODO: rename to model_name?
|
"model": "model"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||||
|
|
||||||
base_model = BaseModelType.StableDiffusion1 # TODO:
|
base_model = self.model.base_model
|
||||||
|
model_name = self.model.model_name
|
||||||
|
model_type = ModelType.Pipeline
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.services.model_manager.model_exists(
|
||||||
model_name=self.model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=model_type,
|
||||||
):
|
):
|
||||||
raise Exception(f"Unkown model name: {self.model_name}!")
|
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.services.model_manager.model_exists(
|
||||||
@ -107,142 +116,39 @@ class SD1ModelLoaderInvocation(BaseInvocation):
|
|||||||
return ModelLoaderOutput(
|
return ModelLoaderOutput(
|
||||||
unet=UNetField(
|
unet=UNetField(
|
||||||
unet=ModelInfo(
|
unet=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=model_type,
|
||||||
submodel=SubModelType.UNet,
|
submodel=SubModelType.UNet,
|
||||||
),
|
),
|
||||||
scheduler=ModelInfo(
|
scheduler=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=model_type,
|
||||||
submodel=SubModelType.Scheduler,
|
submodel=SubModelType.Scheduler,
|
||||||
),
|
),
|
||||||
loras=[],
|
loras=[],
|
||||||
),
|
),
|
||||||
clip=ClipField(
|
clip=ClipField(
|
||||||
tokenizer=ModelInfo(
|
tokenizer=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=model_type,
|
||||||
submodel=SubModelType.Tokenizer,
|
submodel=SubModelType.Tokenizer,
|
||||||
),
|
),
|
||||||
text_encoder=ModelInfo(
|
text_encoder=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=model_type,
|
||||||
submodel=SubModelType.TextEncoder,
|
submodel=SubModelType.TextEncoder,
|
||||||
),
|
),
|
||||||
loras=[],
|
loras=[],
|
||||||
),
|
),
|
||||||
vae=VaeField(
|
vae=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=model_type,
|
||||||
submodel=SubModelType.Vae,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: optimize(less code copy)
|
|
||||||
class SD2ModelLoaderInvocation(BaseInvocation):
|
|
||||||
"""Loading submodels of selected model."""
|
|
||||||
|
|
||||||
type: Literal["sd2_model_loader"] = "sd2_model_loader"
|
|
||||||
|
|
||||||
model_name: str = Field(default="", description="Model to load")
|
|
||||||
# TODO: precision?
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["model", "loader"],
|
|
||||||
"type_hints": {
|
|
||||||
"model_name": "model" # TODO: rename to model_name?
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
|
||||||
|
|
||||||
base_model = BaseModelType.StableDiffusion2 # TODO:
|
|
||||||
|
|
||||||
# TODO: not found exceptions
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
):
|
|
||||||
raise Exception(f"Unkown model name: {self.model_name}!")
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
model_type=SDModelType.Diffusers,
|
|
||||||
submodel=SDModelType.Tokenizer,
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
model_type=SDModelType.Diffusers,
|
|
||||||
submodel=SDModelType.TextEncoder,
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
model_type=SDModelType.Diffusers,
|
|
||||||
submodel=SDModelType.UNet,
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
return ModelLoaderOutput(
|
|
||||||
unet=UNetField(
|
|
||||||
unet=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.UNet,
|
|
||||||
),
|
|
||||||
scheduler=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.Scheduler,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
clip=ClipField(
|
|
||||||
tokenizer=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.Tokenizer,
|
|
||||||
),
|
|
||||||
text_encoder=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.TextEncoder,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
vae=VaeField(
|
|
||||||
vae=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.Vae,
|
submodel=SubModelType.Vae,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import torch
|
import torch
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING
|
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from invokeai.backend.model_management.model_manager import (
|
from invokeai.backend.model_management.model_manager import (
|
||||||
@ -69,19 +69,6 @@ class ModelManagerServiceBase(ABC):
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
|
|
||||||
"""
|
|
||||||
Returns the name and typeof the default model, or None
|
|
||||||
if none is defined.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
"""Sets the default model to the indicated name."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
"""
|
"""
|
||||||
@ -270,17 +257,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
model_type,
|
model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
|
|
||||||
"""
|
|
||||||
Returns the name of the default model, or None
|
|
||||||
if none is defined.
|
|
||||||
"""
|
|
||||||
return self.mgr.default_model()
|
|
||||||
|
|
||||||
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
"""Sets the default model to the indicated name."""
|
|
||||||
self.mgr.set_default_model(model_name, base_model, model_type)
|
|
||||||
|
|
||||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
"""
|
"""
|
||||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||||
@ -297,21 +273,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
self,
|
self,
|
||||||
base_model: Optional[BaseModelType] = None,
|
base_model: Optional[BaseModelType] = None,
|
||||||
model_type: Optional[ModelType] = None
|
model_type: Optional[ModelType] = None
|
||||||
) -> dict:
|
) -> list[dict]:
|
||||||
|
# ) -> dict:
|
||||||
"""
|
"""
|
||||||
Return a dict of models in the format:
|
Return a list of models.
|
||||||
{ model_type1:
|
|
||||||
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
|
||||||
'model_name' : name,
|
|
||||||
'model_type' : SDModelType,
|
|
||||||
'description': description,
|
|
||||||
'format': 'folder'|'safetensors'|'ckpt'
|
|
||||||
},
|
|
||||||
model_name2: { etc }
|
|
||||||
},
|
|
||||||
model_type2:
|
|
||||||
{ model_name_n: etc
|
|
||||||
}
|
|
||||||
"""
|
"""
|
||||||
return self.mgr.list_models(base_model, model_type)
|
return self.mgr.list_models(base_model, model_type)
|
||||||
|
|
||||||
|
@ -266,6 +266,8 @@ class ModelManager(object):
|
|||||||
for model_key, model_config in config.items():
|
for model_key, model_config in config.items():
|
||||||
model_name, base_model, model_type = self.parse_key(model_key)
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
|
# alias for config file
|
||||||
|
model_config["model_format"] = model_config.pop("format")
|
||||||
self.models[model_key] = model_class.create_config(**model_config)
|
self.models[model_key] = model_class.create_config(**model_config)
|
||||||
|
|
||||||
# check config version number and update on disk/RAM if necessary
|
# check config version number and update on disk/RAM if necessary
|
||||||
@ -445,38 +447,6 @@ class ModelManager(object):
|
|||||||
_cache = self.cache,
|
_cache = self.cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
|
|
||||||
"""
|
|
||||||
Returns the name of the default model, or None
|
|
||||||
if none is defined.
|
|
||||||
"""
|
|
||||||
for model_key, model_config in self.models.items():
|
|
||||||
if model_config.default:
|
|
||||||
return self.parse_key(model_key)
|
|
||||||
|
|
||||||
for model_key, _ in self.models.items():
|
|
||||||
return self.parse_key(model_key)
|
|
||||||
else:
|
|
||||||
return None # TODO: or redo as (None, None, None)
|
|
||||||
|
|
||||||
def set_default_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Set the default model. The change will not take
|
|
||||||
effect until you call model_manager.commit()
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_key = self.model_key(model_name, base_model, model_type)
|
|
||||||
if model_key not in self.models:
|
|
||||||
raise Exception(f"Unknown model: {model_key}")
|
|
||||||
|
|
||||||
for cur_model_key, config in self.models.items():
|
|
||||||
config.default = cur_model_key == model_key
|
|
||||||
|
|
||||||
def model_info(
|
def model_info(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -503,9 +473,9 @@ class ModelManager(object):
|
|||||||
self,
|
self,
|
||||||
base_model: Optional[BaseModelType] = None,
|
base_model: Optional[BaseModelType] = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: Optional[ModelType] = None,
|
||||||
) -> Dict[str, Dict[str, str]]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Return a dict of models, in format [base_model][model_type][model_name]
|
Return a list of models.
|
||||||
|
|
||||||
Please use model_manager.models() to get all the model names,
|
Please use model_manager.models() to get all the model names,
|
||||||
model_manager.model_info('model-name') to get the stanza for the model
|
model_manager.model_info('model-name') to get the stanza for the model
|
||||||
@ -513,7 +483,7 @@ class ModelManager(object):
|
|||||||
object derived from models.yaml
|
object derived from models.yaml
|
||||||
"""
|
"""
|
||||||
|
|
||||||
models = dict()
|
models = []
|
||||||
for model_key in sorted(self.models, key=str.casefold):
|
for model_key in sorted(self.models, key=str.casefold):
|
||||||
model_config = self.models[model_key]
|
model_config = self.models[model_key]
|
||||||
|
|
||||||
@ -523,18 +493,16 @@ class ModelManager(object):
|
|||||||
if model_type is not None and cur_model_type != model_type:
|
if model_type is not None and cur_model_type != model_type:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if cur_base_model not in models:
|
model_dict = dict(
|
||||||
models[cur_base_model] = dict()
|
|
||||||
if cur_model_type not in models[cur_base_model]:
|
|
||||||
models[cur_base_model][cur_model_type] = dict()
|
|
||||||
|
|
||||||
models[cur_base_model][cur_model_type][cur_model_name] = dict(
|
|
||||||
**model_config.dict(exclude_defaults=True),
|
**model_config.dict(exclude_defaults=True),
|
||||||
|
# OpenAPIModelInfoBase
|
||||||
name=cur_model_name,
|
name=cur_model_name,
|
||||||
base_model=cur_base_model,
|
base_model=cur_base_model,
|
||||||
type=cur_model_type,
|
type=cur_model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
models.append(model_dict)
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def print_models(self) -> None:
|
def print_models(self) -> None:
|
||||||
@ -646,7 +614,9 @@ class ModelManager(object):
|
|||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
if model_class.save_to_config:
|
if model_class.save_to_config:
|
||||||
# TODO: or exclude_unset better fits here?
|
# TODO: or exclude_unset better fits here?
|
||||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True)
|
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
||||||
|
# alias for config file
|
||||||
|
data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format")
|
||||||
|
|
||||||
yaml_str = OmegaConf.to_yaml(data_to_save)
|
yaml_str = OmegaConf.to_yaml(data_to_save)
|
||||||
config_file_path = conf_file or self.config_path
|
config_file_path = conf_file or self.config_path
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
import inspect
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Literal, get_origin
|
||||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
|
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
|
||||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||||
from .vae import VaeModel
|
from .vae import VaeModel
|
||||||
@ -29,10 +33,63 @@ MODEL_CLASSES = {
|
|||||||
#},
|
#},
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_all_model_configs():
|
MODEL_CONFIGS = list()
|
||||||
configs = set()
|
OPENAPI_MODEL_CONFIGS = list()
|
||||||
for models in MODEL_CLASSES.values():
|
|
||||||
for _, model in models.items():
|
class OpenAPIModelInfoBase(BaseModel):
|
||||||
configs.update(model._get_configs().values())
|
name: str
|
||||||
configs.discard(None)
|
base_model: BaseModelType
|
||||||
return list(configs) # TODO: set, list or tuple
|
type: ModelType
|
||||||
|
|
||||||
|
|
||||||
|
for base_model, models in MODEL_CLASSES.items():
|
||||||
|
for model_type, model_class in models.items():
|
||||||
|
model_configs = set(model_class._get_configs().values())
|
||||||
|
model_configs.discard(None)
|
||||||
|
MODEL_CONFIGS.extend(model_configs)
|
||||||
|
|
||||||
|
for cfg in model_configs:
|
||||||
|
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
||||||
|
openapi_cfg_name = model_name + cfg_name
|
||||||
|
if openapi_cfg_name in vars():
|
||||||
|
continue
|
||||||
|
|
||||||
|
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
|
||||||
|
__annotations__ = dict(
|
||||||
|
type=Literal[model_type.value],
|
||||||
|
),
|
||||||
|
))
|
||||||
|
|
||||||
|
#globals()[openapi_cfg_name] = api_wrapper
|
||||||
|
vars()[openapi_cfg_name] = api_wrapper
|
||||||
|
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
|
||||||
|
|
||||||
|
def get_model_config_enums():
|
||||||
|
enums = list()
|
||||||
|
|
||||||
|
for model_config in MODEL_CONFIGS:
|
||||||
|
fields = inspect.get_annotations(model_config)
|
||||||
|
try:
|
||||||
|
field = fields["model_format"]
|
||||||
|
except:
|
||||||
|
raise Exception("format field not found")
|
||||||
|
|
||||||
|
# model_format: None
|
||||||
|
# model_format: SomeModelFormat
|
||||||
|
# model_format: Literal[SomeModelFormat.Diffusers]
|
||||||
|
# model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint]
|
||||||
|
|
||||||
|
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||||
|
enums.append(field)
|
||||||
|
|
||||||
|
elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
|
||||||
|
enums.append(type(field.__args__[0]))
|
||||||
|
|
||||||
|
elif field is None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
|
||||||
|
|
||||||
|
return enums
|
||||||
|
|
||||||
|
@ -48,12 +48,10 @@ class ModelError(str, Enum):
|
|||||||
|
|
||||||
class ModelConfigBase(BaseModel):
|
class ModelConfigBase(BaseModel):
|
||||||
path: str # or Path
|
path: str # or Path
|
||||||
#name: str # not included as present in model key
|
|
||||||
description: Optional[str] = Field(None)
|
description: Optional[str] = Field(None)
|
||||||
format: Optional[str] = Field(None)
|
model_format: Optional[str] = Field(None)
|
||||||
default: Optional[bool] = Field(False)
|
|
||||||
# do not save to config
|
# do not save to config
|
||||||
error: Optional[ModelError] = Field(None, exclude=True)
|
error: Optional[ModelError] = Field(None)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
use_enum_values = True
|
use_enum_values = True
|
||||||
@ -94,6 +92,11 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
|
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
|
||||||
if len(subtypes) < 2:
|
if len(subtypes) < 2:
|
||||||
raise Exception("Invalid subfolder definition!")
|
raise Exception("Invalid subfolder definition!")
|
||||||
|
if all(t is None for t in subtypes):
|
||||||
|
return None
|
||||||
|
elif any(t is None for t in subtypes):
|
||||||
|
raise Exception(f"Unsupported definition: {subtypes}")
|
||||||
|
|
||||||
if subtypes[0] in ["diffusers", "transformers"]:
|
if subtypes[0] in ["diffusers", "transformers"]:
|
||||||
res_type = sys.modules[subtypes[0]]
|
res_type = sys.modules[subtypes[0]]
|
||||||
subtypes = subtypes[1:]
|
subtypes = subtypes[1:]
|
||||||
@ -122,47 +125,41 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
fields = inspect.get_annotations(value)
|
fields = inspect.get_annotations(value)
|
||||||
if "format" not in fields:
|
try:
|
||||||
raise Exception("Invalid config definition - format field not found")
|
field = fields["model_format"]
|
||||||
|
except:
|
||||||
|
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
|
||||||
|
|
||||||
format_type = typing.get_origin(fields["format"])
|
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||||
if format_type not in {None, Literal, Union}:
|
for model_format in field:
|
||||||
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
configs[model_format.value] = value
|
||||||
|
|
||||||
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__):
|
elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
|
||||||
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
for model_format in field.__args__:
|
||||||
|
configs[model_format.value] = value
|
||||||
|
|
||||||
|
elif field is None:
|
||||||
|
configs[None] = value
|
||||||
|
|
||||||
if format_type == Union:
|
|
||||||
f_fields = fields["format"].__args__
|
|
||||||
else:
|
else:
|
||||||
f_fields = (fields["format"],)
|
raise Exception(f"Unsupported format definition in {cls.__qualname__}")
|
||||||
|
|
||||||
|
|
||||||
for field in f_fields:
|
|
||||||
if field is None:
|
|
||||||
format_name = None
|
|
||||||
else:
|
|
||||||
format_name = field.__args__[0]
|
|
||||||
|
|
||||||
configs[format_name] = value # TODO: error when override(multiple)?
|
|
||||||
|
|
||||||
|
|
||||||
cls.__configs = configs
|
cls.__configs = configs
|
||||||
return cls.__configs
|
return cls.__configs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_config(cls, **kwargs) -> ModelConfigBase:
|
def create_config(cls, **kwargs) -> ModelConfigBase:
|
||||||
if "format" not in kwargs:
|
if "model_format" not in kwargs:
|
||||||
raise Exception("Field 'format' not found in model config")
|
raise Exception("Field 'model_format' not found in model config")
|
||||||
|
|
||||||
configs = cls._get_configs()
|
configs = cls._get_configs()
|
||||||
return configs[kwargs["format"]](**kwargs)
|
return configs[kwargs["model_format"]](**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
|
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
|
||||||
return cls.create_config(
|
return cls.create_config(
|
||||||
path=path,
|
path=path,
|
||||||
format=cls.detect_format(path),
|
model_format=cls.detect_format(path),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union, Literal
|
from typing import Optional, Union, Literal
|
||||||
from .base import (
|
from .base import (
|
||||||
@ -14,12 +15,16 @@ from .base import (
|
|||||||
classproperty,
|
classproperty,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class ControlNetModelFormat(str, Enum):
|
||||||
|
Checkpoint = "checkpoint"
|
||||||
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
class ControlNetModel(ModelBase):
|
class ControlNetModel(ModelBase):
|
||||||
#model_class: Type
|
#model_class: Type
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
model_format: ControlNetModelFormat
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.ControlNet
|
assert model_type == ModelType.ControlNet
|
||||||
@ -69,9 +74,9 @@ class ControlNetModel(ModelBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, path: str):
|
def detect_format(cls, path: str):
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
return "diffusers"
|
return ControlNetModelFormat.Diffusers
|
||||||
else:
|
else:
|
||||||
return "checkpoint"
|
return ControlNetModelFormat.Checkpoint
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(
|
def convert_if_required(
|
||||||
@ -81,7 +86,7 @@ class ControlNetModel(ModelBase):
|
|||||||
config: ModelConfigBase, # empty config or config of parent model
|
config: ModelConfigBase, # empty config or config of parent model
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
if cls.detect_format(model_path) != "diffusers":
|
if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
|
||||||
raise NotImlemetedError("Checkpoint controlnet models currently unsupported")
|
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
from enum import Enum
|
||||||
from typing import Optional, Union, Literal
|
from typing import Optional, Union, Literal
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelBase,
|
ModelBase,
|
||||||
@ -12,11 +13,15 @@ from .base import (
|
|||||||
# TODO: naming
|
# TODO: naming
|
||||||
from ..lora import LoRAModel as LoRAModelRaw
|
from ..lora import LoRAModel as LoRAModelRaw
|
||||||
|
|
||||||
|
class LoRAModelFormat(str, Enum):
|
||||||
|
LyCORIS = "lycoris"
|
||||||
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
class LoRAModel(ModelBase):
|
class LoRAModel(ModelBase):
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
format: Union[Literal["lycoris"], Literal["diffusers"]]
|
model_format: LoRAModelFormat # TODO:
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.Lora
|
assert model_type == ModelType.Lora
|
||||||
@ -52,9 +57,9 @@ class LoRAModel(ModelBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, path: str):
|
def detect_format(cls, path: str):
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
return "diffusers"
|
return LoRAModelFormat.Diffusers
|
||||||
else:
|
else:
|
||||||
return "lycoris"
|
return LoRAModelFormat.LyCORIS
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(
|
def convert_if_required(
|
||||||
@ -64,7 +69,7 @@ class LoRAModel(ModelBase):
|
|||||||
config: ModelConfigBase,
|
config: ModelConfigBase,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
if cls.detect_format(model_path) == "diffusers":
|
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
|
||||||
# TODO: add diffusers lora when it stabilizes a bit
|
# TODO: add diffusers lora when it stabilizes a bit
|
||||||
raise NotImplementedError("Diffusers lora not supported")
|
raise NotImplementedError("Diffusers lora not supported")
|
||||||
else:
|
else:
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
from enum import Enum
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
@ -19,16 +20,19 @@ from .base import (
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
class StableDiffusion1ModelFormat(str, Enum):
|
||||||
|
Checkpoint = "checkpoint"
|
||||||
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
class StableDiffusion1Model(DiffusersModel):
|
class StableDiffusion1Model(DiffusersModel):
|
||||||
|
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
format: Literal["diffusers"]
|
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
class CheckpointConfig(ModelConfigBase):
|
||||||
format: Literal["checkpoint"]
|
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
config: Optional[str] = Field(None)
|
config: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
@ -47,7 +51,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
def probe_config(cls, path: str, **kwargs):
|
def probe_config(cls, path: str, **kwargs):
|
||||||
model_format = cls.detect_format(path)
|
model_format = cls.detect_format(path)
|
||||||
ckpt_config_path = kwargs.get("config", None)
|
ckpt_config_path = kwargs.get("config", None)
|
||||||
if model_format == "checkpoint":
|
if model_format == StableDiffusion1ModelFormat.Checkpoint:
|
||||||
if ckpt_config_path:
|
if ckpt_config_path:
|
||||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||||
@ -57,7 +61,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||||
|
|
||||||
elif model_format == "diffusers":
|
elif model_format == StableDiffusion1ModelFormat.Diffusers:
|
||||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||||
if os.path.exists(unet_config_path):
|
if os.path.exists(unet_config_path):
|
||||||
with open(unet_config_path, "r") as f:
|
with open(unet_config_path, "r") as f:
|
||||||
@ -80,7 +84,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
|
|
||||||
return cls.create_config(
|
return cls.create_config(
|
||||||
path=path,
|
path=path,
|
||||||
format=model_format,
|
model_format=model_format,
|
||||||
|
|
||||||
config=ckpt_config_path,
|
config=ckpt_config_path,
|
||||||
variant=variant,
|
variant=variant,
|
||||||
@ -93,9 +97,9 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, model_path: str):
|
def detect_format(cls, model_path: str):
|
||||||
if os.path.isdir(model_path):
|
if os.path.isdir(model_path):
|
||||||
return "diffusers"
|
return StableDiffusion1ModelFormat.Diffusers
|
||||||
else:
|
else:
|
||||||
return "checkpoint"
|
return StableDiffusion1ModelFormat.Checkpoint
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(
|
def convert_if_required(
|
||||||
@ -116,19 +120,22 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
class StableDiffusion2ModelFormat(str, Enum):
|
||||||
|
Checkpoint = "checkpoint"
|
||||||
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
class StableDiffusion2Model(DiffusersModel):
|
class StableDiffusion2Model(DiffusersModel):
|
||||||
|
|
||||||
# TODO: check that configs overwriten properly
|
# TODO: check that configs overwriten properly
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
format: Literal["diffusers"]
|
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
prediction_type: SchedulerPredictionType
|
prediction_type: SchedulerPredictionType
|
||||||
upcast_attention: bool
|
upcast_attention: bool
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
class CheckpointConfig(ModelConfigBase):
|
||||||
format: Literal["checkpoint"]
|
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
config: Optional[str] = Field(None)
|
config: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
@ -149,7 +156,7 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
def probe_config(cls, path: str, **kwargs):
|
def probe_config(cls, path: str, **kwargs):
|
||||||
model_format = cls.detect_format(path)
|
model_format = cls.detect_format(path)
|
||||||
ckpt_config_path = kwargs.get("config", None)
|
ckpt_config_path = kwargs.get("config", None)
|
||||||
if model_format == "checkpoint":
|
if model_format == StableDiffusion2ModelFormat.Checkpoint:
|
||||||
if ckpt_config_path:
|
if ckpt_config_path:
|
||||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||||
@ -159,7 +166,7 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||||
|
|
||||||
elif model_format == "diffusers":
|
elif model_format == StableDiffusion2ModelFormat.Diffusers:
|
||||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||||
if os.path.exists(unet_config_path):
|
if os.path.exists(unet_config_path):
|
||||||
with open(unet_config_path, "r") as f:
|
with open(unet_config_path, "r") as f:
|
||||||
@ -191,7 +198,7 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
|
|
||||||
return cls.create_config(
|
return cls.create_config(
|
||||||
path=path,
|
path=path,
|
||||||
format=model_format,
|
model_format=model_format,
|
||||||
|
|
||||||
config=ckpt_config_path,
|
config=ckpt_config_path,
|
||||||
variant=variant,
|
variant=variant,
|
||||||
@ -206,9 +213,9 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, model_path: str):
|
def detect_format(cls, model_path: str):
|
||||||
if os.path.isdir(model_path):
|
if os.path.isdir(model_path):
|
||||||
return "diffusers"
|
return StableDiffusion2ModelFormat.Diffusers
|
||||||
else:
|
else:
|
||||||
return "checkpoint"
|
return StableDiffusion2ModelFormat.Checkpoint
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(
|
def convert_if_required(
|
||||||
@ -281,8 +288,8 @@ def _convert_ckpt_and_cache(
|
|||||||
prediction_type = SchedulerPredictionType.Epsilon
|
prediction_type = SchedulerPredictionType.Epsilon
|
||||||
|
|
||||||
elif version == BaseModelType.StableDiffusion2:
|
elif version == BaseModelType.StableDiffusion2:
|
||||||
upcast_attention = config.upcast_attention
|
upcast_attention = model_config.upcast_attention
|
||||||
prediction_type = config.prediction_type
|
prediction_type = model_config.prediction_type
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unknown model provided: {version}")
|
raise Exception(f"Unknown model provided: {version}")
|
||||||
|
@ -16,7 +16,7 @@ class TextualInversionModel(ModelBase):
|
|||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
format: None
|
model_format: None
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.TextualInversion
|
assert model_type == ModelType.TextualInversion
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
import safetensors
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union, Literal
|
from typing import Optional, Union, Literal
|
||||||
from .base import (
|
from .base import (
|
||||||
@ -18,12 +20,16 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from diffusers.utils import is_safetensors_available
|
from diffusers.utils import is_safetensors_available
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
class VaeModelFormat(str, Enum):
|
||||||
|
Checkpoint = "checkpoint"
|
||||||
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
class VaeModel(ModelBase):
|
class VaeModel(ModelBase):
|
||||||
#vae_class: Type
|
#vae_class: Type
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
model_format: VaeModelFormat
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.Vae
|
assert model_type == ModelType.Vae
|
||||||
@ -70,9 +76,9 @@ class VaeModel(ModelBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, path: str):
|
def detect_format(cls, path: str):
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
return "diffusers"
|
return VaeModelFormat.Diffusers
|
||||||
else:
|
else:
|
||||||
return "checkpoint"
|
return VaeModelFormat.Checkpoint
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(
|
def convert_if_required(
|
||||||
@ -82,7 +88,7 @@ class VaeModel(ModelBase):
|
|||||||
config: ModelConfigBase, # empty config or config of parent model
|
config: ModelConfigBase, # empty config or config of parent model
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
if cls.detect_format(model_path) != "diffusers":
|
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
|
||||||
return _convert_vae_ckpt_and_cache(
|
return _convert_vae_ckpt_and_cache(
|
||||||
weights_path=model_path,
|
weights_path=model_path,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
|
@ -24,6 +24,7 @@ import Toaster from './Toaster';
|
|||||||
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
|
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
|
||||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||||
|
import { useListModelsQuery } from 'services/apiSlice';
|
||||||
|
|
||||||
const DEFAULT_CONFIG = {};
|
const DEFAULT_CONFIG = {};
|
||||||
|
|
||||||
@ -46,6 +47,18 @@ const App = ({
|
|||||||
|
|
||||||
const isApplicationReady = useIsApplicationReady();
|
const isApplicationReady = useIsApplicationReady();
|
||||||
|
|
||||||
|
const { data: pipelineModels } = useListModelsQuery({
|
||||||
|
model_type: 'pipeline',
|
||||||
|
});
|
||||||
|
const { data: controlnetModels } = useListModelsQuery({
|
||||||
|
model_type: 'controlnet',
|
||||||
|
});
|
||||||
|
const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' });
|
||||||
|
const { data: loraModels } = useListModelsQuery({ model_type: 'lora' });
|
||||||
|
const { data: embeddingModels } = useListModelsQuery({
|
||||||
|
model_type: 'embedding',
|
||||||
|
});
|
||||||
|
|
||||||
const [loadingOverridden, setLoadingOverridden] = useState(false);
|
const [loadingOverridden, setLoadingOverridden] = useState(false);
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
@ -5,7 +5,6 @@ import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersist
|
|||||||
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
||||||
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
|
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
|
||||||
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
|
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
|
||||||
import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
|
|
||||||
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
|
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
|
||||||
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
|
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
|
||||||
import { omit } from 'lodash-es';
|
import { omit } from 'lodash-es';
|
||||||
@ -18,7 +17,6 @@ const serializationDenylist: {
|
|||||||
gallery: galleryPersistDenylist,
|
gallery: galleryPersistDenylist,
|
||||||
generation: generationPersistDenylist,
|
generation: generationPersistDenylist,
|
||||||
lightbox: lightboxPersistDenylist,
|
lightbox: lightboxPersistDenylist,
|
||||||
models: modelsPersistDenylist,
|
|
||||||
nodes: nodesPersistDenylist,
|
nodes: nodesPersistDenylist,
|
||||||
postprocessing: postprocessingPersistDenylist,
|
postprocessing: postprocessingPersistDenylist,
|
||||||
system: systemPersistDenylist,
|
system: systemPersistDenylist,
|
||||||
|
@ -7,7 +7,6 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice';
|
|||||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||||
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
|
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
|
||||||
import { initialConfigState } from 'features/system/store/configSlice';
|
import { initialConfigState } from 'features/system/store/configSlice';
|
||||||
import { initialModelsState } from 'features/system/store/modelSlice';
|
|
||||||
import { initialSystemState } from 'features/system/store/systemSlice';
|
import { initialSystemState } from 'features/system/store/systemSlice';
|
||||||
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
|
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
|
||||||
import { initialUIState } from 'features/ui/store/uiSlice';
|
import { initialUIState } from 'features/ui/store/uiSlice';
|
||||||
@ -21,7 +20,6 @@ const initialStates: {
|
|||||||
gallery: initialGalleryState,
|
gallery: initialGalleryState,
|
||||||
generation: initialGenerationState,
|
generation: initialGenerationState,
|
||||||
lightbox: initialLightboxState,
|
lightbox: initialLightboxState,
|
||||||
models: initialModelsState,
|
|
||||||
nodes: initialNodesState,
|
nodes: initialNodesState,
|
||||||
postprocessing: initialPostprocessingState,
|
postprocessing: initialPostprocessingState,
|
||||||
system: initialSystemState,
|
system: initialSystemState,
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
import { startAppListening } from '../..';
|
|
||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
||||||
import { receivedPageOfImages } from 'services/thunks/image';
|
import { receivedPageOfImages } from 'services/thunks/image';
|
||||||
import { receivedModels } from 'services/thunks/model';
|
|
||||||
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
||||||
|
import { startAppListening } from '../..';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'socketio' });
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
@ -15,7 +14,7 @@ export const addSocketConnectedEventListener = () => {
|
|||||||
|
|
||||||
moduleLog.debug({ timestamp }, 'Connected');
|
moduleLog.debug({ timestamp }, 'Connected');
|
||||||
|
|
||||||
const { models, nodes, config, images } = getState();
|
const { nodes, config, images } = getState();
|
||||||
|
|
||||||
const { disabledTabs } = config;
|
const { disabledTabs } = config;
|
||||||
|
|
||||||
@ -28,10 +27,6 @@ export const addSocketConnectedEventListener = () => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!models.ids.length) {
|
|
||||||
dispatch(receivedModels());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
||||||
dispatch(receivedOpenAPISchema());
|
dispatch(receivedOpenAPISchema());
|
||||||
}
|
}
|
||||||
|
@ -5,34 +5,32 @@ import {
|
|||||||
configureStore,
|
configureStore,
|
||||||
} from '@reduxjs/toolkit';
|
} from '@reduxjs/toolkit';
|
||||||
|
|
||||||
import { rememberReducer, rememberEnhancer } from 'redux-remember';
|
|
||||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||||
|
import { rememberEnhancer, rememberReducer } from 'redux-remember';
|
||||||
|
|
||||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||||
|
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
||||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||||
import imagesReducer from 'features/gallery/store/imagesSlice';
|
import imagesReducer from 'features/gallery/store/imagesSlice';
|
||||||
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
||||||
import generationReducer from 'features/parameters/store/generationSlice';
|
import generationReducer from 'features/parameters/store/generationSlice';
|
||||||
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
|
||||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||||
import systemReducer from 'features/system/store/systemSlice';
|
import systemReducer from 'features/system/store/systemSlice';
|
||||||
// import sessionReducer from 'features/system/store/sessionSlice';
|
// import sessionReducer from 'features/system/store/sessionSlice';
|
||||||
import configReducer from 'features/system/store/configSlice';
|
|
||||||
import uiReducer from 'features/ui/store/uiSlice';
|
|
||||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
|
||||||
import modelsReducer from 'features/system/store/modelSlice';
|
|
||||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||||
import boardsReducer from 'features/gallery/store/boardSlice';
|
import boardsReducer from 'features/gallery/store/boardSlice';
|
||||||
|
import configReducer from 'features/system/store/configSlice';
|
||||||
|
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||||
|
import uiReducer from 'features/ui/store/uiSlice';
|
||||||
|
|
||||||
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
||||||
|
|
||||||
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
||||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
|
||||||
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||||
|
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||||
|
import { LOCALSTORAGE_PREFIX } from './constants';
|
||||||
import { serialize } from './enhancers/reduxRemember/serialize';
|
import { serialize } from './enhancers/reduxRemember/serialize';
|
||||||
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
||||||
import { LOCALSTORAGE_PREFIX } from './constants';
|
|
||||||
import { api } from 'services/apiSlice';
|
import { api } from 'services/apiSlice';
|
||||||
|
|
||||||
const allReducers = {
|
const allReducers = {
|
||||||
@ -40,7 +38,6 @@ const allReducers = {
|
|||||||
gallery: galleryReducer,
|
gallery: galleryReducer,
|
||||||
generation: generationReducer,
|
generation: generationReducer,
|
||||||
lightbox: lightboxReducer,
|
lightbox: lightboxReducer,
|
||||||
models: modelsReducer,
|
|
||||||
nodes: nodesReducer,
|
nodes: nodesReducer,
|
||||||
postprocessing: postprocessingReducer,
|
postprocessing: postprocessingReducer,
|
||||||
system: systemReducer,
|
system: systemReducer,
|
||||||
@ -50,8 +47,8 @@ const allReducers = {
|
|||||||
images: imagesReducer,
|
images: imagesReducer,
|
||||||
controlNet: controlNetReducer,
|
controlNet: controlNetReducer,
|
||||||
boards: boardsReducer,
|
boards: boardsReducer,
|
||||||
[api.reducerPath]: api.reducer,
|
|
||||||
// session: sessionReducer,
|
// session: sessionReducer,
|
||||||
|
[api.reducerPath]: api.reducer,
|
||||||
};
|
};
|
||||||
|
|
||||||
const rootReducer = combineReducers(allReducers);
|
const rootReducer = combineReducers(allReducers);
|
||||||
@ -63,7 +60,6 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
|||||||
'gallery',
|
'gallery',
|
||||||
'generation',
|
'generation',
|
||||||
'lightbox',
|
'lightbox',
|
||||||
// 'models',
|
|
||||||
'nodes',
|
'nodes',
|
||||||
'postprocessing',
|
'postprocessing',
|
||||||
'system',
|
'system',
|
||||||
|
@ -1,28 +1,18 @@
|
|||||||
import { Select } from '@chakra-ui/react';
|
import { SelectItem } from '@mantine/core';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import {
|
import {
|
||||||
ModelInputFieldTemplate,
|
ModelInputFieldTemplate,
|
||||||
ModelInputFieldValue,
|
ModelInputFieldValue,
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { selectModelsIds } from 'features/system/store/modelSlice';
|
|
||||||
import { isEqual } from 'lodash-es';
|
|
||||||
import { ChangeEvent, memo } from 'react';
|
|
||||||
import { FieldComponentProps } from './types';
|
|
||||||
|
|
||||||
const availableModelsSelector = createSelector(
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
[selectModelsIds],
|
import { FieldComponentProps } from './types';
|
||||||
(allModelNames) => {
|
import { forEach, isString } from 'lodash-es';
|
||||||
return { allModelNames };
|
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||||
// return map(modelList, (_, name) => name);
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
},
|
import { useTranslation } from 'react-i18next';
|
||||||
{
|
import { useListModelsQuery } from 'services/apiSlice';
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const ModelInputFieldComponent = (
|
const ModelInputFieldComponent = (
|
||||||
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
||||||
@ -30,28 +20,82 @@ const ModelInputFieldComponent = (
|
|||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { allModelNames } = useAppSelector(availableModelsSelector);
|
const { data: pipelineModels } = useListModelsQuery({
|
||||||
|
model_type: 'pipeline',
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!pipelineModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(pipelineModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.name,
|
||||||
|
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [pipelineModels]);
|
||||||
|
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]],
|
||||||
|
[pipelineModels?.entities, pipelineModels?.ids, field.value]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleValueChanged = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
|
|
||||||
dispatch(
|
dispatch(
|
||||||
fieldValueChanged({
|
fieldValueChanged({
|
||||||
nodeId,
|
nodeId,
|
||||||
fieldName: field.name,
|
fieldName: field.name,
|
||||||
value: e.target.value,
|
value: v,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
};
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (field.value && pipelineModels?.ids.includes(field.value)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const firstModel = pipelineModels?.ids[0];
|
||||||
|
|
||||||
|
if (!isString(firstModel)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
handleValueChanged(firstModel);
|
||||||
|
}, [field.value, handleValueChanged, pipelineModels?.ids]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Select
|
<IAIMantineSelect
|
||||||
|
tooltip={selectedModel?.description}
|
||||||
|
label={
|
||||||
|
selectedModel?.base_model &&
|
||||||
|
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
||||||
|
}
|
||||||
|
value={field.value}
|
||||||
|
placeholder="Pick one"
|
||||||
|
data={data}
|
||||||
onChange={handleValueChanged}
|
onChange={handleValueChanged}
|
||||||
value={field.value || allModelNames[0]}
|
/>
|
||||||
>
|
|
||||||
{allModelNames.map((option) => (
|
|
||||||
<option key={option}>{option}</option>
|
|
||||||
))}
|
|
||||||
</Select>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -101,21 +101,6 @@ const nodesSlice = createSlice({
|
|||||||
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
|
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
|
||||||
state.schema = action.payload;
|
state.schema = action.payload;
|
||||||
});
|
});
|
||||||
|
|
||||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
|
||||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
|
||||||
|
|
||||||
state.nodes.forEach((node) => {
|
|
||||||
forEach(node.data.inputs, (input) => {
|
|
||||||
if (input.type === 'image') {
|
|
||||||
if (input.value?.image_name === image_name) {
|
|
||||||
input.value.image_url = image_url;
|
|
||||||
input.value.thumbnail_url = thumbnail_url;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ import {
|
|||||||
} from './constants';
|
} from './constants';
|
||||||
import { set } from 'lodash-es';
|
import { set } from 'lodash-es';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
@ -36,7 +37,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: model_name,
|
model: modelId,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -49,6 +50,8 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
// The bounding box determines width and height, not the width and height params
|
// The bounding box determines width and height, not the width and height params
|
||||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||||
|
|
||||||
|
const model = modelIdToPipelineModelField(modelId);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||||
@ -85,9 +88,9 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
id: NOISE,
|
id: NOISE,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[MODEL_LOADER]: {
|
||||||
type: 'sd1_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: MODEL_LOADER,
|
||||||
model_name,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
|
@ -17,6 +17,7 @@ import {
|
|||||||
INPAINT_GRAPH,
|
INPAINT_GRAPH,
|
||||||
INPAINT,
|
INPAINT,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
@ -31,7 +32,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: model_name,
|
model: modelId,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -54,6 +55,8 @@ export const buildCanvasInpaintGraph = (
|
|||||||
// We may need to set the inpaint width and height to scale the image
|
// We may need to set the inpaint width and height to scale the image
|
||||||
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
||||||
|
|
||||||
|
const model = modelIdToPipelineModelField(modelId);
|
||||||
|
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
id: INPAINT_GRAPH,
|
id: INPAINT_GRAPH,
|
||||||
nodes: {
|
nodes: {
|
||||||
@ -99,9 +102,9 @@ export const buildCanvasInpaintGraph = (
|
|||||||
prompt: negativePrompt,
|
prompt: negativePrompt,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[MODEL_LOADER]: {
|
||||||
type: 'sd1_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: MODEL_LOADER,
|
||||||
model_name,
|
model,
|
||||||
},
|
},
|
||||||
[RANGE_OF_SIZE]: {
|
[RANGE_OF_SIZE]: {
|
||||||
type: 'range_of_size',
|
type: 'range_of_size',
|
||||||
|
@ -14,6 +14,7 @@ import {
|
|||||||
TEXT_TO_LATENTS,
|
TEXT_TO_LATENTS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Canvas tab's Text to Image graph.
|
* Builds the Canvas tab's Text to Image graph.
|
||||||
@ -24,7 +25,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: model_name,
|
model: modelId,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -36,6 +37,8 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
// The bounding box determines width and height, not the width and height params
|
// The bounding box determines width and height, not the width and height params
|
||||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||||
|
|
||||||
|
const model = modelIdToPipelineModelField(modelId);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||||
@ -80,9 +83,9 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
steps,
|
steps,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[MODEL_LOADER]: {
|
||||||
type: 'sd1_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: MODEL_LOADER,
|
||||||
model_name,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
|
@ -22,6 +22,7 @@ import {
|
|||||||
} from './constants';
|
} from './constants';
|
||||||
import { set } from 'lodash-es';
|
import { set } from 'lodash-es';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
@ -34,7 +35,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: model_name,
|
model: modelId,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -62,6 +63,8 @@ export const buildLinearImageToImageGraph = (
|
|||||||
throw new Error('No initial image found in state');
|
throw new Error('No initial image found in state');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const model = modelIdToPipelineModelField(modelId);
|
||||||
|
|
||||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
id: IMAGE_TO_IMAGE_GRAPH,
|
id: IMAGE_TO_IMAGE_GRAPH,
|
||||||
@ -89,9 +92,9 @@ export const buildLinearImageToImageGraph = (
|
|||||||
id: NOISE,
|
id: NOISE,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[MODEL_LOADER]: {
|
||||||
type: 'sd1_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: MODEL_LOADER,
|
||||||
model_name,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api';
|
import {
|
||||||
|
BaseModelType,
|
||||||
|
RandomIntInvocation,
|
||||||
|
RangeOfSizeInvocation,
|
||||||
|
} from 'services/api';
|
||||||
import {
|
import {
|
||||||
ITERATE,
|
ITERATE,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
@ -14,6 +18,7 @@ import {
|
|||||||
TEXT_TO_LATENTS,
|
TEXT_TO_LATENTS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
|
||||||
type TextToImageGraphOverrides = {
|
type TextToImageGraphOverrides = {
|
||||||
width: number;
|
width: number;
|
||||||
@ -27,7 +32,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: model_name,
|
model: modelId,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -38,6 +43,8 @@ export const buildLinearTextToImageGraph = (
|
|||||||
shouldRandomizeSeed,
|
shouldRandomizeSeed,
|
||||||
} = state.generation;
|
} = state.generation;
|
||||||
|
|
||||||
|
const model = modelIdToPipelineModelField(modelId);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||||
@ -82,9 +89,9 @@ export const buildLinearTextToImageGraph = (
|
|||||||
steps,
|
steps,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[MODEL_LOADER]: {
|
||||||
type: 'sd1_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: MODEL_LOADER,
|
||||||
model_name,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import { Graph } from 'services/api';
|
import { Graph } from 'services/api';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
import { cloneDeep, forEach, omit, reduce, values } from 'lodash-es';
|
import { cloneDeep, omit, reduce } from 'lodash-es';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { InputFieldValue } from 'features/nodes/types/types';
|
import { InputFieldValue } from 'features/nodes/types/types';
|
||||||
import { AnyInvocation } from 'services/events/types';
|
import { AnyInvocation } from 'services/events/types';
|
||||||
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* We need to do special handling for some fields
|
* We need to do special handling for some fields
|
||||||
@ -24,6 +25,12 @@ export const parseFieldValue = (field: InputFieldValue) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (field.type === 'model') {
|
||||||
|
if (field.value) {
|
||||||
|
return modelIdToPipelineModelField(field.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return field.value;
|
return field.value;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ export const NOISE = 'noise';
|
|||||||
export const RANDOM_INT = 'rand_int';
|
export const RANDOM_INT = 'rand_int';
|
||||||
export const RANGE_OF_SIZE = 'range_of_size';
|
export const RANGE_OF_SIZE = 'range_of_size';
|
||||||
export const ITERATE = 'iterate';
|
export const ITERATE = 'iterate';
|
||||||
export const MODEL_LOADER = 'model_loader';
|
export const MODEL_LOADER = 'pipeline_model_loader';
|
||||||
export const IMAGE_TO_LATENTS = 'image_to_latents';
|
export const IMAGE_TO_LATENTS = 'image_to_latents';
|
||||||
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
||||||
export const RESIZE = 'resize_image';
|
export const RESIZE = 'resize_image';
|
||||||
|
@ -0,0 +1,18 @@
|
|||||||
|
import { BaseModelType, PipelineModelField } from 'services/api';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Crudely converts a model id to a pipeline model field
|
||||||
|
* TODO: Make better
|
||||||
|
*/
|
||||||
|
export const modelIdToPipelineModelField = (
|
||||||
|
modelId: string
|
||||||
|
): PipelineModelField => {
|
||||||
|
const [base_model, model_type, model_name] = modelId.split('/');
|
||||||
|
|
||||||
|
const field: PipelineModelField = {
|
||||||
|
base_model: base_model as BaseModelType,
|
||||||
|
model_name,
|
||||||
|
};
|
||||||
|
|
||||||
|
return field;
|
||||||
|
};
|
@ -6,7 +6,7 @@ import ParamScheduler from './ParamScheduler';
|
|||||||
const ParamSchedulerAndModel = () => {
|
const ParamSchedulerAndModel = () => {
|
||||||
return (
|
return (
|
||||||
<Flex gap={3} w="full">
|
<Flex gap={3} w="full">
|
||||||
<Box w="20rem">
|
<Box w="25rem">
|
||||||
<ParamScheduler />
|
<ParamScheduler />
|
||||||
</Box>
|
</Box>
|
||||||
<Box w="full">
|
<Box w="full">
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
|
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
|
||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { clamp, sortBy } from 'lodash-es';
|
import { clamp } from 'lodash-es';
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
import { imageUrlsReceived } from 'services/thunks/image';
|
|
||||||
import { receivedModels } from 'services/thunks/model';
|
|
||||||
import {
|
import {
|
||||||
CfgScaleParam,
|
CfgScaleParam,
|
||||||
HeightParam,
|
HeightParam,
|
||||||
@ -17,7 +16,6 @@ import {
|
|||||||
StrengthParam,
|
StrengthParam,
|
||||||
WidthParam,
|
WidthParam,
|
||||||
} from './parameterZodSchemas';
|
} from './parameterZodSchemas';
|
||||||
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
|
|
||||||
|
|
||||||
export interface GenerationState {
|
export interface GenerationState {
|
||||||
cfgScale: CfgScaleParam;
|
cfgScale: CfgScaleParam;
|
||||||
@ -220,28 +218,12 @@ export const generationSlice = createSlice({
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(receivedModels.fulfilled, (state, action) => {
|
|
||||||
if (!state.model) {
|
|
||||||
const firstModel = sortBy(action.payload, 'name')[0];
|
|
||||||
state.model = firstModel.name;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
builder.addCase(configChanged, (state, action) => {
|
builder.addCase(configChanged, (state, action) => {
|
||||||
const defaultModel = action.payload.sd?.defaultModel;
|
const defaultModel = action.payload.sd?.defaultModel;
|
||||||
if (defaultModel && !state.model) {
|
if (defaultModel && !state.model) {
|
||||||
state.model = defaultModel;
|
state.model = defaultModel;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
|
||||||
// const { image_name, image_url, thumbnail_url } = action.payload;
|
|
||||||
|
|
||||||
// if (state.initialImage?.image_name === image_name) {
|
|
||||||
// state.initialImage.image_url = image_url;
|
|
||||||
// state.initialImage.thumbnail_url = thumbnail_url;
|
|
||||||
// }
|
|
||||||
// });
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -154,3 +154,17 @@ export type StrengthParam = z.infer<typeof zStrength>;
|
|||||||
*/
|
*/
|
||||||
export const isValidStrength = (val: unknown): val is StrengthParam =>
|
export const isValidStrength = (val: unknown): val is StrengthParam =>
|
||||||
zStrength.safeParse(val).success;
|
zStrength.safeParse(val).success;
|
||||||
|
|
||||||
|
// /**
|
||||||
|
// * Zod schema for BaseModelType
|
||||||
|
// */
|
||||||
|
// export const zBaseModelType = z.enum(['sd-1', 'sd-2']);
|
||||||
|
// /**
|
||||||
|
// * Type alias for base model type, inferred from its zod schema. Should be identical to the type alias from OpenAPI.
|
||||||
|
// */
|
||||||
|
// export type BaseModelType = z.infer<typeof zBaseModelType>;
|
||||||
|
// /**
|
||||||
|
// * Validates/type-guards a value as a base model type
|
||||||
|
// */
|
||||||
|
// export const isValidBaseModelType = (val: unknown): val is BaseModelType =>
|
||||||
|
// zBaseModelType.safeParse(val).success;
|
||||||
|
@ -1,44 +1,59 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
import { isEqual } from 'lodash-es';
|
|
||||||
import { memo, useCallback } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIMantineSelect, {
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
IAISelectDataType,
|
|
||||||
} from 'common/components/IAIMantineSelect';
|
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
|
||||||
import { modelSelected } from 'features/parameters/store/generationSlice';
|
import { modelSelected } from 'features/parameters/store/generationSlice';
|
||||||
import { selectModelsAll, selectModelsById } from '../store/modelSlice';
|
|
||||||
|
|
||||||
const selector = createSelector(
|
import { forEach, isString } from 'lodash-es';
|
||||||
[(state: RootState) => state, generationSelector],
|
import { SelectItem } from '@mantine/core';
|
||||||
(state, generation) => {
|
import { RootState } from 'app/store/store';
|
||||||
const selectedModel = selectModelsById(state, generation.model);
|
import { useListModelsQuery } from 'services/apiSlice';
|
||||||
|
|
||||||
const modelData = selectModelsAll(state)
|
export const MODEL_TYPE_MAP = {
|
||||||
.map<IAISelectDataType>((m) => ({
|
'sd-1': 'Stable Diffusion 1.x',
|
||||||
value: m.name,
|
'sd-2': 'Stable Diffusion 2.x',
|
||||||
label: m.name,
|
};
|
||||||
}))
|
|
||||||
.sort((a, b) => a.label.localeCompare(b.label));
|
|
||||||
return {
|
|
||||||
selectedModel,
|
|
||||||
modelData,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const ModelSelect = () => {
|
const ModelSelect = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { selectedModel, modelData } = useAppSelector(selector);
|
|
||||||
|
const selectedModelId = useAppSelector(
|
||||||
|
(state: RootState) => state.generation.model
|
||||||
|
);
|
||||||
|
|
||||||
|
const { data: pipelineModels } = useListModelsQuery({
|
||||||
|
model_type: 'pipeline',
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!pipelineModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(pipelineModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.name,
|
||||||
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [pipelineModels]);
|
||||||
|
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() => pipelineModels?.entities[selectedModelId],
|
||||||
|
[pipelineModels?.entities, selectedModelId]
|
||||||
|
);
|
||||||
|
|
||||||
const handleChangeModel = useCallback(
|
const handleChangeModel = useCallback(
|
||||||
(v: string | null) => {
|
(v: string | null) => {
|
||||||
if (!v) {
|
if (!v) {
|
||||||
@ -49,13 +64,27 @@ const ModelSelect = () => {
|
|||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const firstModel = pipelineModels?.ids[0];
|
||||||
|
|
||||||
|
if (!isString(firstModel)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
handleChangeModel(firstModel);
|
||||||
|
}, [handleChangeModel, pipelineModels?.ids, selectedModelId]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
tooltip={selectedModel?.description}
|
tooltip={selectedModel?.description}
|
||||||
label={t('modelManager.model')}
|
label={t('modelManager.model')}
|
||||||
value={selectedModel?.name ?? ''}
|
value={selectedModelId}
|
||||||
placeholder="Pick one"
|
placeholder="Pick one"
|
||||||
data={modelData}
|
data={data}
|
||||||
onChange={handleChangeModel}
|
onChange={handleChangeModel}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
|
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||||
@ -16,6 +15,7 @@ const data = map(SCHEDULER_NAMES, (s) => ({
|
|||||||
|
|
||||||
export default function SettingsSchedulers() {
|
export default function SettingsSchedulers() {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const enabledSchedulers = useAppSelector(
|
const enabledSchedulers = useAppSelector(
|
||||||
|
@ -7,13 +7,12 @@ import { systemSelector } from '../store/systemSelectors';
|
|||||||
const isApplicationReadySelector = createSelector(
|
const isApplicationReadySelector = createSelector(
|
||||||
[systemSelector, configSelector],
|
[systemSelector, configSelector],
|
||||||
(system, config) => {
|
(system, config) => {
|
||||||
const { wereModelsReceived, wasSchemaParsed } = system;
|
const { wasSchemaParsed } = system;
|
||||||
|
|
||||||
const { disabledTabs } = config;
|
const { disabledTabs } = config;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
disabledTabs,
|
disabledTabs,
|
||||||
wereModelsReceived,
|
|
||||||
wasSchemaParsed,
|
wasSchemaParsed,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -23,21 +22,17 @@ const isApplicationReadySelector = createSelector(
|
|||||||
* Checks if the application is ready to be used, i.e. if the initial startup process is finished.
|
* Checks if the application is ready to be used, i.e. if the initial startup process is finished.
|
||||||
*/
|
*/
|
||||||
export const useIsApplicationReady = () => {
|
export const useIsApplicationReady = () => {
|
||||||
const { disabledTabs, wereModelsReceived, wasSchemaParsed } = useAppSelector(
|
const { disabledTabs, wasSchemaParsed } = useAppSelector(
|
||||||
isApplicationReadySelector
|
isApplicationReadySelector
|
||||||
);
|
);
|
||||||
|
|
||||||
const isApplicationReady = useMemo(() => {
|
const isApplicationReady = useMemo(() => {
|
||||||
if (!wereModelsReceived) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!disabledTabs.includes('nodes') && !wasSchemaParsed) {
|
if (!disabledTabs.includes('nodes') && !wasSchemaParsed) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}, [disabledTabs, wereModelsReceived, wasSchemaParsed]);
|
}, [disabledTabs, wasSchemaParsed]);
|
||||||
|
|
||||||
return isApplicationReady;
|
return isApplicationReady;
|
||||||
};
|
};
|
||||||
|
@ -1,3 +0,0 @@
|
|||||||
import { RootState } from 'app/store/store';
|
|
||||||
|
|
||||||
export const modelSelector = (state: RootState) => state.models;
|
|
@ -1,47 +0,0 @@
|
|||||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { CkptModelInfo, DiffusersModelInfo } from 'services/api';
|
|
||||||
import { receivedModels } from 'services/thunks/model';
|
|
||||||
|
|
||||||
export type Model = (CkptModelInfo | DiffusersModelInfo) & {
|
|
||||||
name: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const modelsAdapter = createEntityAdapter<Model>({
|
|
||||||
selectId: (model) => model.name,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
|
|
||||||
export const initialModelsState = modelsAdapter.getInitialState();
|
|
||||||
|
|
||||||
export type ModelsState = typeof initialModelsState;
|
|
||||||
|
|
||||||
export const modelsSlice = createSlice({
|
|
||||||
name: 'models',
|
|
||||||
initialState: initialModelsState,
|
|
||||||
reducers: {
|
|
||||||
modelAdded: modelsAdapter.upsertOne,
|
|
||||||
},
|
|
||||||
extraReducers(builder) {
|
|
||||||
/**
|
|
||||||
* Received Models - FULFILLED
|
|
||||||
*/
|
|
||||||
builder.addCase(receivedModels.fulfilled, (state, action) => {
|
|
||||||
const models = action.payload;
|
|
||||||
modelsAdapter.setAll(state, models);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
export const {
|
|
||||||
selectAll: selectModelsAll,
|
|
||||||
selectById: selectModelsById,
|
|
||||||
selectEntities: selectModelsEntities,
|
|
||||||
selectIds: selectModelsIds,
|
|
||||||
selectTotal: selectModelsTotal,
|
|
||||||
} = modelsAdapter.getSelectors<RootState>((state) => state.models);
|
|
||||||
|
|
||||||
export const { modelAdded } = modelsSlice.actions;
|
|
||||||
|
|
||||||
export default modelsSlice.reducer;
|
|
@ -1,6 +0,0 @@
|
|||||||
import { ModelsState } from './modelSlice';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Models slice persist denylist
|
|
||||||
*/
|
|
||||||
export const modelsPersistDenylist: (keyof ModelsState)[] = ['entities', 'ids'];
|
|
@ -1,20 +1,12 @@
|
|||||||
import { UseToastOptions } from '@chakra-ui/react';
|
import { UseToastOptions } from '@chakra-ui/react';
|
||||||
import { PayloadAction } from '@reduxjs/toolkit';
|
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
|
||||||
import * as InvokeAI from 'app/types/invokeai';
|
import * as InvokeAI from 'app/types/invokeai';
|
||||||
|
|
||||||
import { ProgressImage } from 'services/events/types';
|
|
||||||
import { makeToast } from '../../../app/components/Toaster';
|
|
||||||
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
|
|
||||||
import { receivedModels } from 'services/thunks/model';
|
|
||||||
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
|
|
||||||
import { LogLevelName } from 'roarr';
|
|
||||||
import { InvokeLogLevel } from 'app/logging/useLogger';
|
import { InvokeLogLevel } from 'app/logging/useLogger';
|
||||||
import { TFuncKey } from 'i18next';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
import { userInvoked } from 'app/store/actions';
|
import { userInvoked } from 'app/store/actions';
|
||||||
import { LANGUAGES } from '../components/LanguagePicker';
|
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
|
||||||
import { imageUploaded } from 'services/thunks/image';
|
import { TFuncKey, t } from 'i18next';
|
||||||
|
import { LogLevelName } from 'roarr';
|
||||||
import {
|
import {
|
||||||
appSocketConnected,
|
appSocketConnected,
|
||||||
appSocketDisconnected,
|
appSocketDisconnected,
|
||||||
@ -26,6 +18,11 @@ import {
|
|||||||
appSocketSubscribed,
|
appSocketSubscribed,
|
||||||
appSocketUnsubscribed,
|
appSocketUnsubscribed,
|
||||||
} from 'services/events/actions';
|
} from 'services/events/actions';
|
||||||
|
import { ProgressImage } from 'services/events/types';
|
||||||
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
|
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
|
||||||
|
import { makeToast } from '../../../app/components/Toaster';
|
||||||
|
import { LANGUAGES } from '../components/LanguagePicker';
|
||||||
|
|
||||||
export type CancelStrategy = 'immediate' | 'scheduled';
|
export type CancelStrategy = 'immediate' | 'scheduled';
|
||||||
|
|
||||||
@ -379,13 +376,6 @@ export const systemSlice = createSlice({
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
|
||||||
* Received available models from the backend
|
|
||||||
*/
|
|
||||||
builder.addCase(receivedModels.fulfilled, (state) => {
|
|
||||||
state.wereModelsReceived = true;
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* OpenAPI schema was parsed
|
* OpenAPI schema was parsed
|
||||||
*/
|
*/
|
||||||
|
@ -25,6 +25,8 @@ export type { ConditioningField } from './models/ConditioningField';
|
|||||||
export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation';
|
export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation';
|
||||||
export type { ControlField } from './models/ControlField';
|
export type { ControlField } from './models/ControlField';
|
||||||
export type { ControlNetInvocation } from './models/ControlNetInvocation';
|
export type { ControlNetInvocation } from './models/ControlNetInvocation';
|
||||||
|
export type { ControlNetModelConfig } from './models/ControlNetModelConfig';
|
||||||
|
export type { ControlNetModelFormat } from './models/ControlNetModelFormat';
|
||||||
export type { ControlOutput } from './models/ControlOutput';
|
export type { ControlOutput } from './models/ControlOutput';
|
||||||
export type { CreateModelRequest } from './models/CreateModelRequest';
|
export type { CreateModelRequest } from './models/CreateModelRequest';
|
||||||
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
|
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
|
||||||
@ -67,14 +69,6 @@ export type { InfillTileInvocation } from './models/InfillTileInvocation';
|
|||||||
export type { InpaintInvocation } from './models/InpaintInvocation';
|
export type { InpaintInvocation } from './models/InpaintInvocation';
|
||||||
export type { IntCollectionOutput } from './models/IntCollectionOutput';
|
export type { IntCollectionOutput } from './models/IntCollectionOutput';
|
||||||
export type { IntOutput } from './models/IntOutput';
|
export type { IntOutput } from './models/IntOutput';
|
||||||
export type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config';
|
|
||||||
export type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './models/invokeai__backend__model_management__models__lora__LoRAModel__Config';
|
|
||||||
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig';
|
|
||||||
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig';
|
|
||||||
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig';
|
|
||||||
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig';
|
|
||||||
export type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config';
|
|
||||||
export type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './models/invokeai__backend__model_management__models__vae__VaeModel__Config';
|
|
||||||
export type { IterateInvocation } from './models/IterateInvocation';
|
export type { IterateInvocation } from './models/IterateInvocation';
|
||||||
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
|
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
|
||||||
export type { LatentsField } from './models/LatentsField';
|
export type { LatentsField } from './models/LatentsField';
|
||||||
@ -87,6 +81,8 @@ export type { LoadImageInvocation } from './models/LoadImageInvocation';
|
|||||||
export type { LoraInfo } from './models/LoraInfo';
|
export type { LoraInfo } from './models/LoraInfo';
|
||||||
export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation';
|
export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation';
|
||||||
export type { LoraLoaderOutput } from './models/LoraLoaderOutput';
|
export type { LoraLoaderOutput } from './models/LoraLoaderOutput';
|
||||||
|
export type { LoRAModelConfig } from './models/LoRAModelConfig';
|
||||||
|
export type { LoRAModelFormat } from './models/LoRAModelFormat';
|
||||||
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
|
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
|
||||||
export type { MaskOutput } from './models/MaskOutput';
|
export type { MaskOutput } from './models/MaskOutput';
|
||||||
export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation';
|
export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation';
|
||||||
@ -109,6 +105,8 @@ export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedRe
|
|||||||
export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
|
export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
|
||||||
export type { ParamIntInvocation } from './models/ParamIntInvocation';
|
export type { ParamIntInvocation } from './models/ParamIntInvocation';
|
||||||
export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
|
export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
|
||||||
|
export type { PipelineModelField } from './models/PipelineModelField';
|
||||||
|
export type { PipelineModelLoaderInvocation } from './models/PipelineModelLoaderInvocation';
|
||||||
export type { PromptCollectionOutput } from './models/PromptCollectionOutput';
|
export type { PromptCollectionOutput } from './models/PromptCollectionOutput';
|
||||||
export type { PromptOutput } from './models/PromptOutput';
|
export type { PromptOutput } from './models/PromptOutput';
|
||||||
export type { RandomIntInvocation } from './models/RandomIntInvocation';
|
export type { RandomIntInvocation } from './models/RandomIntInvocation';
|
||||||
@ -120,16 +118,23 @@ export type { ResourceOrigin } from './models/ResourceOrigin';
|
|||||||
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
|
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
|
||||||
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
|
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
|
||||||
export type { SchedulerPredictionType } from './models/SchedulerPredictionType';
|
export type { SchedulerPredictionType } from './models/SchedulerPredictionType';
|
||||||
export type { SD1ModelLoaderInvocation } from './models/SD1ModelLoaderInvocation';
|
|
||||||
export type { SD2ModelLoaderInvocation } from './models/SD2ModelLoaderInvocation';
|
|
||||||
export type { ShowImageInvocation } from './models/ShowImageInvocation';
|
export type { ShowImageInvocation } from './models/ShowImageInvocation';
|
||||||
|
export type { StableDiffusion1ModelCheckpointConfig } from './models/StableDiffusion1ModelCheckpointConfig';
|
||||||
|
export type { StableDiffusion1ModelDiffusersConfig } from './models/StableDiffusion1ModelDiffusersConfig';
|
||||||
|
export type { StableDiffusion1ModelFormat } from './models/StableDiffusion1ModelFormat';
|
||||||
|
export type { StableDiffusion2ModelCheckpointConfig } from './models/StableDiffusion2ModelCheckpointConfig';
|
||||||
|
export type { StableDiffusion2ModelDiffusersConfig } from './models/StableDiffusion2ModelDiffusersConfig';
|
||||||
|
export type { StableDiffusion2ModelFormat } from './models/StableDiffusion2ModelFormat';
|
||||||
export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation';
|
export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation';
|
||||||
export type { SubModelType } from './models/SubModelType';
|
export type { SubModelType } from './models/SubModelType';
|
||||||
export type { SubtractInvocation } from './models/SubtractInvocation';
|
export type { SubtractInvocation } from './models/SubtractInvocation';
|
||||||
export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';
|
export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';
|
||||||
|
export type { TextualInversionModelConfig } from './models/TextualInversionModelConfig';
|
||||||
export type { UNetField } from './models/UNetField';
|
export type { UNetField } from './models/UNetField';
|
||||||
export type { UpscaleInvocation } from './models/UpscaleInvocation';
|
export type { UpscaleInvocation } from './models/UpscaleInvocation';
|
||||||
export type { VaeField } from './models/VaeField';
|
export type { VaeField } from './models/VaeField';
|
||||||
|
export type { VaeModelConfig } from './models/VaeModelConfig';
|
||||||
|
export type { VaeModelFormat } from './models/VaeModelFormat';
|
||||||
export type { VaeRepo } from './models/VaeRepo';
|
export type { VaeRepo } from './models/VaeRepo';
|
||||||
export type { ValidationError } from './models/ValidationError';
|
export type { ValidationError } from './models/ValidationError';
|
||||||
export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation';
|
export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation';
|
||||||
|
@ -0,0 +1,18 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { BaseModelType } from './BaseModelType';
|
||||||
|
import type { ControlNetModelFormat } from './ControlNetModelFormat';
|
||||||
|
import type { ModelError } from './ModelError';
|
||||||
|
|
||||||
|
export type ControlNetModelConfig = {
|
||||||
|
name: string;
|
||||||
|
base_model: BaseModelType;
|
||||||
|
type: 'controlnet';
|
||||||
|
path: string;
|
||||||
|
description?: string;
|
||||||
|
model_format: ControlNetModelFormat;
|
||||||
|
error?: ModelError;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,8 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An enumeration.
|
||||||
|
*/
|
||||||
|
export type ControlNetModelFormat = 'checkpoint' | 'diffusers';
|
@ -49,6 +49,7 @@ import type { OpenposeImageProcessorInvocation } from './OpenposeImageProcessorI
|
|||||||
import type { ParamFloatInvocation } from './ParamFloatInvocation';
|
import type { ParamFloatInvocation } from './ParamFloatInvocation';
|
||||||
import type { ParamIntInvocation } from './ParamIntInvocation';
|
import type { ParamIntInvocation } from './ParamIntInvocation';
|
||||||
import type { PidiImageProcessorInvocation } from './PidiImageProcessorInvocation';
|
import type { PidiImageProcessorInvocation } from './PidiImageProcessorInvocation';
|
||||||
|
import type { PipelineModelLoaderInvocation } from './PipelineModelLoaderInvocation';
|
||||||
import type { RandomIntInvocation } from './RandomIntInvocation';
|
import type { RandomIntInvocation } from './RandomIntInvocation';
|
||||||
import type { RandomRangeInvocation } from './RandomRangeInvocation';
|
import type { RandomRangeInvocation } from './RandomRangeInvocation';
|
||||||
import type { RangeInvocation } from './RangeInvocation';
|
import type { RangeInvocation } from './RangeInvocation';
|
||||||
@ -56,8 +57,6 @@ import type { RangeOfSizeInvocation } from './RangeOfSizeInvocation';
|
|||||||
import type { ResizeLatentsInvocation } from './ResizeLatentsInvocation';
|
import type { ResizeLatentsInvocation } from './ResizeLatentsInvocation';
|
||||||
import type { RestoreFaceInvocation } from './RestoreFaceInvocation';
|
import type { RestoreFaceInvocation } from './RestoreFaceInvocation';
|
||||||
import type { ScaleLatentsInvocation } from './ScaleLatentsInvocation';
|
import type { ScaleLatentsInvocation } from './ScaleLatentsInvocation';
|
||||||
import type { SD1ModelLoaderInvocation } from './SD1ModelLoaderInvocation';
|
|
||||||
import type { SD2ModelLoaderInvocation } from './SD2ModelLoaderInvocation';
|
|
||||||
import type { ShowImageInvocation } from './ShowImageInvocation';
|
import type { ShowImageInvocation } from './ShowImageInvocation';
|
||||||
import type { StepParamEasingInvocation } from './StepParamEasingInvocation';
|
import type { StepParamEasingInvocation } from './StepParamEasingInvocation';
|
||||||
import type { SubtractInvocation } from './SubtractInvocation';
|
import type { SubtractInvocation } from './SubtractInvocation';
|
||||||
@ -73,7 +72,7 @@ export type Graph = {
|
|||||||
/**
|
/**
|
||||||
* The nodes in this graph
|
* The nodes in this graph
|
||||||
*/
|
*/
|
||||||
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation)>;
|
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | PipelineModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation)>;
|
||||||
/**
|
/**
|
||||||
* The connections between nodes and their fields in this graph
|
* The connections between nodes and their fields in this graph
|
||||||
*/
|
*/
|
||||||
|
@ -0,0 +1,18 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { BaseModelType } from './BaseModelType';
|
||||||
|
import type { LoRAModelFormat } from './LoRAModelFormat';
|
||||||
|
import type { ModelError } from './ModelError';
|
||||||
|
|
||||||
|
export type LoRAModelConfig = {
|
||||||
|
name: string;
|
||||||
|
base_model: BaseModelType;
|
||||||
|
type: 'lora';
|
||||||
|
path: string;
|
||||||
|
description?: string;
|
||||||
|
model_format: LoRAModelFormat;
|
||||||
|
error?: ModelError;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,8 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An enumeration.
|
||||||
|
*/
|
||||||
|
export type LoRAModelFormat = 'lycoris' | 'diffusers';
|
@ -2,16 +2,16 @@
|
|||||||
/* tslint:disable */
|
/* tslint:disable */
|
||||||
/* eslint-disable */
|
/* eslint-disable */
|
||||||
|
|
||||||
import type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './invokeai__backend__model_management__models__controlnet__ControlNetModel__Config';
|
import type { ControlNetModelConfig } from './ControlNetModelConfig';
|
||||||
import type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './invokeai__backend__model_management__models__lora__LoRAModel__Config';
|
import type { LoRAModelConfig } from './LoRAModelConfig';
|
||||||
import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig';
|
import type { StableDiffusion1ModelCheckpointConfig } from './StableDiffusion1ModelCheckpointConfig';
|
||||||
import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig';
|
import type { StableDiffusion1ModelDiffusersConfig } from './StableDiffusion1ModelDiffusersConfig';
|
||||||
import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig';
|
import type { StableDiffusion2ModelCheckpointConfig } from './StableDiffusion2ModelCheckpointConfig';
|
||||||
import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig';
|
import type { StableDiffusion2ModelDiffusersConfig } from './StableDiffusion2ModelDiffusersConfig';
|
||||||
import type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config';
|
import type { TextualInversionModelConfig } from './TextualInversionModelConfig';
|
||||||
import type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './invokeai__backend__model_management__models__vae__VaeModel__Config';
|
import type { VaeModelConfig } from './VaeModelConfig';
|
||||||
|
|
||||||
export type ModelsList = {
|
export type ModelsList = {
|
||||||
models: Record<string, Record<string, Record<string, (invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig | invokeai__backend__model_management__models__controlnet__ControlNetModel__Config | invokeai__backend__model_management__models__lora__LoRAModel__Config | invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig | invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config | invokeai__backend__model_management__models__vae__VaeModel__Config | invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig | invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig)>>>;
|
models: Array<(StableDiffusion1ModelCheckpointConfig | StableDiffusion1ModelDiffusersConfig | VaeModelConfig | LoRAModelConfig | ControlNetModelConfig | TextualInversionModelConfig | StableDiffusion2ModelCheckpointConfig | StableDiffusion2ModelDiffusersConfig)>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -0,0 +1,20 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { BaseModelType } from './BaseModelType';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Pipeline model field
|
||||||
|
*/
|
||||||
|
export type PipelineModelField = {
|
||||||
|
/**
|
||||||
|
* Name of the model
|
||||||
|
*/
|
||||||
|
model_name: string;
|
||||||
|
/**
|
||||||
|
* Base model
|
||||||
|
*/
|
||||||
|
base_model: BaseModelType;
|
||||||
|
};
|
||||||
|
|
@ -2,10 +2,12 @@
|
|||||||
/* tslint:disable */
|
/* tslint:disable */
|
||||||
/* eslint-disable */
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { PipelineModelField } from './PipelineModelField';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loading submodels of selected model.
|
* Loads a pipeline model, outputting its submodels.
|
||||||
*/
|
*/
|
||||||
export type SD2ModelLoaderInvocation = {
|
export type PipelineModelLoaderInvocation = {
|
||||||
/**
|
/**
|
||||||
* The id of this node. Must be unique among all nodes.
|
* The id of this node. Must be unique among all nodes.
|
||||||
*/
|
*/
|
||||||
@ -14,10 +16,10 @@ export type SD2ModelLoaderInvocation = {
|
|||||||
* Whether or not this node is an intermediate node.
|
* Whether or not this node is an intermediate node.
|
||||||
*/
|
*/
|
||||||
is_intermediate?: boolean;
|
is_intermediate?: boolean;
|
||||||
type?: 'sd2_model_loader';
|
type?: 'pipeline_model_loader';
|
||||||
/**
|
/**
|
||||||
* Model to load
|
* The model to load
|
||||||
*/
|
*/
|
||||||
model_name?: string;
|
model: PipelineModelField;
|
||||||
};
|
};
|
||||||
|
|
@ -1,23 +0,0 @@
|
|||||||
/* istanbul ignore file */
|
|
||||||
/* tslint:disable */
|
|
||||||
/* eslint-disable */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Loading submodels of selected model.
|
|
||||||
*/
|
|
||||||
export type SD1ModelLoaderInvocation = {
|
|
||||||
/**
|
|
||||||
* The id of this node. Must be unique among all nodes.
|
|
||||||
*/
|
|
||||||
id: string;
|
|
||||||
/**
|
|
||||||
* Whether or not this node is an intermediate node.
|
|
||||||
*/
|
|
||||||
is_intermediate?: boolean;
|
|
||||||
type?: 'sd1_model_loader';
|
|
||||||
/**
|
|
||||||
* Model to load
|
|
||||||
*/
|
|
||||||
model_name?: string;
|
|
||||||
};
|
|
||||||
|
|
@ -2,14 +2,17 @@
|
|||||||
/* tslint:disable */
|
/* tslint:disable */
|
||||||
/* eslint-disable */
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { BaseModelType } from './BaseModelType';
|
||||||
import type { ModelError } from './ModelError';
|
import type { ModelError } from './ModelError';
|
||||||
import type { ModelVariantType } from './ModelVariantType';
|
import type { ModelVariantType } from './ModelVariantType';
|
||||||
|
|
||||||
export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig = {
|
export type StableDiffusion1ModelCheckpointConfig = {
|
||||||
|
name: string;
|
||||||
|
base_model: BaseModelType;
|
||||||
|
type: 'pipeline';
|
||||||
path: string;
|
path: string;
|
||||||
description?: string;
|
description?: string;
|
||||||
format: 'checkpoint';
|
model_format: 'checkpoint';
|
||||||
default?: boolean;
|
|
||||||
error?: ModelError;
|
error?: ModelError;
|
||||||
vae?: string;
|
vae?: string;
|
||||||
config?: string;
|
config?: string;
|
@ -2,14 +2,17 @@
|
|||||||
/* tslint:disable */
|
/* tslint:disable */
|
||||||
/* eslint-disable */
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { BaseModelType } from './BaseModelType';
|
||||||
import type { ModelError } from './ModelError';
|
import type { ModelError } from './ModelError';
|
||||||
import type { ModelVariantType } from './ModelVariantType';
|
import type { ModelVariantType } from './ModelVariantType';
|
||||||
|
|
||||||
export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig = {
|
export type StableDiffusion1ModelDiffusersConfig = {
|
||||||
|
name: string;
|
||||||
|
base_model: BaseModelType;
|
||||||
|
type: 'pipeline';
|
||||||
path: string;
|
path: string;
|
||||||
description?: string;
|
description?: string;
|
||||||
format: 'diffusers';
|
model_format: 'diffusers';
|
||||||
default?: boolean;
|
|
||||||
error?: ModelError;
|
error?: ModelError;
|
||||||
vae?: string;
|
vae?: string;
|
||||||
variant: ModelVariantType;
|
variant: ModelVariantType;
|
@ -0,0 +1,8 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An enumeration.
|
||||||
|
*/
|
||||||
|
export type StableDiffusion1ModelFormat = 'checkpoint' | 'diffusers';
|
@ -2,15 +2,18 @@
|
|||||||
/* tslint:disable */
|
/* tslint:disable */
|
||||||
/* eslint-disable */
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { BaseModelType } from './BaseModelType';
|
||||||
import type { ModelError } from './ModelError';
|
import type { ModelError } from './ModelError';
|
||||||
import type { ModelVariantType } from './ModelVariantType';
|
import type { ModelVariantType } from './ModelVariantType';
|
||||||
import type { SchedulerPredictionType } from './SchedulerPredictionType';
|
import type { SchedulerPredictionType } from './SchedulerPredictionType';
|
||||||
|
|
||||||
export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig = {
|
export type StableDiffusion2ModelCheckpointConfig = {
|
||||||
|
name: string;
|
||||||
|
base_model: BaseModelType;
|
||||||
|
type: 'pipeline';
|
||||||
path: string;
|
path: string;
|
||||||
description?: string;
|
description?: string;
|
||||||
format: 'checkpoint';
|
model_format: 'checkpoint';
|
||||||
default?: boolean;
|
|
||||||
error?: ModelError;
|
error?: ModelError;
|
||||||
vae?: string;
|
vae?: string;
|
||||||
config?: string;
|
config?: string;
|
@ -2,15 +2,18 @@
|
|||||||
/* tslint:disable */
|
/* tslint:disable */
|
||||||
/* eslint-disable */
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { BaseModelType } from './BaseModelType';
|
||||||
import type { ModelError } from './ModelError';
|
import type { ModelError } from './ModelError';
|
||||||
import type { ModelVariantType } from './ModelVariantType';
|
import type { ModelVariantType } from './ModelVariantType';
|
||||||
import type { SchedulerPredictionType } from './SchedulerPredictionType';
|
import type { SchedulerPredictionType } from './SchedulerPredictionType';
|
||||||
|
|
||||||
export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig = {
|
export type StableDiffusion2ModelDiffusersConfig = {
|
||||||
|
name: string;
|
||||||
|
base_model: BaseModelType;
|
||||||
|
type: 'pipeline';
|
||||||
path: string;
|
path: string;
|
||||||
description?: string;
|
description?: string;
|
||||||
format: 'diffusers';
|
model_format: 'diffusers';
|
||||||
default?: boolean;
|
|
||||||
error?: ModelError;
|
error?: ModelError;
|
||||||
vae?: string;
|
vae?: string;
|
||||||
variant: ModelVariantType;
|
variant: ModelVariantType;
|
@ -0,0 +1,8 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An enumeration.
|
||||||
|
*/
|
||||||
|
export type StableDiffusion2ModelFormat = 'checkpoint' | 'diffusers';
|
@ -0,0 +1,17 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { BaseModelType } from './BaseModelType';
|
||||||
|
import type { ModelError } from './ModelError';
|
||||||
|
|
||||||
|
export type TextualInversionModelConfig = {
|
||||||
|
name: string;
|
||||||
|
base_model: BaseModelType;
|
||||||
|
type: 'embedding';
|
||||||
|
path: string;
|
||||||
|
description?: string;
|
||||||
|
model_format: null;
|
||||||
|
error?: ModelError;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,18 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { BaseModelType } from './BaseModelType';
|
||||||
|
import type { ModelError } from './ModelError';
|
||||||
|
import type { VaeModelFormat } from './VaeModelFormat';
|
||||||
|
|
||||||
|
export type VaeModelConfig = {
|
||||||
|
name: string;
|
||||||
|
base_model: BaseModelType;
|
||||||
|
type: 'vae';
|
||||||
|
path: string;
|
||||||
|
description?: string;
|
||||||
|
model_format: VaeModelFormat;
|
||||||
|
error?: ModelError;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,8 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An enumeration.
|
||||||
|
*/
|
||||||
|
export type VaeModelFormat = 'checkpoint' | 'diffusers';
|
@ -1,14 +0,0 @@
|
|||||||
/* istanbul ignore file */
|
|
||||||
/* tslint:disable */
|
|
||||||
/* eslint-disable */
|
|
||||||
|
|
||||||
import type { ModelError } from './ModelError';
|
|
||||||
|
|
||||||
export type invokeai__backend__model_management__models__controlnet__ControlNetModel__Config = {
|
|
||||||
path: string;
|
|
||||||
description?: string;
|
|
||||||
format: ('checkpoint' | 'diffusers');
|
|
||||||
default?: boolean;
|
|
||||||
error?: ModelError;
|
|
||||||
};
|
|
||||||
|
|
@ -1,14 +0,0 @@
|
|||||||
/* istanbul ignore file */
|
|
||||||
/* tslint:disable */
|
|
||||||
/* eslint-disable */
|
|
||||||
|
|
||||||
import type { ModelError } from './ModelError';
|
|
||||||
|
|
||||||
export type invokeai__backend__model_management__models__lora__LoRAModel__Config = {
|
|
||||||
path: string;
|
|
||||||
description?: string;
|
|
||||||
format: ('lycoris' | 'diffusers');
|
|
||||||
default?: boolean;
|
|
||||||
error?: ModelError;
|
|
||||||
};
|
|
||||||
|
|
@ -1,14 +0,0 @@
|
|||||||
/* istanbul ignore file */
|
|
||||||
/* tslint:disable */
|
|
||||||
/* eslint-disable */
|
|
||||||
|
|
||||||
import type { ModelError } from './ModelError';
|
|
||||||
|
|
||||||
export type invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config = {
|
|
||||||
path: string;
|
|
||||||
description?: string;
|
|
||||||
format: null;
|
|
||||||
default?: boolean;
|
|
||||||
error?: ModelError;
|
|
||||||
};
|
|
||||||
|
|
@ -1,14 +0,0 @@
|
|||||||
/* istanbul ignore file */
|
|
||||||
/* tslint:disable */
|
|
||||||
/* eslint-disable */
|
|
||||||
|
|
||||||
import type { ModelError } from './ModelError';
|
|
||||||
|
|
||||||
export type invokeai__backend__model_management__models__vae__VaeModel__Config = {
|
|
||||||
path: string;
|
|
||||||
description?: string;
|
|
||||||
format: ('checkpoint' | 'diffusers');
|
|
||||||
default?: boolean;
|
|
||||||
error?: ModelError;
|
|
||||||
};
|
|
||||||
|
|
@ -51,6 +51,7 @@ import type { PaginatedResults_GraphExecutionState_ } from '../models/PaginatedR
|
|||||||
import type { ParamFloatInvocation } from '../models/ParamFloatInvocation';
|
import type { ParamFloatInvocation } from '../models/ParamFloatInvocation';
|
||||||
import type { ParamIntInvocation } from '../models/ParamIntInvocation';
|
import type { ParamIntInvocation } from '../models/ParamIntInvocation';
|
||||||
import type { PidiImageProcessorInvocation } from '../models/PidiImageProcessorInvocation';
|
import type { PidiImageProcessorInvocation } from '../models/PidiImageProcessorInvocation';
|
||||||
|
import type { PipelineModelLoaderInvocation } from '../models/PipelineModelLoaderInvocation';
|
||||||
import type { RandomIntInvocation } from '../models/RandomIntInvocation';
|
import type { RandomIntInvocation } from '../models/RandomIntInvocation';
|
||||||
import type { RandomRangeInvocation } from '../models/RandomRangeInvocation';
|
import type { RandomRangeInvocation } from '../models/RandomRangeInvocation';
|
||||||
import type { RangeInvocation } from '../models/RangeInvocation';
|
import type { RangeInvocation } from '../models/RangeInvocation';
|
||||||
@ -58,8 +59,6 @@ import type { RangeOfSizeInvocation } from '../models/RangeOfSizeInvocation';
|
|||||||
import type { ResizeLatentsInvocation } from '../models/ResizeLatentsInvocation';
|
import type { ResizeLatentsInvocation } from '../models/ResizeLatentsInvocation';
|
||||||
import type { RestoreFaceInvocation } from '../models/RestoreFaceInvocation';
|
import type { RestoreFaceInvocation } from '../models/RestoreFaceInvocation';
|
||||||
import type { ScaleLatentsInvocation } from '../models/ScaleLatentsInvocation';
|
import type { ScaleLatentsInvocation } from '../models/ScaleLatentsInvocation';
|
||||||
import type { SD1ModelLoaderInvocation } from '../models/SD1ModelLoaderInvocation';
|
|
||||||
import type { SD2ModelLoaderInvocation } from '../models/SD2ModelLoaderInvocation';
|
|
||||||
import type { ShowImageInvocation } from '../models/ShowImageInvocation';
|
import type { ShowImageInvocation } from '../models/ShowImageInvocation';
|
||||||
import type { StepParamEasingInvocation } from '../models/StepParamEasingInvocation';
|
import type { StepParamEasingInvocation } from '../models/StepParamEasingInvocation';
|
||||||
import type { SubtractInvocation } from '../models/SubtractInvocation';
|
import type { SubtractInvocation } from '../models/SubtractInvocation';
|
||||||
@ -175,7 +174,7 @@ export class SessionsService {
|
|||||||
* The id of the session
|
* The id of the session
|
||||||
*/
|
*/
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation),
|
requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | PipelineModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation),
|
||||||
}): CancelablePromise<string> {
|
}): CancelablePromise<string> {
|
||||||
return __request(OpenAPI, {
|
return __request(OpenAPI, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
@ -212,7 +211,7 @@ export class SessionsService {
|
|||||||
* The path to the node in the graph
|
* The path to the node in the graph
|
||||||
*/
|
*/
|
||||||
nodePath: string,
|
nodePath: string,
|
||||||
requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation),
|
requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | PipelineModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation),
|
||||||
}): CancelablePromise<GraphExecutionState> {
|
}): CancelablePromise<GraphExecutionState> {
|
||||||
return __request(OpenAPI, {
|
return __request(OpenAPI, {
|
||||||
method: 'PUT',
|
method: 'PUT',
|
||||||
|
@ -13,23 +13,68 @@ import {
|
|||||||
TagTypesFrom,
|
TagTypesFrom,
|
||||||
TagTypesFromApi,
|
TagTypesFromApi,
|
||||||
} from '@reduxjs/toolkit/dist/query/endpointDefinitions';
|
} from '@reduxjs/toolkit/dist/query/endpointDefinitions';
|
||||||
|
import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
|
||||||
|
import { BaseModelType } from './api/models/BaseModelType';
|
||||||
|
import { ModelType } from './api/models/ModelType';
|
||||||
|
import { ModelsList } from './api/models/ModelsList';
|
||||||
|
import { keyBy } from 'lodash-es';
|
||||||
|
|
||||||
type ListBoardsArg = { offset: number; limit: number };
|
type ListBoardsArg = { offset: number; limit: number };
|
||||||
type UpdateBoardArg = { board_id: string; changes: BoardChanges };
|
type UpdateBoardArg = { board_id: string; changes: BoardChanges };
|
||||||
type AddImageToBoardArg = { board_id: string; image_name: string };
|
type AddImageToBoardArg = { board_id: string; image_name: string };
|
||||||
type RemoveImageFromBoardArg = { board_id: string; image_name: string };
|
type RemoveImageFromBoardArg = { board_id: string; image_name: string };
|
||||||
type ListBoardImagesArg = { board_id: string; offset: number; limit: number };
|
type ListBoardImagesArg = { board_id: string; offset: number; limit: number };
|
||||||
|
type ListModelsArg = { base_model?: BaseModelType; model_type?: ModelType };
|
||||||
|
|
||||||
const tagTypes = ['Board', 'Image'];
|
type ModelConfig = ModelsList['models'][number];
|
||||||
|
|
||||||
|
const tagTypes = ['Board', 'Image', 'Model'];
|
||||||
type ApiFullTagDescription = FullTagDescription<(typeof tagTypes)[number]>;
|
type ApiFullTagDescription = FullTagDescription<(typeof tagTypes)[number]>;
|
||||||
|
|
||||||
const LIST = 'LIST';
|
const LIST = 'LIST';
|
||||||
|
|
||||||
|
const modelsAdapter = createEntityAdapter<ModelConfig>({
|
||||||
|
selectId: (model) => getModelId(model),
|
||||||
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
|
});
|
||||||
|
|
||||||
|
const getModelId = ({ base_model, type, name }: ModelConfig) =>
|
||||||
|
`${base_model}/${type}/${name}`;
|
||||||
|
|
||||||
export const api = createApi({
|
export const api = createApi({
|
||||||
baseQuery: fetchBaseQuery({ baseUrl: 'http://localhost:5173/api/v1/' }),
|
baseQuery: fetchBaseQuery({ baseUrl: 'http://localhost:5173/api/v1/' }),
|
||||||
reducerPath: 'api',
|
reducerPath: 'api',
|
||||||
tagTypes,
|
tagTypes,
|
||||||
endpoints: (build) => ({
|
endpoints: (build) => ({
|
||||||
|
/**
|
||||||
|
* Models Queries
|
||||||
|
*/
|
||||||
|
|
||||||
|
listModels: build.query<EntityState<ModelConfig>, ListModelsArg>({
|
||||||
|
query: (arg) => ({ url: 'models/', params: arg }),
|
||||||
|
providesTags: (result, error, arg) => {
|
||||||
|
// any list of boards
|
||||||
|
const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST }];
|
||||||
|
|
||||||
|
if (result) {
|
||||||
|
// and individual tags for each board
|
||||||
|
tags.push(
|
||||||
|
...result.ids.map((id) => ({
|
||||||
|
type: 'Model' as const,
|
||||||
|
id,
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return tags;
|
||||||
|
},
|
||||||
|
transformResponse: (response: ModelsList, meta, arg) => {
|
||||||
|
return modelsAdapter.addMany(
|
||||||
|
modelsAdapter.getInitialState(),
|
||||||
|
keyBy(response.models, getModelId)
|
||||||
|
);
|
||||||
|
},
|
||||||
|
}),
|
||||||
/**
|
/**
|
||||||
* Boards Queries
|
* Boards Queries
|
||||||
*/
|
*/
|
||||||
@ -174,4 +219,5 @@ export const {
|
|||||||
useRemoveImageFromBoardMutation,
|
useRemoveImageFromBoardMutation,
|
||||||
useListBoardImagesQuery,
|
useListBoardImagesQuery,
|
||||||
useGetImageDTOQuery,
|
useGetImageDTOQuery,
|
||||||
|
useListModelsQuery,
|
||||||
} = api;
|
} = api;
|
||||||
|
@ -1,33 +0,0 @@
|
|||||||
import { log } from 'app/logging/useLogger';
|
|
||||||
import { createAppAsyncThunk } from 'app/store/storeUtils';
|
|
||||||
import { Model } from 'features/system/store/modelSlice';
|
|
||||||
import { reduce, size } from 'lodash-es';
|
|
||||||
import { ModelsService } from 'services/api';
|
|
||||||
|
|
||||||
const models = log.child({ namespace: 'model' });
|
|
||||||
|
|
||||||
export const IMAGES_PER_PAGE = 20;
|
|
||||||
|
|
||||||
export const receivedModels = createAppAsyncThunk(
|
|
||||||
'models/receivedModels',
|
|
||||||
async (_) => {
|
|
||||||
const response = await ModelsService.listModels();
|
|
||||||
|
|
||||||
const deserializedModels = reduce(
|
|
||||||
response.models['sd-1']['pipeline'],
|
|
||||||
(modelsAccumulator, model, modelName) => {
|
|
||||||
modelsAccumulator[modelName] = { ...model, name: modelName };
|
|
||||||
|
|
||||||
return modelsAccumulator;
|
|
||||||
},
|
|
||||||
{} as Record<string, Model>
|
|
||||||
);
|
|
||||||
|
|
||||||
models.info(
|
|
||||||
{ response },
|
|
||||||
`Received ${size(response.models['sd-1']['pipeline'])} models`
|
|
||||||
);
|
|
||||||
|
|
||||||
return deserializedModels;
|
|
||||||
}
|
|
||||||
);
|
|
@ -4,8 +4,8 @@ import { generateColorPalette } from '../util/generateColorPalette';
|
|||||||
export const greenTeaThemeColors: InvokeAIThemeColors = {
|
export const greenTeaThemeColors: InvokeAIThemeColors = {
|
||||||
base: generateColorPalette(223, 10),
|
base: generateColorPalette(223, 10),
|
||||||
baseAlpha: generateColorPalette(223, 10, false, true),
|
baseAlpha: generateColorPalette(223, 10, false, true),
|
||||||
accent: generateColorPalette(155, 80),
|
accent: generateColorPalette(160, 60),
|
||||||
accentAlpha: generateColorPalette(155, 80, false, true),
|
accentAlpha: generateColorPalette(160, 60, false, true),
|
||||||
working: generateColorPalette(47, 68),
|
working: generateColorPalette(47, 68),
|
||||||
workingAlpha: generateColorPalette(47, 68, false, true),
|
workingAlpha: generateColorPalette(47, 68, false, true),
|
||||||
warning: generateColorPalette(28, 75),
|
warning: generateColorPalette(28, 75),
|
||||||
@ -14,5 +14,5 @@ export const greenTeaThemeColors: InvokeAIThemeColors = {
|
|||||||
okAlpha: generateColorPalette(122, 49, false, true),
|
okAlpha: generateColorPalette(122, 49, false, true),
|
||||||
error: generateColorPalette(0, 50),
|
error: generateColorPalette(0, 50),
|
||||||
errorAlpha: generateColorPalette(0, 50, false, true),
|
errorAlpha: generateColorPalette(0, 50, false, true),
|
||||||
gridLineColor: 'rgba(255, 255, 255, 0.2)',
|
gridLineColor: 'rgba(255, 255, 255, 0.15)',
|
||||||
};
|
};
|
||||||
|
@ -2,8 +2,8 @@ import { InvokeAIThemeColors } from 'theme/themeTypes';
|
|||||||
import { generateColorPalette } from 'theme/util/generateColorPalette';
|
import { generateColorPalette } from 'theme/util/generateColorPalette';
|
||||||
|
|
||||||
export const invokeAIThemeColors: InvokeAIThemeColors = {
|
export const invokeAIThemeColors: InvokeAIThemeColors = {
|
||||||
base: generateColorPalette(225, 15),
|
base: generateColorPalette(220, 15),
|
||||||
baseAlpha: generateColorPalette(225, 15, false, true),
|
baseAlpha: generateColorPalette(220, 15, false, true),
|
||||||
accent: generateColorPalette(250, 50),
|
accent: generateColorPalette(250, 50),
|
||||||
accentAlpha: generateColorPalette(250, 50, false, true),
|
accentAlpha: generateColorPalette(250, 50, false, true),
|
||||||
working: generateColorPalette(47, 67),
|
working: generateColorPalette(47, 67),
|
||||||
@ -14,5 +14,5 @@ export const invokeAIThemeColors: InvokeAIThemeColors = {
|
|||||||
okAlpha: generateColorPalette(113, 70, false, true),
|
okAlpha: generateColorPalette(113, 70, false, true),
|
||||||
error: generateColorPalette(0, 76),
|
error: generateColorPalette(0, 76),
|
||||||
errorAlpha: generateColorPalette(0, 76, false, true),
|
errorAlpha: generateColorPalette(0, 76, false, true),
|
||||||
gridLineColor: 'rgba(255, 255, 255, 0.2)',
|
gridLineColor: 'rgba(150, 150, 180, 0.15)',
|
||||||
};
|
};
|
||||||
|
@ -14,5 +14,5 @@ export const lightThemeColors: InvokeAIThemeColors = {
|
|||||||
okAlpha: generateColorPalette(122, 49, true, true),
|
okAlpha: generateColorPalette(122, 49, true, true),
|
||||||
error: generateColorPalette(0, 50, true),
|
error: generateColorPalette(0, 50, true),
|
||||||
errorAlpha: generateColorPalette(0, 50, true, true),
|
errorAlpha: generateColorPalette(0, 50, true, true),
|
||||||
gridLineColor: 'rgba(0, 0, 0, 0.2)',
|
gridLineColor: 'rgba(0, 0, 0, 0.15)',
|
||||||
};
|
};
|
||||||
|
@ -14,5 +14,5 @@ export const oceanBlueColors: InvokeAIThemeColors = {
|
|||||||
okAlpha: generateColorPalette(122, 49, false, true),
|
okAlpha: generateColorPalette(122, 49, false, true),
|
||||||
error: generateColorPalette(0, 100),
|
error: generateColorPalette(0, 100),
|
||||||
errorAlpha: generateColorPalette(0, 100, false, true),
|
errorAlpha: generateColorPalette(0, 100, false, true),
|
||||||
gridLineColor: 'rgba(136, 148, 184, 0.2)',
|
gridLineColor: 'rgba(136, 148, 184, 0.15)',
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user