mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add support for generic loading of diffusers directories
This commit is contained in:
committed by
psychedelicious
parent
a9962fd104
commit
f81b8bc9f6
@ -7,7 +7,7 @@ from importlib import import_module
|
||||
from pathlib import Path
|
||||
|
||||
from .convert_cache.convert_cache_default import ModelConvertCache
|
||||
from .load_base import LoadedModel, ModelLoaderBase
|
||||
from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
|
||||
from .load_default import ModelLoader
|
||||
from .model_cache.model_cache_default import ModelCache
|
||||
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||
@ -19,6 +19,7 @@ for module in loaders:
|
||||
|
||||
__all__ = [
|
||||
"LoadedModel",
|
||||
"LoadedModelWithoutConfig",
|
||||
"ModelCache",
|
||||
"ModelConvertCache",
|
||||
"ModelLoaderBase",
|
||||
|
@ -20,11 +20,10 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModel:
|
||||
class LoadedModelWithoutConfig:
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
|
||||
_locker: ModelLockerBase
|
||||
config: Optional[AnyModelConfig] = None
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
"""Context entry."""
|
||||
@ -41,6 +40,13 @@ class LoadedModel:
|
||||
return self._locker.model
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModel(LoadedModelWithoutConfig):
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
|
||||
config: Optional[AnyModelConfig] = None
|
||||
|
||||
|
||||
# TODO(MM2):
|
||||
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
|
||||
# know about. I think the problem may be related to this class being an ABC.
|
||||
|
@ -65,14 +65,11 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
else:
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||
class_name = config.get("_class_name", None)
|
||||
if class_name:
|
||||
if class_name := config.get("_class_name"):
|
||||
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||
if config.get("model_type", None) == "clip_vision_model":
|
||||
class_name = config.get("architectures")
|
||||
assert class_name is not None
|
||||
elif class_name := config.get("architectures"):
|
||||
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
|
||||
if not class_name:
|
||||
else:
|
||||
raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||
|
Reference in New Issue
Block a user