mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
final tidying before marking PR as ready for review
- Replace AnyModelLoader with ModelLoaderRegistry - Fix type check errors in multiple files - Remove apparently unneeded `get_model_config_enum()` method from model manager - Remove last vestiges of old model manager - Updated tests and documentation resolve conflict with seamless.py
This commit is contained in:
@ -1,13 +1,9 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Default implementation of model loading in InvokeAI."""
|
||||
|
||||
import sys
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from diffusers import ModelMixin
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import (
|
||||
@ -25,17 +21,6 @@ from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
|
||||
|
||||
class ConfigLoader(ConfigMixin):
|
||||
"""Subclass of ConfigMixin for loading diffusers configuration files."""
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Load a diffusrs ConfigMixin configuration."""
|
||||
cls.config_name = kwargs.pop("config_name")
|
||||
# Diffusers doesn't provide typing info
|
||||
return super().load_config(*args, **kwargs) # type: ignore
|
||||
|
||||
|
||||
# TO DO: The loader is not thread safe!
|
||||
class ModelLoader(ModelLoaderBase):
|
||||
"""Default implementation of ModelLoaderBase."""
|
||||
@ -137,43 +122,6 @@ class ModelLoader(ModelLoaderBase):
|
||||
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
|
||||
)
|
||||
|
||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
|
||||
return ConfigLoader.load_config(model_path, config_name=config_name)
|
||||
|
||||
# TO DO: Add exception handling
|
||||
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
|
||||
if module in ["diffusers", "transformers"]:
|
||||
res_type = sys.modules[module]
|
||||
else:
|
||||
res_type = sys.modules["diffusers"].pipelines
|
||||
result: ModelMixin = getattr(res_type, class_name)
|
||||
return result
|
||||
|
||||
# TO DO: Add exception handling
|
||||
def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
||||
if submodel_type:
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
||||
module, class_name = config[submodel_type.value]
|
||||
return self._hf_definition_to_type(module=module, class_name=class_name)
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException(
|
||||
f'The "{submodel_type}" submodel is not available for this model.'
|
||||
) from e
|
||||
else:
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||
class_name = config.get("_class_name", None)
|
||||
if class_name:
|
||||
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||
if config.get("model_type", None) == "clip_vision_model":
|
||||
class_name = config.get("architectures")[0]
|
||||
return self._hf_definition_to_type(module="transformers", class_name=class_name)
|
||||
if not class_name:
|
||||
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
raise NotImplementedError
|
||||
|
Reference in New Issue
Block a user