2024-02-01 04:37:59 +00:00
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""
Base class for model loading in InvokeAI .
Use like this :
loader = AnyModelLoader ( . . . )
loaded_model = loader . get_model ( ' 019ab39adfa1840455 ' )
with loaded_model as model : # context manager moves model into VRAM
# do something with loaded_model
"""
2024-02-04 22:23:10 +00:00
import hashlib
2024-02-01 04:37:59 +00:00
from abc import ABC , abstractmethod
from dataclasses import dataclass
from logging import Logger
from pathlib import Path
2024-02-04 22:23:10 +00:00
from typing import Any , Callable , Dict , Optional , Tuple , Type
2024-02-01 04:37:59 +00:00
from invokeai . app . services . config import InvokeAIAppConfig
2024-02-10 23:09:45 +00:00
from invokeai . backend . model_manager . config import (
AnyModel ,
AnyModelConfig ,
BaseModelType ,
2024-02-13 05:26:49 +00:00
ModelConfigBase ,
2024-02-10 23:09:45 +00:00
ModelFormat ,
ModelType ,
SubModelType ,
VaeCheckpointConfig ,
VaeDiffusersConfig ,
)
2024-02-04 03:55:09 +00:00
from invokeai . backend . model_manager . load . convert_cache . convert_cache_base import ModelConvertCacheBase
2024-02-04 22:23:10 +00:00
from invokeai . backend . model_manager . load . model_cache . model_cache_base import ModelCacheBase , ModelLockerBase
2024-02-06 02:55:11 +00:00
from invokeai . backend . util . logging import InvokeAILogger
2024-02-04 22:23:10 +00:00
2024-02-01 04:37:59 +00:00
@dataclass
class LoadedModel :
""" Context manager object that mediates transfer from RAM<->VRAM. """
config : AnyModelConfig
locker : ModelLockerBase
2024-02-10 23:09:45 +00:00
def __enter__ ( self ) - > AnyModel :
2024-02-01 04:37:59 +00:00
""" Context entry. """
self . locker . lock ( )
return self . model
def __exit__ ( self , * args : Any , * * kwargs : Any ) - > None :
""" Context exit. """
self . locker . unlock ( )
@property
def model ( self ) - > AnyModel :
""" Return the model without locking it. """
2024-02-04 03:55:09 +00:00
return self . locker . model
2024-02-01 04:37:59 +00:00
class ModelLoaderBase ( ABC ) :
""" Abstract base class for loading models into RAM/VRAM. """
@abstractmethod
def __init__ (
self ,
app_config : InvokeAIAppConfig ,
logger : Logger ,
2024-02-04 22:23:10 +00:00
ram_cache : ModelCacheBase [ AnyModel ] ,
2024-02-01 04:37:59 +00:00
convert_cache : ModelConvertCacheBase ,
) :
""" Initialize the loader. """
pass
@abstractmethod
2024-02-13 05:26:49 +00:00
def load_model ( self , model_config : ModelConfigBase , submodel_type : Optional [ SubModelType ] = None ) - > LoadedModel :
2024-02-01 04:37:59 +00:00
"""
2024-02-04 03:55:09 +00:00
Return a model given its confguration .
2024-02-01 04:37:59 +00:00
2024-02-04 03:55:09 +00:00
Given a model identified in the model configuration backend ,
2024-02-01 04:37:59 +00:00
return a ModelInfo object that can be used to retrieve the model .
: param model_config : Model configuration , as returned by ModelConfigRecordStore
: param submodel_type : an ModelType enum indicating the portion of
the model to retrieve ( e . g . ModelType . Vae )
"""
pass
@abstractmethod
def get_size_fs (
self , config : AnyModelConfig , model_path : Path , submodel_type : Optional [ SubModelType ] = None
) - > int :
""" Return size in bytes of the model, calculated before loading. """
pass
# TO DO: Better name?
class AnyModelLoader :
""" This class manages the model loaders and invokes the correct one to load a model of given base and type. """
# this tracks the loader subclasses
_registry : Dict [ str , Type [ ModelLoaderBase ] ] = { }
2024-02-06 02:55:11 +00:00
_logger : Logger = InvokeAILogger . get_logger ( )
2024-02-01 04:37:59 +00:00
def __init__ (
self ,
app_config : InvokeAIAppConfig ,
logger : Logger ,
2024-02-04 22:23:10 +00:00
ram_cache : ModelCacheBase [ AnyModel ] ,
2024-02-01 04:37:59 +00:00
convert_cache : ModelConvertCacheBase ,
) :
2024-02-04 03:55:09 +00:00
""" Initialize AnyModelLoader with its dependencies. """
2024-02-01 04:37:59 +00:00
self . _app_config = app_config
self . _logger = logger
self . _ram_cache = ram_cache
self . _convert_cache = convert_cache
2024-02-04 03:55:09 +00:00
@property
2024-02-04 22:23:10 +00:00
def ram_cache ( self ) - > ModelCacheBase [ AnyModel ] :
2024-02-04 03:55:09 +00:00
""" Return the RAM cache associated used by the loaders. """
return self . _ram_cache
2024-02-01 04:37:59 +00:00
2024-02-13 02:25:42 +00:00
@property
def convert_cache ( self ) - > ModelConvertCacheBase :
""" Return the convert cache associated used by the loaders. """
return self . _convert_cache
2024-02-13 05:26:49 +00:00
def load_model ( self , model_config : ModelConfigBase , submodel_type : Optional [ SubModelType ] = None ) - > LoadedModel :
2024-02-04 03:55:09 +00:00
"""
Return a model given its configuration .
2024-02-01 04:37:59 +00:00
: param key : model key , as known to the config backend
: param submodel_type : an ModelType enum indicating the portion of
the model to retrieve ( e . g . ModelType . Vae )
"""
2024-02-04 22:23:10 +00:00
implementation , model_config , submodel_type = self . __class__ . get_implementation ( model_config , submodel_type )
2024-02-01 04:37:59 +00:00
return implementation (
app_config = self . _app_config ,
logger = self . _logger ,
ram_cache = self . _ram_cache ,
convert_cache = self . _convert_cache ,
) . load_model ( model_config , submodel_type )
@staticmethod
def _to_registry_key ( base : BaseModelType , type : ModelType , format : ModelFormat ) - > str :
return " - " . join ( [ base . value , type . value , format . value ] )
@classmethod
2024-02-04 22:23:10 +00:00
def get_implementation (
2024-02-13 05:26:49 +00:00
cls , config : ModelConfigBase , submodel_type : Optional [ SubModelType ]
) - > Tuple [ Type [ ModelLoaderBase ] , ModelConfigBase , Optional [ SubModelType ] ] :
2024-02-01 04:37:59 +00:00
""" Get subclass of ModelLoaderBase registered to handle base and type. """
2024-02-04 22:23:10 +00:00
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
conf2 , submodel_type = cls . _handle_subtype_overrides ( config , submodel_type )
key1 = cls . _to_registry_key ( conf2 . base , conf2 . type , conf2 . format ) # for a specific base type
key2 = cls . _to_registry_key ( BaseModelType . Any , conf2 . type , conf2 . format ) # with wildcard Any
2024-02-01 04:37:59 +00:00
implementation = cls . _registry . get ( key1 ) or cls . _registry . get ( key2 )
if not implementation :
raise NotImplementedError (
2024-02-04 22:23:10 +00:00
f " No subclass of LoadedModel is registered for base= { config . base } , type= { config . type } , format= { config . format } "
2024-02-01 04:37:59 +00:00
)
2024-02-04 22:23:10 +00:00
return implementation , conf2 , submodel_type
@classmethod
def _handle_subtype_overrides (
2024-02-13 05:26:49 +00:00
cls , config : ModelConfigBase , submodel_type : Optional [ SubModelType ]
) - > Tuple [ ModelConfigBase , Optional [ SubModelType ] ] :
2024-02-04 22:23:10 +00:00
if submodel_type == SubModelType . Vae and hasattr ( config , " vae " ) and config . vae is not None :
model_path = Path ( config . vae )
config_class = (
VaeCheckpointConfig if model_path . suffix in [ " .pt " , " .safetensors " , " .ckpt " ] else VaeDiffusersConfig
)
hash = hashlib . md5 ( model_path . as_posix ( ) . encode ( " utf-8 " ) ) . hexdigest ( )
new_conf = config_class ( path = model_path . as_posix ( ) , name = model_path . stem , base = config . base , key = hash )
submodel_type = None
else :
new_conf = config
return new_conf , submodel_type
2024-02-01 04:37:59 +00:00
@classmethod
def register (
cls , type : ModelType , format : ModelFormat , base : BaseModelType = BaseModelType . Any
) - > Callable [ [ Type [ ModelLoaderBase ] ] , Type [ ModelLoaderBase ] ] :
""" Define a decorator which registers the subclass of loader. """
def decorator ( subclass : Type [ ModelLoaderBase ] ) - > Type [ ModelLoaderBase ] :
2024-02-06 02:55:11 +00:00
cls . _logger . debug ( f " Registering class { subclass . __name__ } to load models of type { base } / { type } / { format } " )
2024-02-01 04:37:59 +00:00
key = cls . _to_registry_key ( base , type , format )
2024-02-10 23:09:45 +00:00
if key in cls . _registry :
raise Exception (
f " { subclass . __name__ } is trying to register as a loader for { base } / { type } / { format } , but this type of model has already been registered by { cls . _registry [ key ] . __name__ } "
)
2024-02-01 04:37:59 +00:00
cls . _registry [ key ] = subclass
return subclass
return decorator