mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model loading and conversion implemented for vaes
This commit is contained in:
committed by
psychedelicious
parent
b8e875bb73
commit
8ba5360269
@ -8,15 +8,14 @@ from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from diffusers import ModelMixin
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from injector import inject
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType
|
||||
from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.model_manager.ram_cache import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
|
||||
|
||||
@ -35,7 +34,6 @@ class ConfigLoader(ConfigMixin):
|
||||
class ModelLoader(ModelLoaderBase):
|
||||
"""Default implementation of ModelLoaderBase."""
|
||||
|
||||
@inject # can inject instances of each of the classes in the call signature
|
||||
def __init__(
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
@ -87,18 +85,15 @@ class ModelLoader(ModelLoaderBase):
|
||||
def _convert_if_needed(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> Path:
|
||||
if not self._needs_conversion(config):
|
||||
return model_path
|
||||
cache_path: Path = self._convert_cache.cache_path(config.key)
|
||||
|
||||
if not self._needs_conversion(config, model_path, cache_path):
|
||||
return cache_path if cache_path.exists() else model_path
|
||||
|
||||
self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type))
|
||||
cache_path: Path = self._convert_cache.cache_path(config.key)
|
||||
if cache_path.exists():
|
||||
return cache_path
|
||||
return self._convert_model(config, model_path, cache_path)
|
||||
|
||||
self._convert_model(model_path, cache_path)
|
||||
return cache_path
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig) -> bool:
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
def _load_if_needed(
|
||||
@ -133,7 +128,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
|
||||
)
|
||||
|
||||
def _convert_model(self, model_path: Path, cache_path: Path) -> None:
|
||||
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||
raise NotImplementedError
|
||||
|
||||
def _load_model(
|
||||
|
Reference in New Issue
Block a user