mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
remove defunct code
This commit is contained in:
commit
8e1a56875e
@ -1,12 +1,12 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
||||
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
from typing import Annotated, Literal, Optional, Union, Dict
|
||||
|
||||
from fastapi import Query
|
||||
from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend import SDModelType
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
@ -60,7 +60,8 @@ class ConvertedModelResponse(BaseModel):
|
||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
|
||||
models: Dict[BaseModelType, Dict[ModelType, Dict[str, dict]]] # TODO: collect all configs
|
||||
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
|
||||
|
||||
|
||||
@models_router.get(
|
||||
@ -69,9 +70,12 @@ class ModelsList(BaseModel):
|
||||
responses={200: {"model": ModelsList }},
|
||||
)
|
||||
async def list_models(
|
||||
model_type: SDModelType = Query(
|
||||
default=None, description="The type of model to get"
|
||||
),
|
||||
base_model: BaseModelType = Query(
|
||||
default=None, description="Base model"
|
||||
),
|
||||
model_type: ModelType = Query(
|
||||
default=None, description="The type of model to get"
|
||||
),
|
||||
) -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(model_type)
|
||||
|
@ -8,7 +8,7 @@ from .model import ClipField
|
||||
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.model_management import SDModelType
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
|
||||
from compel import Compel
|
||||
@ -76,7 +76,11 @@ class CompelInvocation(BaseInvocation):
|
||||
try:
|
||||
ti_list.append(
|
||||
stack.enter_context(
|
||||
context.services.model_manager.get_model(model_name=name, model_type=SDModelType.TextualInversion)
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
|
@ -5,12 +5,13 @@ import copy
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.model_management import SDModelType
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
model_name: str = Field(description="Info to load submodel")
|
||||
model_type: SDModelType = Field(description="Info to load submodel")
|
||||
submodel: Optional[SDModelType] = Field(description="Info to load submodel")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Info to load submodel")
|
||||
submodel: Optional[SubModelType] = Field(description="Info to load submodel")
|
||||
|
||||
class LoraInfo(ModelInfo):
|
||||
weight: float = Field(description="Lora's weight which to use when apply to model")
|
||||
@ -63,10 +64,13 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
|
||||
base_model = BaseModelType.StableDiffusion2 # TODO:
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
):
|
||||
raise Exception(f"Unkown model name: {self.model_name}!")
|
||||
|
||||
@ -104,12 +108,14 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SDModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
@ -117,12 +123,14 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
@ -130,7 +138,8 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SDModelType.Vae,
|
||||
),
|
||||
)
|
||||
|
@ -3,7 +3,7 @@
|
||||
from typing import Any
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.app.services.model_manager_service import SDModelType, SDModelInfo
|
||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, SDModelInfo
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
|
||||
class EventServiceBase:
|
||||
@ -109,8 +109,9 @@ class EventServiceBase:
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
submodel: SDModelType,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: SubModelType,
|
||||
) -> None:
|
||||
"""Emitted when a model is requested"""
|
||||
self.__emit_session_event(
|
||||
@ -120,6 +121,7 @@ class EventServiceBase:
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
),
|
||||
@ -131,8 +133,9 @@ class EventServiceBase:
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
submodel: SDModelType,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: SubModelType,
|
||||
model_info: SDModelInfo,
|
||||
) -> None:
|
||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||
@ -143,6 +146,7 @@ class EventServiceBase:
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
|
@ -10,7 +10,9 @@ from dataclasses import dataclass
|
||||
|
||||
from invokeai.backend.model_management.model_manager import (
|
||||
ModelManager,
|
||||
SDModelType,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
SDModelInfo,
|
||||
)
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
@ -20,12 +22,6 @@ from ...backend.util import choose_precision, choose_torch_device
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||
|
||||
@dataclass
|
||||
class LastUsedModel:
|
||||
model_name: str=None
|
||||
model_type: SDModelType=None
|
||||
|
||||
last_used_model = LastUsedModel()
|
||||
|
||||
class ModelManagerServiceBase(ABC):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
@ -48,8 +44,9 @@ class ModelManagerServiceBase(ABC):
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType = SDModelType.Diffusers,
|
||||
submodel: Optional[SDModelType] = None,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> SDModelInfo:
|
||||
@ -67,12 +64,13 @@ class ModelManagerServiceBase(ABC):
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
|
||||
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns the name and typeof the default model, or None
|
||||
if none is defined.
|
||||
@ -80,26 +78,26 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_default_model(self, model_name: str, model_type: SDModelType):
|
||||
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
|
||||
"""Sets the default model to the indicated name."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_info(self, model_name: str, model_type: SDModelType) -> dict:
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self) -> List[Tuple[str, SDModelType]]:
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(self, model_type: SDModelType=None) -> dict:
|
||||
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_type1:
|
||||
@ -122,7 +120,8 @@ class ModelManagerServiceBase(ABC):
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False
|
||||
) -> None:
|
||||
@ -139,7 +138,8 @@ class ModelManagerServiceBase(ABC):
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
delete_files: bool = False,
|
||||
):
|
||||
"""
|
||||
@ -297,8 +297,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType = SDModelType.Diffusers,
|
||||
submodel: Optional[SDModelType] = None,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> SDModelInfo:
|
||||
@ -306,23 +307,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
Retrieve the indicated model. submodel can be used to get a
|
||||
part (such as the vae) of a diffusers mode.
|
||||
"""
|
||||
|
||||
# Temporary hack here: we remember the last model fetched
|
||||
# so that when executing a graph, the first node called gets
|
||||
# to set default model for subsequent nodes in the event that
|
||||
# they do not set the model explicitly. This should be
|
||||
# displaced by model loader mechanism.
|
||||
# This is to work around lack of model loader at current time,
|
||||
# which was causing inconsistent model usage throughout graph.
|
||||
global last_used_model
|
||||
|
||||
if not model_name:
|
||||
self.logger.debug('No model name provided, defaulting to last loaded model')
|
||||
model_name = last_used_model.model_name
|
||||
model_type = model_type or last_used_model.model_type
|
||||
else:
|
||||
last_used_model.model_name = model_name
|
||||
last_used_model.model_type = model_type
|
||||
|
||||
# if we are called from within a node, then we get to emit
|
||||
# load start and complete events
|
||||
@ -331,12 +315,14 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
node=node,
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel
|
||||
submodel=submodel,
|
||||
)
|
||||
|
||||
model_info = self.mgr.get_model(
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
submodel,
|
||||
)
|
||||
@ -346,6 +332,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
node=node,
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info
|
||||
@ -356,7 +343,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
@ -364,33 +352,38 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
return self.mgr.model_exists(
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
)
|
||||
|
||||
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
|
||||
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
"""
|
||||
return self.mgr.default_model()
|
||||
|
||||
def set_default_model(self, model_name: str, model_type: SDModelType):
|
||||
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
|
||||
"""Sets the default model to the indicated name."""
|
||||
self.mgr.set_default_model(model_name)
|
||||
self.mgr.set_default_model(model_name, base_model, model_type)
|
||||
|
||||
def model_info(self, model_name: str, model_type: SDModelType) -> dict:
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
return self.mgr.model_info(model_name)
|
||||
return self.mgr.model_info(model_name, base_model, model_type)
|
||||
|
||||
def model_names(self) -> List[Tuple[str, SDModelType]]:
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
return self.mgr.model_names()
|
||||
|
||||
def list_models(self, model_type: SDModelType=None) -> dict:
|
||||
def list_models(
|
||||
self,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_type1:
|
||||
@ -406,12 +399,13 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
{ model_name_n: etc
|
||||
}
|
||||
"""
|
||||
return self.mgr.list_models()
|
||||
return self.mgr.list_models(base_model, model_type)
|
||||
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
)->None:
|
||||
@ -422,13 +416,14 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
return self.mgr.add_model(model_name, model_type, model_attributes, clobber)
|
||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType = SDModelType.Diffusers,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
delete_files: bool = False,
|
||||
):
|
||||
"""
|
||||
@ -436,7 +431,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
"""
|
||||
self.mgr.del_model(model_name, model_type, delete_files)
|
||||
self.mgr.del_model(model_name, base_model, model_type, delete_files)
|
||||
|
||||
def import_diffuser_model(
|
||||
self,
|
||||
@ -541,8 +536,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
node,
|
||||
context,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
submodel: SDModelType,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: SubModelType,
|
||||
model_info: Optional[SDModelInfo] = None,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
@ -555,6 +551,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
node=node.dict(),
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info
|
||||
@ -565,6 +562,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
node=node.dict(),
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
)
|
||||
|
@ -16,13 +16,14 @@ class RestorationServices:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if args.restore or args.esrgan:
|
||||
restoration = Restoration()
|
||||
if args.restore:
|
||||
# TODO: redo for new model structure
|
||||
if False and args.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(
|
||||
args.gfpgan_model_path
|
||||
)
|
||||
else:
|
||||
logger.info("Face restoration disabled")
|
||||
if args.esrgan:
|
||||
if False and args.esrgan:
|
||||
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
||||
else:
|
||||
logger.info("Upscaling disabled")
|
||||
|
@ -9,5 +9,5 @@ from .generator import (
|
||||
Img2Img,
|
||||
Inpaint
|
||||
)
|
||||
from .model_management import ModelManager, ModelCache, ModelType, ModelInfo
|
||||
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubmodelType, ModelInfo
|
||||
from .safety_checker import SafetyChecker
|
||||
|
@ -3,4 +3,4 @@ Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .model_manager import ModelManager, ModelInfo
|
||||
from .model_cache import ModelCache
|
||||
from .models import ModelType
|
||||
from .models import BaseModelType, ModelType, SubModelType
|
||||
|
@ -30,7 +30,9 @@ import torch
|
||||
from diffusers import logging as diffusers_logging
|
||||
from transformers import logging as transformers_logging
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .models import ModelType, SubModelType, ModelBase
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
from .lora import LoRAModel, TextualInversionModel
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
||||
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
@ -122,11 +124,12 @@ class ModelCache(object):
|
||||
def get_key(
|
||||
self,
|
||||
model_path: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel_type: Optional[ModelType] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
):
|
||||
|
||||
key = f"{model_path}:{model_type}"
|
||||
key = f"{model_path}:{base_model}:{model_type}"
|
||||
if submodel_type:
|
||||
key += f":{submodel_type}"
|
||||
return key
|
||||
@ -145,10 +148,13 @@ class ModelCache(object):
|
||||
self,
|
||||
model_path: str,
|
||||
model_class: Type[ModelBase],
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
model_info_key = self.get_key(
|
||||
model_path=model_path,
|
||||
model_type=model_class,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel_type=None,
|
||||
)
|
||||
|
||||
@ -165,6 +171,8 @@ class ModelCache(object):
|
||||
self,
|
||||
model_path: Union[str, Path],
|
||||
model_class: Type[ModelBase],
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
gpu_load: bool = True,
|
||||
) -> Any:
|
||||
@ -178,17 +186,20 @@ class ModelCache(object):
|
||||
model_info = self._get_model_info(
|
||||
model_path=model_path,
|
||||
model_class=model_class,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
key = self.get_key(
|
||||
model_path=model_path,
|
||||
model_type=model_class, # TODO:
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel_type=submodel,
|
||||
)
|
||||
|
||||
# TODO: lock for no copies on simultaneous calls?
|
||||
cache_entry = self._cached_models.get(key, None)
|
||||
if cache_entry is None:
|
||||
self.logger.info(f'Loading model {model_path}, type {model_class}:{submodel}')
|
||||
self.logger.info(f'Loading model {model_path}, type {base_model}:{model_type}:{submodel}')
|
||||
|
||||
# this will remove older cached models until
|
||||
# there is sufficient room to load the requested model
|
||||
|
@ -160,6 +160,8 @@ from huggingface_hub import scan_cache_dir
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
|
||||
@ -184,9 +186,10 @@ class ModelCache(object):
|
||||
class ModelInfo():
|
||||
context: ModelLocker
|
||||
name: str
|
||||
base_model: BaseModelType
|
||||
type: ModelType
|
||||
hash: str
|
||||
location: Union[Path,str]
|
||||
location: Union[Path, str]
|
||||
precision: torch.dtype
|
||||
revision: str = None
|
||||
_cache: ModelCache = None
|
||||
@ -222,6 +225,9 @@ MAX_CACHE_SIZE = 6.0 # GB
|
||||
# └── realesrgan
|
||||
|
||||
|
||||
class ConfigMeta(BaseModel):
|
||||
version: str
|
||||
|
||||
class ModelManager(object):
|
||||
"""
|
||||
High-level interface to model management.
|
||||
@ -229,6 +235,38 @@ class ModelManager(object):
|
||||
|
||||
logger: types.ModuleType = logger
|
||||
|
||||
# TODO:
|
||||
def _convert_2_3_models(self, config: DictConfig):
|
||||
for model_name, model_config in config.items():
|
||||
if model_config["format"] == "diffusers":
|
||||
pass
|
||||
elif model_config["format"] == "ckpt":
|
||||
|
||||
if any(model_config["config"].endswith(file) for file in {
|
||||
"v1-finetune.yaml",
|
||||
"v1-finetune_style.yaml",
|
||||
"v1-inference.yaml",
|
||||
"v1-inpainting-inference.yaml",
|
||||
"v1-m1-finetune.yaml",
|
||||
}):
|
||||
# copy as as sd1.5
|
||||
pass
|
||||
|
||||
# ~99% accurate should be
|
||||
elif model_config["config"].endswith("v2-inference-v.yaml"):
|
||||
# copy as sd 2.x (768)
|
||||
pass
|
||||
|
||||
# for real don't know how accurate it
|
||||
elif model_config["config"].endswith("v2-inference.yaml"):
|
||||
# copy as sd 2.x-base (512)
|
||||
pass
|
||||
|
||||
else:
|
||||
# TODO:
|
||||
raise Exception("Unknown model")
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Union[Path, DictConfig, str],
|
||||
@ -244,18 +282,29 @@ class ModelManager(object):
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
if isinstance(config, DictConfig):
|
||||
self.config_path = None
|
||||
self.config = config
|
||||
elif isinstance(config,(str,Path)):
|
||||
self.config_path = config
|
||||
self.config = OmegaConf.load(self.config_path)
|
||||
else:
|
||||
|
||||
self.config_path = None
|
||||
if isinstance(config, (str, Path)):
|
||||
self.config_path = Path(config)
|
||||
config = OmegaConf.load(self.config_path)
|
||||
|
||||
elif not isinstance(config, DictConfig):
|
||||
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
|
||||
|
||||
#if "__meta__" not in config:
|
||||
# config = self._convert_2_3_models(config)
|
||||
|
||||
config_meta = ConfigMeta(**config.pop("__metadata__")) # TODO: naming
|
||||
# TODO: metadata not found
|
||||
|
||||
self.models = dict()
|
||||
for model_key, model_config in config.items():
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
self.models[model_key] = model_class.build_config(**model_config)
|
||||
|
||||
# check config version number and update on disk/RAM if necessary
|
||||
self.globals = InvokeAIAppConfig.get_config()
|
||||
self._update_config_file_version()
|
||||
self.logger = logger
|
||||
self.cache = ModelCache(
|
||||
max_cache_size=max_cache_size,
|
||||
@ -267,7 +316,7 @@ class ModelManager(object):
|
||||
self.cache_keys = dict()
|
||||
|
||||
# add controlnet, lora and textual_inversion models from disk
|
||||
self.scan_models_directory(include_diffusers=False)
|
||||
self.scan_models_directory()
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
@ -280,7 +329,7 @@ class ModelManager(object):
|
||||
identifier.
|
||||
"""
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
return model_key in self.config
|
||||
return model_key in self.models
|
||||
|
||||
def create_key(
|
||||
self,
|
||||
@ -350,38 +399,49 @@ class ModelManager(object):
|
||||
|
||||
"""
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_dir = self.globals.models_path
|
||||
if not model_class.has_config:
|
||||
model_config = None
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
|
||||
for ext in {"pt", "ckpt", "safetensors"}:
|
||||
model_path = os.path.join(model_dir, base_model, model_type, f"{model_name}.{ext}")
|
||||
if os.path.exists(model_path):
|
||||
break
|
||||
else:
|
||||
model_path = os.path.join(model_dir, base_model, model_type, model_name)
|
||||
if not os.path.exists(model_path):
|
||||
raise InvalidModelError(
|
||||
f"Model not found - \"{base_model}/{model_type}/{model_name}\" "
|
||||
)
|
||||
|
||||
else:
|
||||
# find in config
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
if model_key not in self.config:
|
||||
raise InvalidModelError(
|
||||
f'"{model_key}" is not a known model name. Please check your models.yaml file'
|
||||
# if model not found try to find it (maybe file just pasted)
|
||||
if model_key not in self.models:
|
||||
# TODO: find by mask or try rescan?
|
||||
path_mask = f"/models/{base_model}/{model_type}/{model_name}*"
|
||||
if False: # model_path = next(find_by_mask(path_mask)):
|
||||
model_path = None # TODO:
|
||||
model_config = model_class.build_config(
|
||||
path=model_path,
|
||||
)
|
||||
self.models[model_key] = model_config
|
||||
else:
|
||||
raise Exception(f"Model not found - {model_key}")
|
||||
|
||||
model_config = self.config[model_key]
|
||||
model_path = model_config.path
|
||||
# if it known model check that target path exists (if manualy deleted)
|
||||
else:
|
||||
# logic repeated twice(in rescan too) any way to optimize?
|
||||
if not os.path.exists(self.models[model_key].path):
|
||||
if model_class.save_to_config:
|
||||
self.models[model_key].error = ModelError.NotFound
|
||||
raise Exception(f"Files for model \"{model_key}\" not found")
|
||||
|
||||
# vae/movq override
|
||||
# TODO:
|
||||
if submodel_type is not None and submodel_type in model_config:
|
||||
model_path = model_config[submodel_type]["path"]
|
||||
model_type = submodel_type
|
||||
submodel_type = None
|
||||
else:
|
||||
self.models.pop(model_key, None)
|
||||
raise Exception(f"Model not found - {model_key}")
|
||||
|
||||
# reset model errors?
|
||||
|
||||
|
||||
|
||||
model_config = self.models[model_key]
|
||||
|
||||
# /models/{base_model}/{model_type}/{name}.ckpt or .safentesors
|
||||
# /models/{base_model}/{model_type}/{name}/
|
||||
model_path = model_config.path
|
||||
|
||||
# vae/movq override
|
||||
# TODO:
|
||||
if submodel is not None and submodel in model_config:
|
||||
model_path = model_config[submodel]
|
||||
model_type = submodel
|
||||
submodel = None
|
||||
|
||||
dst_convert_path = None # TODO:
|
||||
model_path = model_class.convert_if_required(
|
||||
@ -414,11 +474,11 @@ class ModelManager(object):
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
"""
|
||||
for model_key, model_config in self.config.items():
|
||||
if model_config.get("default", False):
|
||||
for model_key, model_config in self.models.items():
|
||||
if model_config.default:
|
||||
return self.parse_key(model_key)
|
||||
|
||||
for model_key, _ in self.config.items():
|
||||
for model_key, _ in self.models.items():
|
||||
return self.parse_key(model_key)
|
||||
else:
|
||||
return None # TODO: or redo as (None, None, None)
|
||||
@ -435,14 +495,11 @@ class ModelManager(object):
|
||||
"""
|
||||
|
||||
model_key = self.model_key(model_name, base_model, model_type)
|
||||
if model_key not in self.config:
|
||||
if model_key not in self.models:
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
for cur_model_key, config in self.config.items():
|
||||
if cur_model_key == model_key:
|
||||
config["default"] = True
|
||||
else:
|
||||
config.pop("default", None)
|
||||
for cur_model_key, config in self.models.items():
|
||||
config.default = cur_model_key == model_key
|
||||
|
||||
def model_info(
|
||||
self,
|
||||
@ -454,14 +511,17 @@ class ModelManager(object):
|
||||
Given a model name returns the OmegaConf (dict-like) object describing it.
|
||||
"""
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
return self.config.get(model_key, None)
|
||||
if model_key in self.models:
|
||||
return self.models[model_key].dict(exclude_defaults=True)
|
||||
else:
|
||||
return None # TODO: None or empty dict on not found
|
||||
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Return a list of (str, BaseModelType, ModelType) corresponding to all models
|
||||
known to the configuration.
|
||||
"""
|
||||
return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x], DictConfig)]
|
||||
return [(self.parse_key(x)) for x in self.models.keys()]
|
||||
|
||||
def list_models(
|
||||
self,
|
||||
@ -479,61 +539,52 @@ class ModelManager(object):
|
||||
assert not(model_type is not None and base_model is None), "model_type must be provided with base_model"
|
||||
|
||||
models = dict()
|
||||
for model_key in sorted(self.config, key=str.casefold):
|
||||
stanza = self.config[model_key]
|
||||
for model_key in sorted(self.models, key=str.casefold):
|
||||
model_config = self.models[model_key]
|
||||
|
||||
if model_key.startswith('_'):
|
||||
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||
if base_model is not None and cur_base_model != base_model:
|
||||
continue
|
||||
if model_type is not None and cur_model_type != model_type:
|
||||
continue
|
||||
|
||||
model_name, m_base_model, stanza_type = self.parse_key(model_key)
|
||||
if base_model is not None and m_base_model != base_model:
|
||||
continue
|
||||
if model_type is not None and model_type != stanza_type:
|
||||
continue
|
||||
if cur_base_model not in models:
|
||||
models[cur_base_model] = dict()
|
||||
if cur_model_type not in models[cur_base_model]:
|
||||
models[cur_base_model][cur_model_type] = dict()
|
||||
|
||||
if m_base_model not in models:
|
||||
models[m_base_model] = dict()
|
||||
if stanza_type not in models:
|
||||
models[m_base_model][stanza_type] = dict()
|
||||
|
||||
model_class = MODEL_CLASSES[m_base_model][stanza_type]
|
||||
models[m_base_model][stanza_type][model_name] = model_class.build_config(
|
||||
**stanza,
|
||||
name=model_name,
|
||||
base_model=base_model,
|
||||
type=stanza_type,
|
||||
models[cur_base_model][cur_model_type][cur_model_name] = dict(
|
||||
**model_config.dict(exclude_defaults=True),
|
||||
name=cur_model_name,
|
||||
base_model=cur_base_model,
|
||||
type=cur_model_type,
|
||||
)
|
||||
#models[m_base_model][stanza_type][model_name] = model_class.Config(
|
||||
# **stanza,
|
||||
# name=model_name,
|
||||
# base_model=base_model,
|
||||
# type=stanza_type,
|
||||
#).dict()
|
||||
|
||||
return models
|
||||
|
||||
def print_models(self) -> None:
|
||||
"""
|
||||
Print a table of models, their descriptions, and load status
|
||||
Print a table of models, their descriptions
|
||||
"""
|
||||
# TODO: redo
|
||||
for model_type, model_dict in self.list_models().items():
|
||||
for model_name, model_info in model_dict.items():
|
||||
line = f'{model_info["name"]:25s} {model_info["status"]:>15s} {model_info["type"]:10s} {model_info["description"]}'
|
||||
if model_info["status"] in ["in gpu","locked in gpu"]:
|
||||
line = f"\033[1m{line}\033[0m"
|
||||
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
||||
print(line)
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: ModelType.Diffusers,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
delete_files: bool = False,
|
||||
):
|
||||
"""
|
||||
Delete the named model.
|
||||
"""
|
||||
model_key = self.create_key(model_name, model_type)
|
||||
model_cfg = self.pop(model_key, None)
|
||||
raise Exception("TODO: del_model") # TODO: redo
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
model_cfg = self.models.pop(model_key, None)
|
||||
|
||||
if model_cfg is None:
|
||||
self.logger.error(
|
||||
@ -581,27 +632,14 @@ class ModelManager(object):
|
||||
"""
|
||||
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
|
||||
model_class.build_config(
|
||||
**model_attributes,
|
||||
name=model_name,
|
||||
base_model=base_model,
|
||||
type=model_type,
|
||||
)
|
||||
#model_cfg = model_class.Config(
|
||||
# **model_attributes,
|
||||
# name=model_name,
|
||||
# base_model=base_model,
|
||||
# type=model_type,
|
||||
#)
|
||||
|
||||
model_config = model_class.build_config(**model_attributes)
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
|
||||
assert (
|
||||
clobber or model_key not in self.config
|
||||
clobber or model_key not in self.models
|
||||
), f'attempt to overwrite existing model definition "{model_key}"'
|
||||
|
||||
self.config[model_key] = model_attributes
|
||||
self.models[model_key] = model_config
|
||||
|
||||
if clobber and model_key in self.cache_keys:
|
||||
# TODO:
|
||||
@ -633,7 +671,15 @@ class ModelManager(object):
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
"""
|
||||
yaml_str = OmegaConf.to_yaml(self.config)
|
||||
data_to_save = dict()
|
||||
for model_key, model_config in self.models.items():
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
if model_class.save_to_config:
|
||||
# TODO: or exclude_unset better fits here?
|
||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True)
|
||||
|
||||
yaml_str = OmegaConf.to_yaml(data_to_save)
|
||||
config_file_path = conf_file or self.config_path
|
||||
assert config_file_path is not None,'no config file path to write to'
|
||||
config_file_path = self.globals.root_dir / config_file_path
|
||||
@ -697,61 +743,3 @@ class ModelManager(object):
|
||||
resolved_path = self.globals.root_dir / source
|
||||
return resolved_path
|
||||
|
||||
def _update_config_file_version(self):
|
||||
"""
|
||||
This gets called at object init time and will update
|
||||
from older versions of the config file to new ones
|
||||
as necessary.
|
||||
"""
|
||||
current_version = self.config.get("_version","1.0.0")
|
||||
if version.parse(current_version) < version.parse(CONFIG_FILE_VERSION):
|
||||
self.logger.warning(f'models.yaml version {current_version} detected. Updating to {CONFIG_FILE_VERSION}')
|
||||
self.logger.warning('The original file will be renamed models.yaml.orig')
|
||||
if self.config_path:
|
||||
old_file = Path(self.config_path)
|
||||
new_name = old_file.parent / 'models.yaml.orig'
|
||||
old_file.replace(new_name)
|
||||
|
||||
new_config = OmegaConf.create()
|
||||
new_config["_version"] = CONFIG_FILE_VERSION
|
||||
|
||||
for model_key in self.config:
|
||||
|
||||
old_stanza = self.config[model_key]
|
||||
if not isinstance(old_stanza,DictConfig):
|
||||
continue
|
||||
|
||||
# ignore old and ugly way of associating a legacy
|
||||
# vae with a legacy checkpont model
|
||||
if old_stanza.get("config") and '/VAE/' in old_stanza.get("config"):
|
||||
continue
|
||||
|
||||
# bare keys are updated to be prefixed with 'diffusers/'
|
||||
if '/' not in model_key:
|
||||
new_key = f'diffusers/{model_key}'
|
||||
else:
|
||||
new_key = model_key
|
||||
|
||||
if old_stanza.get('format')=='diffusers':
|
||||
model_format = 'folder'
|
||||
elif old_stanza.get('weights') and Path(old_stanza.get('weights')).suffix == '.ckpt':
|
||||
model_format = 'ckpt'
|
||||
elif old_stanza.get('weights') and Path(old_stanza.get('weights')).suffix == '.safetensors':
|
||||
model_format = 'safetensors'
|
||||
else:
|
||||
model_format = old_stanza.get('format')
|
||||
|
||||
# copy fields over manually rather than doing a copy() or deepcopy()
|
||||
# in order to avoid bringing in unwanted fields.
|
||||
new_config[new_key] = dict(
|
||||
description = old_stanza.get('description'),
|
||||
format = model_format,
|
||||
)
|
||||
for field in ["repo_id", "path", "weights", "config", "vae"]:
|
||||
if field_value := old_stanza.get(field):
|
||||
new_config[new_key].update({field: field_value})
|
||||
|
||||
self.config = new_config
|
||||
if self.config_path:
|
||||
self.commit()
|
||||
|
||||
|
37
invokeai/backend/model_management/models/__init__.py
Normal file
37
invokeai/backend/model_management/models/__init__.py
Normal file
@ -0,0 +1,37 @@
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase
|
||||
from .stable_diffusion import StableDiffusion15Model, StableDiffusion2Model, StableDiffusion2BaseModel
|
||||
from .vae import VaeModel
|
||||
from .lora import LoRAModel
|
||||
#from .controlnet import ControlNetModel # TODO:
|
||||
from .textual_inversion import TextualInversionModel
|
||||
|
||||
MODEL_CLASSES = {
|
||||
BaseModelType.StableDiffusion1_5: {
|
||||
ModelType.Pipeline: StableDiffusion15Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
#ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelType.Pipeline: StableDiffusion2Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
#ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
},
|
||||
BaseModelType.StableDiffusion2Base: {
|
||||
ModelType.Pipeline: StableDiffusion2BaseModel,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
#ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
},
|
||||
#BaseModelType.Kandinsky2_1: {
|
||||
# ModelType.Pipeline: Kandinsky2_1Model,
|
||||
# ModelType.MoVQ: MoVQModel,
|
||||
# ModelType.Lora: LoRAModel,
|
||||
# ModelType.ControlNet: ControlNetModel,
|
||||
# ModelType.TextualInversion: TextualInversionModel,
|
||||
#},
|
||||
}
|
297
invokeai/backend/model_management/models/base.py
Normal file
297
invokeai/backend/model_management/models/base.py
Normal file
@ -0,0 +1,297 @@
|
||||
import sys
|
||||
import typing
|
||||
import inspect
|
||||
from enum import Enum
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, ConfigMixin
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
#StableDiffusion1_5 = "stable_diffusion_1_5"
|
||||
#StableDiffusion2 = "stable_diffusion_2"
|
||||
#StableDiffusion2Base = "stable_diffusion_2_base"
|
||||
# TODO: maybe then add sample size(512/768)?
|
||||
StableDiffusion1_5 = "sd-1.5"
|
||||
StableDiffusion2Base = "sd-2-base" # 512 pixels; this will have epsilon parameterization
|
||||
StableDiffusion2 = "sd-2" # 768 pixels; this will have v-prediction parameterization
|
||||
#Kandinsky2_1 = "kandinsky_2_1"
|
||||
|
||||
class ModelType(str, Enum):
|
||||
Pipeline = "pipeline"
|
||||
Vae = "vae"
|
||||
|
||||
Lora = "lora"
|
||||
#ControlNet = "controlnet"
|
||||
TextualInversion = "embedding"
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
UNet = "unet"
|
||||
TextEncoder = "text_encoder"
|
||||
Tokenizer = "tokenizer"
|
||||
Vae = "vae"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
#MoVQ = "movq"
|
||||
|
||||
class ModelError(str, Enum):
|
||||
NotFound = "not_found"
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
path: str # or Path
|
||||
#name: str # not included as present in model key
|
||||
description: Optional[str] = Field(None)
|
||||
format: Optional[str] = Field(None)
|
||||
default: Optional[bool] = Field(False)
|
||||
# do not save to config
|
||||
error: Optional[ModelError] = Field(None, exclude=True)
|
||||
|
||||
|
||||
class EmptyConfigLoader(ConfigMixin):
|
||||
@classmethod
|
||||
def load_config(cls, *args, **kwargs):
|
||||
cls.config_name = kwargs.pop("config_name")
|
||||
return super().load_config(*args, **kwargs)
|
||||
|
||||
class ModelBase:
|
||||
#model_path: str
|
||||
#base_model: BaseModelType
|
||||
#model_type: ModelType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
self.model_path = model_path
|
||||
self.base_model = base_model
|
||||
self.model_type = model_type
|
||||
|
||||
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
|
||||
if len(subtypes) < 2:
|
||||
raise Exception("Invalid subfolder definition!")
|
||||
if subtypes[0] in ["diffusers", "transformers"]:
|
||||
res_type = sys.modules[subtypes[0]]
|
||||
subtypes = subtypes[1:]
|
||||
|
||||
else:
|
||||
res_type = sys.modules["diffusers"]
|
||||
res_type = getattr(res_type, "pipelines")
|
||||
|
||||
|
||||
for subtype in subtypes:
|
||||
res_type = getattr(res_type, subtype)
|
||||
return res_type
|
||||
|
||||
@classmethod
|
||||
def _get_configs(cls):
|
||||
if not hasattr(cls, "__configs"):
|
||||
configs = dict()
|
||||
for name in dir(cls):
|
||||
if name.startswith("__"):
|
||||
continue
|
||||
|
||||
value = getattr(cls, name)
|
||||
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
|
||||
continue
|
||||
|
||||
fields = inspect.get_annotations(value)
|
||||
if "format" not in fields or typing.get_origin(fields["format"]) != Literal:
|
||||
raise Exception("Invalid config definition - format field not found")
|
||||
|
||||
format_type = typing.get_origin(fields["format"])
|
||||
if format_type not in {None, Literal}:
|
||||
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
||||
|
||||
if format_type is Literal:
|
||||
format = fields["format"].__args__[0]
|
||||
else:
|
||||
format = None
|
||||
configs[format] = value # TODO: error when override(multiple)?
|
||||
|
||||
cls.__configs = configs
|
||||
|
||||
return cls.__configs
|
||||
|
||||
@classmethod
|
||||
def build_config(cls, **kwargs):
|
||||
if "format" not in kwargs:
|
||||
kwargs["format"] = cls.detect_format(kwargs["path"])
|
||||
|
||||
configs = cls._get_configs()
|
||||
return configs[kwargs["format"]](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
|
||||
class DiffusersModel(ModelBase):
|
||||
#child_types: Dict[str, Type]
|
||||
#child_sizes: Dict[str, int]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.child_types: Dict[str, Type] = dict()
|
||||
self.child_sizes: Dict[str, int] = dict()
|
||||
|
||||
try:
|
||||
config_data = DiffusionPipeline.load_config(self.model_path)
|
||||
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
|
||||
except:
|
||||
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
|
||||
|
||||
config_data.pop("_ignore_files", None)
|
||||
|
||||
# retrieve all folder_names that contain relevant files
|
||||
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
|
||||
|
||||
for child_name in child_components:
|
||||
child_type = self._hf_definition_to_type(config_data[child_name])
|
||||
self.child_types[child_name] = child_type
|
||||
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
|
||||
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
if child_type is None:
|
||||
return sum(self.child_sizes.values())
|
||||
else:
|
||||
return self.child_sizes[child_type]
|
||||
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
):
|
||||
# return pipeline in different function to pass more arguments
|
||||
if child_type is None:
|
||||
raise Exception("Child model type can't be null on diffusers model")
|
||||
if child_type not in self.child_types:
|
||||
return None # TODO: or raise
|
||||
|
||||
if torch_dtype == torch.float16:
|
||||
variants = ["fp16", None]
|
||||
else:
|
||||
variants = [None, "fp16"]
|
||||
|
||||
# TODO: better error handling(differentiate not found from others)
|
||||
for variant in variants:
|
||||
try:
|
||||
# TODO: set cache_dir to /dev/null to be sure that cache not used?
|
||||
model = self.child_types[child_type].from_pretrained(
|
||||
self.model_path,
|
||||
subfolder=child_type.value,
|
||||
torch_dtype=torch_dtype,
|
||||
variant=variant,
|
||||
local_files_only=True,
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
print("====ERR LOAD====")
|
||||
print(f"{variant}: {e}")
|
||||
|
||||
# calc more accurate size
|
||||
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
||||
return model
|
||||
|
||||
#def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
|
||||
|
||||
|
||||
|
||||
def calc_model_size_by_fs(
|
||||
model_path: str,
|
||||
subfolder: Optional[str] = None,
|
||||
variant: Optional[str] = None
|
||||
):
|
||||
if subfolder is not None:
|
||||
model_path = os.path.join(model_path, subfolder)
|
||||
|
||||
# this can happen when, for example, the safety checker
|
||||
# is not downloaded.
|
||||
if not os.path.exists(model_path):
|
||||
return 0
|
||||
|
||||
all_files = os.listdir(model_path)
|
||||
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
|
||||
|
||||
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
|
||||
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
|
||||
other_files = set(all_files) - fp16_files - bit8_files
|
||||
|
||||
if variant is None:
|
||||
files = other_files
|
||||
elif variant == "fp16":
|
||||
files = fp16_files
|
||||
elif variant == "8bit":
|
||||
files = bit8_files
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown variant: {variant}")
|
||||
|
||||
# try read from index if exists
|
||||
index_postfix = ".index.json"
|
||||
if variant is not None:
|
||||
index_postfix = f".index.{variant}.json"
|
||||
|
||||
for file in files:
|
||||
if not file.endswith(index_postfix):
|
||||
continue
|
||||
try:
|
||||
with open(os.path.join(model_path, file), "r") as f:
|
||||
index_data = json.loads(f.read())
|
||||
return int(index_data["metadata"]["total_size"])
|
||||
except:
|
||||
pass
|
||||
|
||||
# calculate files size if there is no index file
|
||||
formats = [
|
||||
(".safetensors",), # safetensors
|
||||
(".bin",), # torch
|
||||
(".onnx", ".pb"), # onnx
|
||||
(".msgpack",), # flax
|
||||
(".ckpt",), # tf
|
||||
(".h5",), # tf2
|
||||
]
|
||||
|
||||
for file_format in formats:
|
||||
model_files = [f for f in files if f.endswith(file_format)]
|
||||
if len(model_files) == 0:
|
||||
continue
|
||||
|
||||
model_size = 0
|
||||
for model_file in model_files:
|
||||
file_stats = os.stat(os.path.join(model_path, model_file))
|
||||
model_size += file_stats.st_size
|
||||
return model_size
|
||||
|
||||
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
|
||||
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
|
||||
|
||||
|
||||
def calc_model_size_by_data(model) -> int:
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
return _calc_pipeline_by_data(model)
|
||||
elif isinstance(model, torch.nn.Module):
|
||||
return _calc_model_by_data(model)
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def _calc_pipeline_by_data(pipeline) -> int:
|
||||
res = 0
|
||||
for submodel_key in pipeline.components.keys():
|
||||
submodel = getattr(pipeline, submodel_key)
|
||||
if submodel is not None and isinstance(submodel, torch.nn.Module):
|
||||
res += _calc_model_by_data(submodel)
|
||||
return res
|
||||
|
||||
|
||||
def _calc_model_by_data(model) -> int:
|
||||
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
|
||||
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
|
||||
mem = mem_params + mem_bufs # in bytes
|
||||
return mem
|
63
invokeai/backend/model_management/models/lora.py
Normal file
63
invokeai/backend/model_management/models/lora.py
Normal file
@ -0,0 +1,63 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
# TODO: naming
|
||||
from ..lora import LoRAModel as LoRAModelRaw
|
||||
|
||||
class LoRAModel(ModelBase):
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
format: None
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Lora
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.model_size = os.path.getsize(self.model_path)
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in lora")
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in lora")
|
||||
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=self.model_path,
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
|
||||
self.model_size = model.calc_size()
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if os.path.isdir(path):
|
||||
return "diffusers"
|
||||
else:
|
||||
return "lycoris"
|
||||
|
||||
@staticmethod
|
||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
|
||||
if cls.detect_format(model_path) == "diffusers":
|
||||
# TODO: add diffusers lora when it stabilizes a bit
|
||||
raise NotImplementedError("Diffusers lora not supported")
|
||||
else:
|
||||
return model_path
|
131
invokeai/backend/model_management/models/stable_diffusion.py
Normal file
131
invokeai/backend/model_management/models/stable_diffusion.py
Normal file
@ -0,0 +1,131 @@
|
||||
import os
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from typing import Literal, Optional
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
DiffusersModel,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
|
||||
# TODO: how to name properly
|
||||
class StableDiffusion15Model(DiffusersModel):
|
||||
|
||||
# TODO: str -> Path?
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
format: Literal["diffusers"]
|
||||
vae: Optional[str] = Field(None)
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
format: Literal["checkpoint"]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: Optional[str] = Field(None)
|
||||
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion1_5
|
||||
assert model_type == ModelType.Pipeline
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion1_5,
|
||||
model_type=ModelType.Pipeline,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
if os.path.isdir(model_path):
|
||||
return "diffusers"
|
||||
else:
|
||||
return "checkpoint"
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
|
||||
cfg = cls.build_config(**config)
|
||||
if isinstance(cfg, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(cfg) # TODO: args
|
||||
else:
|
||||
return model_path
|
||||
|
||||
# all same
|
||||
class StableDiffusion2BaseModel(StableDiffusion15Model):
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
# skip StableDiffusion15Model __init__
|
||||
assert base_model == BaseModelType.StableDiffusion2Base
|
||||
assert model_type == ModelType.Pipeline
|
||||
super(StableDiffusion15Model, self).__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion2Base,
|
||||
model_type=ModelType.Pipeline,
|
||||
)
|
||||
|
||||
class StableDiffusion2Model(DiffusersModel):
|
||||
|
||||
# TODO: str -> Path?
|
||||
# overwrite configs
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
format: Literal["diffusers"]
|
||||
vae: Optional[str] = Field(None)
|
||||
attention_upscale: bool = Field(True)
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
format: Literal["checkpoint"]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: Optional[str] = Field(None)
|
||||
attention_upscale: bool = Field(True)
|
||||
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
# skip StableDiffusion15Model __init__
|
||||
assert base_model == BaseModelType.StableDiffusion2
|
||||
assert model_type == ModelType.Pipeline
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion2,
|
||||
model_type=ModelType.Pipeline,
|
||||
)
|
||||
|
||||
|
||||
# TODO: rework
|
||||
DictConfig = dict
|
||||
def _convert_ckpt_and_cache(self, mconfig: DictConfig) -> str:
|
||||
"""
|
||||
Convert the checkpoint model indicated in mconfig into a
|
||||
diffusers, cache it to disk, and return Path to converted
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
weights = app_config.root_dir / mconfig.path
|
||||
config_file = app_config.root_dir / mconfig.config
|
||||
diffusers_path = app_config.converted_ckpts_dir / weights.stem
|
||||
|
||||
# return cached version if it exists
|
||||
if diffusers_path.exists():
|
||||
return diffusers_path
|
||||
|
||||
# TODO: I think that it more correctly to convert with embedded vae
|
||||
# as if user will delete custom vae he will got not embedded but also custom vae
|
||||
#vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
|
||||
vae_ckpt_path, vae_model = None, None
|
||||
|
||||
# to avoid circular import errors
|
||||
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
with SilenceWarnings():
|
||||
convert_ckpt_to_diffusers(
|
||||
weights,
|
||||
diffusers_path,
|
||||
extract_ema=True,
|
||||
original_config_file=config_file,
|
||||
vae=vae_model,
|
||||
vae_path=str(app_config.root_dir / vae_ckpt_path) if vae_ckpt_path else None,
|
||||
scan_needed=True,
|
||||
)
|
||||
return diffusers_path
|
@ -0,0 +1,56 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
# TODO: naming
|
||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||
|
||||
class TextualInversionModel(ModelBase):
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
format: None
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.TextualInversion
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.model_size = os.path.getsize(self.model_path)
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in textual inversion")
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in textual inversion")
|
||||
|
||||
model = TextualInversionModelRaw.from_checkpoint(
|
||||
file_path=self.model_path,
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
|
||||
self.model_size = model.embedding.nelement() * model.embedding.element_size()
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
|
||||
return model_path
|
122
invokeai/backend/model_management/models/vae.py
Normal file
122
invokeai/backend/model_management/models/vae.py
Normal file
@ -0,0 +1,122 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import Optional
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
EmptyConfigLoader,
|
||||
calc_model_size_by_fs,
|
||||
calc_model_size_by_data,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
class VaeModel(ModelBase):
|
||||
#vae_class: Type
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
format: None
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Vae
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
try:
|
||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||
#config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||
except:
|
||||
raise Exception("Invalid vae model! (config.json not found or invalid)")
|
||||
|
||||
try:
|
||||
vae_class_name = config.get("_class_name", "AutoencoderKL")
|
||||
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
|
||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||
except:
|
||||
raise Exception("Invalid vae model! (Unkown vae type)")
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in vae model")
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in vae model")
|
||||
|
||||
model = self.vae_class.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
# calc more accurate size
|
||||
self.model_size = calc_model_size_by_data(model)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if os.path.isdir(path):
|
||||
return "diffusers"
|
||||
else:
|
||||
return "checkpoint"
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
|
||||
if cls.detect_format(model_path) != "diffusers":
|
||||
# TODO:
|
||||
#_convert_vae_ckpt_and_cache
|
||||
raise NotImplementedError("TODO: vae convert")
|
||||
else:
|
||||
return model_path
|
||||
|
||||
# TODO: rework
|
||||
DictConfig = dict
|
||||
def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> str:
|
||||
"""
|
||||
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
|
||||
object, cache it to disk, and return Path to converted
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
root = app_config.root_dir
|
||||
weights_file = root / mconfig.path
|
||||
config_file = root / mconfig.config
|
||||
diffusers_path = app_config.converted_ckpts_dir / weights_file.stem
|
||||
image_size = mconfig.get('width') or mconfig.get('height') or 512
|
||||
|
||||
# return cached version if it exists
|
||||
if diffusers_path.exists():
|
||||
return diffusers_path
|
||||
|
||||
# this avoids circular import error
|
||||
from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
if weights_file.suffix == '.safetensors':
|
||||
checkpoint = safetensors.torch.load_file(weights_file)
|
||||
else:
|
||||
checkpoint = torch.load(weights_file, map_location="cpu")
|
||||
|
||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
config = OmegaConf.load(config_file)
|
||||
|
||||
vae_model = convert_ldm_vae_to_diffusers(
|
||||
checkpoint = checkpoint,
|
||||
vae_config = config,
|
||||
image_size = image_size
|
||||
)
|
||||
vae_model.save_pretrained(
|
||||
diffusers_path,
|
||||
safe_serialization=is_safetensors_available()
|
||||
)
|
||||
return diffusers_path
|
@ -14,7 +14,7 @@ export const receivedModels = createAppAsyncThunk(
|
||||
const response = await ModelsService.listModels();
|
||||
|
||||
const deserializedModels = reduce(
|
||||
response.models['diffusers'],
|
||||
response.models['sd-1.5']['pipeline'],
|
||||
(modelsAccumulator, model, modelName) => {
|
||||
modelsAccumulator[modelName] = { ...model, name: modelName };
|
||||
|
||||
@ -25,7 +25,7 @@ export const receivedModels = createAppAsyncThunk(
|
||||
|
||||
models.info(
|
||||
{ response },
|
||||
`Received ${size(response.models['diffusers'])} models`
|
||||
`Received ${size(response.models['sd-1.5']['pipeline'])} models`
|
||||
);
|
||||
|
||||
return deserializedModels;
|
||||
|
Loading…
Reference in New Issue
Block a user