mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[mm] Do not write diffuser model to disk when convert_cache set to zero (#6072)
* pass model config to _load_model * make conversion work again * do not write diffusers to disk when convert_cache set to 0 * adding same model to cache twice is a no-op, not an assertion error * fix issues identified by psychedelicious during pr review * following conversion, avoid redundant read of cached submodels * fix error introduced while merging --------- Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
parent
0ac1c0f339
commit
3d6d89feb4
@ -614,8 +614,8 @@ async def convert_model(
|
||||
The return value is the model configuration for the converted model.
|
||||
"""
|
||||
model_manager = ApiDependencies.invoker.services.model_manager
|
||||
loader = model_manager.load
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
loader = ApiDependencies.invoker.services.model_manager.load
|
||||
store = ApiDependencies.invoker.services.model_manager.store
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
|
||||
@ -630,7 +630,13 @@ async def convert_model(
|
||||
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
||||
|
||||
# loading the model will convert it into a cached diffusers file
|
||||
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler)
|
||||
try:
|
||||
cc_size = loader.convert_cache.max_size
|
||||
if cc_size == 0: # temporary set the convert cache to a positive number so that cached model is written
|
||||
loader._convert_cache.max_size = 1.0
|
||||
loader.load_model(model_config, submodel_type=SubModelType.Scheduler)
|
||||
finally:
|
||||
loader._convert_cache.max_size = cc_size
|
||||
|
||||
# Get the path of the converted model from the loader
|
||||
cache_path = loader.convert_cache.cache_path(key)
|
||||
|
@ -33,42 +33,3 @@ __all__ = [
|
||||
"SchedulerPredictionType",
|
||||
"SubModelType",
|
||||
]
|
||||
|
||||
########## to help populate the openapi_schema with format enums for each config ###########
|
||||
# This code is no longer necessary?
|
||||
# leave it here just in case
|
||||
#
|
||||
# import inspect
|
||||
# from enum import Enum
|
||||
# from typing import Any, Iterable, Dict, get_args, Set
|
||||
# def _expand(something: Any) -> Iterable[type]:
|
||||
# if isinstance(something, type):
|
||||
# yield something
|
||||
# else:
|
||||
# for x in get_args(something):
|
||||
# for y in _expand(x):
|
||||
# yield y
|
||||
|
||||
# def _find_format(cls: type) -> Iterable[Enum]:
|
||||
# if hasattr(inspect, "get_annotations"):
|
||||
# fields = inspect.get_annotations(cls)
|
||||
# else:
|
||||
# fields = cls.__annotations__
|
||||
# if "format" in fields:
|
||||
# for x in get_args(fields["format"]):
|
||||
# yield x
|
||||
# for parent_class in cls.__bases__:
|
||||
# for x in _find_format(parent_class):
|
||||
# yield x
|
||||
# return None
|
||||
|
||||
# def get_model_config_formats() -> Dict[str, Set[Enum]]:
|
||||
# result: Dict[str, Set[Enum]] = {}
|
||||
# for model_config in _expand(AnyModelConfig):
|
||||
# for field in _find_format(model_config):
|
||||
# if field is None:
|
||||
# continue
|
||||
# if not result.get(model_config.__qualname__):
|
||||
# result[model_config.__qualname__] = set()
|
||||
# result[model_config.__qualname__].add(field)
|
||||
# return result
|
||||
|
@ -3,7 +3,7 @@
|
||||
"""Conversion script for the Stable Diffusion checkpoints."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from diffusers import AutoencoderKL
|
||||
@ -15,6 +15,8 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
)
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from . import AnyModel
|
||||
|
||||
|
||||
def convert_ldm_vae_to_diffusers(
|
||||
checkpoint: Dict[str, torch.Tensor],
|
||||
@ -33,11 +35,11 @@ def convert_ldm_vae_to_diffusers(
|
||||
|
||||
def convert_ckpt_to_diffusers(
|
||||
checkpoint_path: str | Path,
|
||||
dump_path: str | Path,
|
||||
dump_path: Optional[str | Path] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
use_safetensors: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> AnyModel:
|
||||
"""
|
||||
Takes all the arguments of download_from_original_stable_diffusion_ckpt(),
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
@ -47,18 +49,20 @@ def convert_ckpt_to_diffusers(
|
||||
pipe = pipe.to(precision)
|
||||
|
||||
# TO DO: save correct repo variant
|
||||
pipe.save_pretrained(
|
||||
dump_path,
|
||||
safe_serialization=use_safetensors,
|
||||
)
|
||||
if dump_path:
|
||||
pipe.save_pretrained(
|
||||
dump_path,
|
||||
safe_serialization=use_safetensors,
|
||||
)
|
||||
return pipe
|
||||
|
||||
|
||||
def convert_controlnet_to_diffusers(
|
||||
checkpoint_path: Path,
|
||||
dump_path: Path,
|
||||
dump_path: Optional[Path] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
**kwargs,
|
||||
):
|
||||
) -> AnyModel:
|
||||
"""
|
||||
Takes all the arguments of download_controlnet_from_original_ckpt(),
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
@ -68,4 +72,6 @@ def convert_controlnet_to_diffusers(
|
||||
pipe = pipe.to(precision)
|
||||
|
||||
# TO DO: save correct repo variant
|
||||
pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
if dump_path:
|
||||
pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
return pipe
|
||||
|
@ -19,11 +19,20 @@ class ModelConvertCache(ModelConvertCacheBase):
|
||||
self._cache_path = cache_path
|
||||
self._max_size = max_size
|
||||
|
||||
# adjust cache size at startup in case it has been changed
|
||||
if self._cache_path.exists():
|
||||
self.make_room(0.0)
|
||||
|
||||
@property
|
||||
def max_size(self) -> float:
|
||||
"""Return the maximum size of this cache directory (GB)."""
|
||||
return self._max_size
|
||||
|
||||
@max_size.setter
|
||||
def max_size(self, value: float) -> None:
|
||||
"""Set the maximum size of this cache directory (GB)."""
|
||||
self._max_size = value
|
||||
|
||||
def cache_path(self, key: str) -> Path:
|
||||
"""Return the path for a model with the indicated key."""
|
||||
return self._cache_path / key
|
||||
|
@ -83,3 +83,15 @@ class ModelLoaderBase(ABC):
|
||||
) -> int:
|
||||
"""Return size in bytes of the model, calculated before loading."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the convert cache associated with this loader."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
pass
|
||||
|
@ -3,14 +3,13 @@
|
||||
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
||||
@ -54,51 +53,43 @@ class ModelLoader(ModelLoaderBase):
|
||||
if model_config.type is ModelType.Main and not submodel_type:
|
||||
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
||||
|
||||
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
|
||||
model_path = self._get_model_path(model_config)
|
||||
|
||||
if not model_path.exists():
|
||||
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
|
||||
|
||||
model_path = self._convert_if_needed(model_config, model_path, submodel_type)
|
||||
locker = self._load_if_needed(model_config, model_path, submodel_type)
|
||||
with skip_torch_weight_init():
|
||||
locker = self._convert_and_load(model_config, model_path, submodel_type)
|
||||
return LoadedModel(config=model_config, _locker=locker)
|
||||
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
@property
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the convert cache associated with this loader."""
|
||||
return self._convert_cache
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||
model_base = self._app_config.models_path
|
||||
result = (model_base / config.path).resolve(), config, submodel_type
|
||||
return result
|
||||
return (model_base / config.path).resolve()
|
||||
|
||||
def _convert_if_needed(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> Path:
|
||||
cache_path: Path = self._convert_cache.cache_path(config.key)
|
||||
|
||||
if not self._needs_conversion(config, model_path, cache_path):
|
||||
return cache_path if cache_path.exists() else model_path
|
||||
|
||||
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||
return self._convert_model(config, model_path, cache_path)
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
def _load_if_needed(
|
||||
def _convert_and_load(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> ModelLockerBase:
|
||||
# TO DO: This is not thread safe!
|
||||
try:
|
||||
return self._ram_cache.get(config.key, submodel_type)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
model_variant = getattr(config, "repo_variant", None)
|
||||
self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||
|
||||
# This is where the model is actually loaded!
|
||||
with skip_torch_weight_init():
|
||||
loaded_model = self._load_model(model_path, model_variant=model_variant, submodel_type=submodel_type)
|
||||
cache_path: Path = self._convert_cache.cache_path(config.key)
|
||||
if self._needs_conversion(config, model_path, cache_path):
|
||||
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
||||
else:
|
||||
config.path = str(cache_path) if cache_path.exists() else str(self._get_model_path(config))
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
config.key,
|
||||
@ -123,15 +114,34 @@ class ModelLoader(ModelLoaderBase):
|
||||
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
|
||||
)
|
||||
|
||||
def _do_convert(
|
||||
self, config: AnyModelConfig, model_path: Path, cache_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> AnyModel:
|
||||
self.convert_cache.make_room(calc_model_size_by_fs(model_path))
|
||||
pipeline = self._convert_model(config, model_path, cache_path if self.convert_cache.max_size > 0 else None)
|
||||
if submodel_type:
|
||||
# Proactively load the various submodels into the RAM cache so that we don't have to re-convert
|
||||
# the entire pipeline every time a new submodel is needed.
|
||||
for subtype in SubModelType:
|
||||
if subtype == submodel_type:
|
||||
continue
|
||||
if submodel := getattr(pipeline, subtype.value, None):
|
||||
self._ram_cache.put(
|
||||
config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel)
|
||||
)
|
||||
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
raise NotImplementedError
|
||||
|
||||
# This needs to be implemented in the subclass
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
raise NotImplementedError
|
||||
|
@ -122,6 +122,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Return the cap on cache size."""
|
||||
return self._max_cache_size
|
||||
|
||||
@max_cache_size.setter
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on cache size."""
|
||||
self._max_cache_size = value
|
||||
|
||||
@property
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
@ -157,8 +162,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
assert key not in self._cached_models
|
||||
|
||||
if key in self._cached_models:
|
||||
return
|
||||
self.make_room(size)
|
||||
cache_record = CacheRecord(key, model, size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
@ -405,6 +411,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
#
|
||||
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
|
||||
# immediately when their reference count hits 0.
|
||||
if self.stats:
|
||||
self.stats.cleared = models_cleared
|
||||
gc.collect()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -2,8 +2,10 @@
|
||||
"""Class for ControlNet model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
@ -33,7 +35,7 @@ class ControlNetLoader(GenericDiffusersLoader):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
image_size = (
|
||||
512
|
||||
@ -45,7 +47,7 @@ class ControlNetLoader(GenericDiffusersLoader):
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
with open(self._app_config.legacy_conf_path / config.config_path, "r") as config_stream:
|
||||
convert_controlnet_to_diffusers(
|
||||
result = convert_controlnet_to_diffusers(
|
||||
model_path,
|
||||
output_path,
|
||||
original_config_file=config_stream,
|
||||
@ -53,4 +55,4 @@ class ControlNetLoader(GenericDiffusersLoader):
|
||||
precision=self._torch_dtype,
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
)
|
||||
return output_path
|
||||
return result
|
||||
|
@ -10,13 +10,14 @@ from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase
|
||||
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
@ -28,14 +29,15 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
model_path = Path(config.path)
|
||||
model_class = self.get_hf_load_class(model_path)
|
||||
if submodel_type is not None:
|
||||
raise Exception(f"There are no submodels in models of type {model_class}")
|
||||
variant = model_variant.value if model_variant else None
|
||||
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
|
||||
variant = repo_variant.value if repo_variant else None
|
||||
try:
|
||||
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant)
|
||||
except OSError as e:
|
||||
|
@ -9,13 +9,14 @@ import torch
|
||||
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
||||
@ -24,13 +25,13 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||
model = build_ip_adapter(
|
||||
model_path = Path(config.path)
|
||||
model: RawModel = build_ip_adapter(
|
||||
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
|
||||
device=torch.device("cpu"),
|
||||
dtype=self._torch_dtype,
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
@ -12,7 +12,6 @@ from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
@ -41,12 +40,12 @@ class LoRALoader(ModelLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in a LoRA model.")
|
||||
model_path = Path(config.path)
|
||||
assert self._model_base is not None
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
@ -56,12 +55,9 @@ class LoRALoader(ModelLoader):
|
||||
return model
|
||||
|
||||
# override
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
self._model_base = (
|
||||
config.base
|
||||
) # cheating a little - we remember this variable for using in the subsequent call to _load_model()
|
||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||
# cheating a little - we remember this variable for using in the subsequent call to _load_model()
|
||||
self._model_base = config.base
|
||||
|
||||
model_base_path = self._app_config.models_path
|
||||
model_path = model_base_path / config.path
|
||||
@ -73,5 +69,4 @@ class LoRALoader(ModelLoader):
|
||||
model_path = path
|
||||
break
|
||||
|
||||
result = model_path.resolve(), config, submodel_type
|
||||
return result
|
||||
return model_path.resolve()
|
||||
|
@ -7,9 +7,9 @@ from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
@ -25,18 +25,19 @@ class OnnyxDiffusersModel(GenericDiffusersLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not submodel_type is not None:
|
||||
raise Exception("A submodel type must be provided when loading onnx pipelines.")
|
||||
model_path = Path(config.path)
|
||||
load_class = self.get_hf_load_class(model_path, submodel_type)
|
||||
variant = model_variant.value if model_variant else None
|
||||
repo_variant = getattr(config, "repo_variant", None)
|
||||
variant = repo_variant.value if repo_variant else None
|
||||
model_path = model_path / submodel_type.value
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
variant=variant,
|
||||
) # type: ignore
|
||||
)
|
||||
return result
|
||||
|
@ -9,12 +9,16 @@ from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig, ModelVariantType
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
DiffusersConfigBase,
|
||||
MainCheckpointConfig,
|
||||
ModelVariantType,
|
||||
)
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
@ -41,14 +45,15 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not submodel_type is not None:
|
||||
raise Exception("A submodel type must be provided when loading main pipelines.")
|
||||
model_path = Path(config.path)
|
||||
load_class = self.get_hf_load_class(model_path, submodel_type)
|
||||
variant = model_variant.value if model_variant else None
|
||||
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
|
||||
variant = repo_variant.value if repo_variant else None
|
||||
model_path = model_path / submodel_type.value
|
||||
try:
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
@ -78,7 +83,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
base = config.base
|
||||
|
||||
@ -94,7 +99,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
|
||||
convert_ckpt_to_diffusers(
|
||||
loaded_model = convert_ckpt_to_diffusers(
|
||||
model_path,
|
||||
output_path,
|
||||
model_type=self.model_base_to_model_type[base],
|
||||
@ -108,4 +113,4 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
load_safety_checker=False,
|
||||
num_in_channels=VARIANT_TO_IN_CHANNEL_MAP[config.variant],
|
||||
)
|
||||
return output_path
|
||||
return loaded_model
|
||||
|
@ -2,14 +2,13 @@
|
||||
"""Class for TI model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
@ -27,22 +26,19 @@ class TextualInversionLoader(ModelLoader):
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in a TI model.")
|
||||
model = TextualInversionModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
file_path=config.path,
|
||||
dtype=self._torch_dtype,
|
||||
)
|
||||
return model
|
||||
|
||||
# override
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||
model_path = self._app_config.models_path / config.path
|
||||
|
||||
if config.format == ModelFormat.EmbeddingFolder:
|
||||
@ -53,4 +49,4 @@ class TextualInversionLoader(ModelLoader):
|
||||
if not path.exists():
|
||||
raise OSError(f"The embedding file at {path} was not found")
|
||||
|
||||
return path, config, submodel_type
|
||||
return path
|
||||
|
Loading…
Reference in New Issue
Block a user