model loading and conversion implemented for vaes

This commit is contained in:
Lincoln Stein
2024-02-03 22:55:09 -05:00
committed by psychedelicious
parent b8e875bb73
commit 8ba5360269
29 changed files with 2382 additions and 237 deletions

View File

@ -8,15 +8,14 @@ from typing import Any, Dict, Optional, Tuple
from diffusers import ModelMixin
from diffusers.configuration_utils import ConfigMixin
from injector import inject
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType
from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.model_manager.ram_cache import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
@ -35,7 +34,6 @@ class ConfigLoader(ConfigMixin):
class ModelLoader(ModelLoaderBase):
"""Default implementation of ModelLoaderBase."""
@inject # can inject instances of each of the classes in the call signature
def __init__(
self,
app_config: InvokeAIAppConfig,
@ -87,18 +85,15 @@ class ModelLoader(ModelLoaderBase):
def _convert_if_needed(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
) -> Path:
if not self._needs_conversion(config):
return model_path
cache_path: Path = self._convert_cache.cache_path(config.key)
if not self._needs_conversion(config, model_path, cache_path):
return cache_path if cache_path.exists() else model_path
self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type))
cache_path: Path = self._convert_cache.cache_path(config.key)
if cache_path.exists():
return cache_path
return self._convert_model(config, model_path, cache_path)
self._convert_model(model_path, cache_path)
return cache_path
def _needs_conversion(self, config: AnyModelConfig) -> bool:
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
return False
def _load_if_needed(
@ -133,7 +128,7 @@ class ModelLoader(ModelLoaderBase):
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
)
def _convert_model(self, model_path: Path, cache_path: Path) -> None:
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
raise NotImplementedError
def _load_model(