[mm] Do not write diffuser model to disk when convert_cache set to zero (#6072)

* pass model config to _load_model

* make conversion work again

* do not write diffusers to disk when convert_cache set to 0

* adding same model to cache twice is a no-op, not an assertion error

* fix issues identified by psychedelicious during pr review

* following conversion, avoid redundant read of cached submodels

* fix error introduced while merging

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
Lincoln Stein 2024-03-29 16:11:08 -04:00 committed by GitHub
parent 0ac1c0f339
commit 3d6d89feb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 147 additions and 133 deletions

View File

@ -614,8 +614,8 @@ async def convert_model(
The return value is the model configuration for the converted model. The return value is the model configuration for the converted model.
""" """
model_manager = ApiDependencies.invoker.services.model_manager model_manager = ApiDependencies.invoker.services.model_manager
loader = model_manager.load
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
loader = ApiDependencies.invoker.services.model_manager.load
store = ApiDependencies.invoker.services.model_manager.store store = ApiDependencies.invoker.services.model_manager.store
installer = ApiDependencies.invoker.services.model_manager.install installer = ApiDependencies.invoker.services.model_manager.install
@ -630,7 +630,13 @@ async def convert_model(
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.") raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
# loading the model will convert it into a cached diffusers file # loading the model will convert it into a cached diffusers file
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler) try:
cc_size = loader.convert_cache.max_size
if cc_size == 0: # temporary set the convert cache to a positive number so that cached model is written
loader._convert_cache.max_size = 1.0
loader.load_model(model_config, submodel_type=SubModelType.Scheduler)
finally:
loader._convert_cache.max_size = cc_size
# Get the path of the converted model from the loader # Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key) cache_path = loader.convert_cache.cache_path(key)

View File

@ -33,42 +33,3 @@ __all__ = [
"SchedulerPredictionType", "SchedulerPredictionType",
"SubModelType", "SubModelType",
] ]
########## to help populate the openapi_schema with format enums for each config ###########
# This code is no longer necessary?
# leave it here just in case
#
# import inspect
# from enum import Enum
# from typing import Any, Iterable, Dict, get_args, Set
# def _expand(something: Any) -> Iterable[type]:
# if isinstance(something, type):
# yield something
# else:
# for x in get_args(something):
# for y in _expand(x):
# yield y
# def _find_format(cls: type) -> Iterable[Enum]:
# if hasattr(inspect, "get_annotations"):
# fields = inspect.get_annotations(cls)
# else:
# fields = cls.__annotations__
# if "format" in fields:
# for x in get_args(fields["format"]):
# yield x
# for parent_class in cls.__bases__:
# for x in _find_format(parent_class):
# yield x
# return None
# def get_model_config_formats() -> Dict[str, Set[Enum]]:
# result: Dict[str, Set[Enum]] = {}
# for model_config in _expand(AnyModelConfig):
# for field in _find_format(model_config):
# if field is None:
# continue
# if not result.get(model_config.__qualname__):
# result[model_config.__qualname__] = set()
# result[model_config.__qualname__].add(field)
# return result

View File

@ -3,7 +3,7 @@
"""Conversion script for the Stable Diffusion checkpoints.""" """Conversion script for the Stable Diffusion checkpoints."""
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict, Optional
import torch import torch
from diffusers import AutoencoderKL from diffusers import AutoencoderKL
@ -15,6 +15,8 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
) )
from omegaconf import DictConfig from omegaconf import DictConfig
from . import AnyModel
def convert_ldm_vae_to_diffusers( def convert_ldm_vae_to_diffusers(
checkpoint: Dict[str, torch.Tensor], checkpoint: Dict[str, torch.Tensor],
@ -33,11 +35,11 @@ def convert_ldm_vae_to_diffusers(
def convert_ckpt_to_diffusers( def convert_ckpt_to_diffusers(
checkpoint_path: str | Path, checkpoint_path: str | Path,
dump_path: str | Path, dump_path: Optional[str | Path] = None,
precision: torch.dtype = torch.float16, precision: torch.dtype = torch.float16,
use_safetensors: bool = True, use_safetensors: bool = True,
**kwargs, **kwargs,
): ) -> AnyModel:
""" """
Takes all the arguments of download_from_original_stable_diffusion_ckpt(), Takes all the arguments of download_from_original_stable_diffusion_ckpt(),
and in addition a path-like object indicating the location of the desired diffusers and in addition a path-like object indicating the location of the desired diffusers
@ -47,18 +49,20 @@ def convert_ckpt_to_diffusers(
pipe = pipe.to(precision) pipe = pipe.to(precision)
# TO DO: save correct repo variant # TO DO: save correct repo variant
pipe.save_pretrained( if dump_path:
dump_path, pipe.save_pretrained(
safe_serialization=use_safetensors, dump_path,
) safe_serialization=use_safetensors,
)
return pipe
def convert_controlnet_to_diffusers( def convert_controlnet_to_diffusers(
checkpoint_path: Path, checkpoint_path: Path,
dump_path: Path, dump_path: Optional[Path] = None,
precision: torch.dtype = torch.float16, precision: torch.dtype = torch.float16,
**kwargs, **kwargs,
): ) -> AnyModel:
""" """
Takes all the arguments of download_controlnet_from_original_ckpt(), Takes all the arguments of download_controlnet_from_original_ckpt(),
and in addition a path-like object indicating the location of the desired diffusers and in addition a path-like object indicating the location of the desired diffusers
@ -68,4 +72,6 @@ def convert_controlnet_to_diffusers(
pipe = pipe.to(precision) pipe = pipe.to(precision)
# TO DO: save correct repo variant # TO DO: save correct repo variant
pipe.save_pretrained(dump_path, safe_serialization=True) if dump_path:
pipe.save_pretrained(dump_path, safe_serialization=True)
return pipe

View File

@ -19,11 +19,20 @@ class ModelConvertCache(ModelConvertCacheBase):
self._cache_path = cache_path self._cache_path = cache_path
self._max_size = max_size self._max_size = max_size
# adjust cache size at startup in case it has been changed
if self._cache_path.exists():
self.make_room(0.0)
@property @property
def max_size(self) -> float: def max_size(self) -> float:
"""Return the maximum size of this cache directory (GB).""" """Return the maximum size of this cache directory (GB)."""
return self._max_size return self._max_size
@max_size.setter
def max_size(self, value: float) -> None:
"""Set the maximum size of this cache directory (GB)."""
self._max_size = value
def cache_path(self, key: str) -> Path: def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key.""" """Return the path for a model with the indicated key."""
return self._cache_path / key return self._cache_path / key

View File

@ -83,3 +83,15 @@ class ModelLoaderBase(ABC):
) -> int: ) -> int:
"""Return size in bytes of the model, calculated before loading.""" """Return size in bytes of the model, calculated before loading."""
pass pass
@property
@abstractmethod
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the convert cache associated with this loader."""
pass
@property
@abstractmethod
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the ram cache associated with this loader."""
pass

View File

@ -3,14 +3,13 @@
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple from typing import Optional
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModel, AnyModel,
AnyModelConfig, AnyModelConfig,
InvalidModelConfigException, InvalidModelConfigException,
ModelRepoVariant,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
@ -54,51 +53,43 @@ class ModelLoader(ModelLoaderBase):
if model_config.type is ModelType.Main and not submodel_type: if model_config.type is ModelType.Main and not submodel_type:
raise InvalidModelConfigException("submodel_type is required when loading a main model") raise InvalidModelConfigException("submodel_type is required when loading a main model")
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type) model_path = self._get_model_path(model_config)
if not model_path.exists(): if not model_path.exists():
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}") raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
model_path = self._convert_if_needed(model_config, model_path, submodel_type) with skip_torch_weight_init():
locker = self._load_if_needed(model_config, model_path, submodel_type) locker = self._convert_and_load(model_config, model_path, submodel_type)
return LoadedModel(config=model_config, _locker=locker) return LoadedModel(config=model_config, _locker=locker)
def _get_model_path( @property
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None def convert_cache(self) -> ModelConvertCacheBase:
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]: """Return the convert cache associated with this loader."""
return self._convert_cache
@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the ram cache associated with this loader."""
return self._ram_cache
def _get_model_path(self, config: AnyModelConfig) -> Path:
model_base = self._app_config.models_path model_base = self._app_config.models_path
result = (model_base / config.path).resolve(), config, submodel_type return (model_base / config.path).resolve()
return result
def _convert_if_needed( def _convert_and_load(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
) -> Path:
cache_path: Path = self._convert_cache.cache_path(config.key)
if not self._needs_conversion(config, model_path, cache_path):
return cache_path if cache_path.exists() else model_path
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
return self._convert_model(config, model_path, cache_path)
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
return False
def _load_if_needed(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
) -> ModelLockerBase: ) -> ModelLockerBase:
# TO DO: This is not thread safe!
try: try:
return self._ram_cache.get(config.key, submodel_type) return self._ram_cache.get(config.key, submodel_type)
except IndexError: except IndexError:
pass pass
model_variant = getattr(config, "repo_variant", None) cache_path: Path = self._convert_cache.cache_path(config.key)
self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type)) if self._needs_conversion(config, model_path, cache_path):
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
# This is where the model is actually loaded! else:
with skip_torch_weight_init(): config.path = str(cache_path) if cache_path.exists() else str(self._get_model_path(config))
loaded_model = self._load_model(model_path, model_variant=model_variant, submodel_type=submodel_type) loaded_model = self._load_model(config, submodel_type)
self._ram_cache.put( self._ram_cache.put(
config.key, config.key,
@ -123,15 +114,34 @@ class ModelLoader(ModelLoaderBase):
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None, variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
) )
def _do_convert(
self, config: AnyModelConfig, model_path: Path, cache_path: Path, submodel_type: Optional[SubModelType] = None
) -> AnyModel:
self.convert_cache.make_room(calc_model_size_by_fs(model_path))
pipeline = self._convert_model(config, model_path, cache_path if self.convert_cache.max_size > 0 else None)
if submodel_type:
# Proactively load the various submodels into the RAM cache so that we don't have to re-convert
# the entire pipeline every time a new submodel is needed.
for subtype in SubModelType:
if subtype == submodel_type:
continue
if submodel := getattr(pipeline, subtype.value, None):
self._ram_cache.put(
config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel)
)
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
return False
# This needs to be implemented in subclasses that handle checkpoints # This needs to be implemented in subclasses that handle checkpoints
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
raise NotImplementedError raise NotImplementedError
# This needs to be implemented in the subclass # This needs to be implemented in the subclass
def _load_model( def _load_model(
self, self,
model_path: Path, config: AnyModelConfig,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> AnyModel: ) -> AnyModel:
raise NotImplementedError raise NotImplementedError

View File

@ -122,6 +122,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""Return the cap on cache size.""" """Return the cap on cache size."""
return self._max_cache_size return self._max_cache_size
@max_cache_size.setter
def max_cache_size(self, value: float) -> None:
"""Set the cap on cache size."""
self._max_cache_size = value
@property @property
def stats(self) -> Optional[CacheStats]: def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object.""" """Return collected CacheStats object."""
@ -157,8 +162,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
) -> None: ) -> None:
"""Store model under key and optional submodel_type.""" """Store model under key and optional submodel_type."""
key = self._make_cache_key(key, submodel_type) key = self._make_cache_key(key, submodel_type)
assert key not in self._cached_models if key in self._cached_models:
return
self.make_room(size)
cache_record = CacheRecord(key, model, size) cache_record = CacheRecord(key, model, size)
self._cached_models[key] = cache_record self._cached_models[key] = cache_record
self._cache_stack.append(key) self._cache_stack.append(key)
@ -405,6 +411,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
# #
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up # Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
# immediately when their reference count hits 0. # immediately when their reference count hits 0.
if self.stats:
self.stats.cleared = models_cleared
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -2,8 +2,10 @@
"""Class for ControlNet model loading in InvokeAI.""" """Class for ControlNet model loading in InvokeAI."""
from pathlib import Path from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
ModelFormat, ModelFormat,
@ -33,7 +35,7 @@ class ControlNetLoader(GenericDiffusersLoader):
else: else:
return True return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
assert isinstance(config, CheckpointConfigBase) assert isinstance(config, CheckpointConfigBase)
image_size = ( image_size = (
512 512
@ -45,7 +47,7 @@ class ControlNetLoader(GenericDiffusersLoader):
self._logger.info(f"Converting {model_path} to diffusers format") self._logger.info(f"Converting {model_path} to diffusers format")
with open(self._app_config.legacy_conf_path / config.config_path, "r") as config_stream: with open(self._app_config.legacy_conf_path / config.config_path, "r") as config_stream:
convert_controlnet_to_diffusers( result = convert_controlnet_to_diffusers(
model_path, model_path,
output_path, output_path,
original_config_file=config_stream, original_config_file=config_stream,
@ -53,4 +55,4 @@ class ControlNetLoader(GenericDiffusersLoader):
precision=self._torch_dtype, precision=self._torch_dtype,
from_safetensors=model_path.suffix == ".safetensors", from_safetensors=model_path.suffix == ".safetensors",
) )
return output_path return result

View File

@ -10,13 +10,14 @@ from diffusers.models.modeling_utils import ModelMixin
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModel, AnyModel,
AnyModelConfig,
BaseModelType, BaseModelType,
InvalidModelConfigException, InvalidModelConfigException,
ModelFormat, ModelFormat,
ModelRepoVariant,
ModelType, ModelType,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.config import DiffusersConfigBase
from .. import ModelLoader, ModelLoaderRegistry from .. import ModelLoader, ModelLoaderRegistry
@ -28,14 +29,15 @@ class GenericDiffusersLoader(ModelLoader):
def _load_model( def _load_model(
self, self,
model_path: Path, config: AnyModelConfig,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> AnyModel: ) -> AnyModel:
model_path = Path(config.path)
model_class = self.get_hf_load_class(model_path) model_class = self.get_hf_load_class(model_path)
if submodel_type is not None: if submodel_type is not None:
raise Exception(f"There are no submodels in models of type {model_class}") raise Exception(f"There are no submodels in models of type {model_class}")
variant = model_variant.value if model_variant else None repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
variant = repo_variant.value if repo_variant else None
try: try:
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant)
except OSError as e: except OSError as e:

View File

@ -9,13 +9,14 @@ import torch
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModel, AnyModel,
AnyModelConfig,
BaseModelType, BaseModelType,
ModelFormat, ModelFormat,
ModelRepoVariant,
ModelType, ModelType,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
from invokeai.backend.raw_model import RawModel
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
@ -24,13 +25,13 @@ class IPAdapterInvokeAILoader(ModelLoader):
def _load_model( def _load_model(
self, self,
model_path: Path, config: AnyModelConfig,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> AnyModel: ) -> AnyModel:
if submodel_type is not None: if submodel_type is not None:
raise ValueError("There are no submodels in an IP-Adapter model.") raise ValueError("There are no submodels in an IP-Adapter model.")
model = build_ip_adapter( model_path = Path(config.path)
model: RawModel = build_ip_adapter(
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"), ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
device=torch.device("cpu"), device=torch.device("cpu"),
dtype=self._torch_dtype, dtype=self._torch_dtype,

View File

@ -3,7 +3,7 @@
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple from typing import Optional
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.lora import LoRAModelRaw
@ -12,7 +12,6 @@ from invokeai.backend.model_manager import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
ModelFormat, ModelFormat,
ModelRepoVariant,
ModelType, ModelType,
SubModelType, SubModelType,
) )
@ -41,12 +40,12 @@ class LoRALoader(ModelLoader):
def _load_model( def _load_model(
self, self,
model_path: Path, config: AnyModelConfig,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> AnyModel: ) -> AnyModel:
if submodel_type is not None: if submodel_type is not None:
raise ValueError("There are no submodels in a LoRA model.") raise ValueError("There are no submodels in a LoRA model.")
model_path = Path(config.path)
assert self._model_base is not None assert self._model_base is not None
model = LoRAModelRaw.from_checkpoint( model = LoRAModelRaw.from_checkpoint(
file_path=model_path, file_path=model_path,
@ -56,12 +55,9 @@ class LoRALoader(ModelLoader):
return model return model
# override # override
def _get_model_path( def _get_model_path(self, config: AnyModelConfig) -> Path:
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None # cheating a little - we remember this variable for using in the subsequent call to _load_model()
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]: self._model_base = config.base
self._model_base = (
config.base
) # cheating a little - we remember this variable for using in the subsequent call to _load_model()
model_base_path = self._app_config.models_path model_base_path = self._app_config.models_path
model_path = model_base_path / config.path model_path = model_base_path / config.path
@ -73,5 +69,4 @@ class LoRALoader(ModelLoader):
model_path = path model_path = path
break break
result = model_path.resolve(), config, submodel_type return model_path.resolve()
return result

View File

@ -7,9 +7,9 @@ from typing import Optional
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModel, AnyModel,
AnyModelConfig,
BaseModelType, BaseModelType,
ModelFormat, ModelFormat,
ModelRepoVariant,
ModelType, ModelType,
SubModelType, SubModelType,
) )
@ -25,18 +25,19 @@ class OnnyxDiffusersModel(GenericDiffusersLoader):
def _load_model( def _load_model(
self, self,
model_path: Path, config: AnyModelConfig,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> AnyModel: ) -> AnyModel:
if not submodel_type is not None: if not submodel_type is not None:
raise Exception("A submodel type must be provided when loading onnx pipelines.") raise Exception("A submodel type must be provided when loading onnx pipelines.")
model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path, submodel_type) load_class = self.get_hf_load_class(model_path, submodel_type)
variant = model_variant.value if model_variant else None repo_variant = getattr(config, "repo_variant", None)
variant = repo_variant.value if repo_variant else None
model_path = model_path / submodel_type.value model_path = model_path / submodel_type.value
result: AnyModel = load_class.from_pretrained( result: AnyModel = load_class.from_pretrained(
model_path, model_path,
torch_dtype=self._torch_dtype, torch_dtype=self._torch_dtype,
variant=variant, variant=variant,
) # type: ignore )
return result return result

View File

@ -9,12 +9,16 @@ from invokeai.backend.model_manager import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
ModelFormat, ModelFormat,
ModelRepoVariant,
ModelType, ModelType,
SchedulerPredictionType, SchedulerPredictionType,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig, ModelVariantType from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
DiffusersConfigBase,
MainCheckpointConfig,
ModelVariantType,
)
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from .. import ModelLoaderRegistry from .. import ModelLoaderRegistry
@ -41,14 +45,15 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
def _load_model( def _load_model(
self, self,
model_path: Path, config: AnyModelConfig,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> AnyModel: ) -> AnyModel:
if not submodel_type is not None: if not submodel_type is not None:
raise Exception("A submodel type must be provided when loading main pipelines.") raise Exception("A submodel type must be provided when loading main pipelines.")
model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path, submodel_type) load_class = self.get_hf_load_class(model_path, submodel_type)
variant = model_variant.value if model_variant else None repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
variant = repo_variant.value if repo_variant else None
model_path = model_path / submodel_type.value model_path = model_path / submodel_type.value
try: try:
result: AnyModel = load_class.from_pretrained( result: AnyModel = load_class.from_pretrained(
@ -78,7 +83,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
else: else:
return True return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
assert isinstance(config, MainCheckpointConfig) assert isinstance(config, MainCheckpointConfig)
base = config.base base = config.base
@ -94,7 +99,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
self._logger.info(f"Converting {model_path} to diffusers format") self._logger.info(f"Converting {model_path} to diffusers format")
convert_ckpt_to_diffusers( loaded_model = convert_ckpt_to_diffusers(
model_path, model_path,
output_path, output_path,
model_type=self.model_base_to_model_type[base], model_type=self.model_base_to_model_type[base],
@ -108,4 +113,4 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
load_safety_checker=False, load_safety_checker=False,
num_in_channels=VARIANT_TO_IN_CHANNEL_MAP[config.variant], num_in_channels=VARIANT_TO_IN_CHANNEL_MAP[config.variant],
) )
return output_path return loaded_model

View File

@ -2,14 +2,13 @@
"""Class for TI model loading in InvokeAI.""" """Class for TI model loading in InvokeAI."""
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple from typing import Optional
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModel, AnyModel,
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
ModelFormat, ModelFormat,
ModelRepoVariant,
ModelType, ModelType,
SubModelType, SubModelType,
) )
@ -27,22 +26,19 @@ class TextualInversionLoader(ModelLoader):
def _load_model( def _load_model(
self, self,
model_path: Path, config: AnyModelConfig,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> AnyModel: ) -> AnyModel:
if submodel_type is not None: if submodel_type is not None:
raise ValueError("There are no submodels in a TI model.") raise ValueError("There are no submodels in a TI model.")
model = TextualInversionModelRaw.from_checkpoint( model = TextualInversionModelRaw.from_checkpoint(
file_path=model_path, file_path=config.path,
dtype=self._torch_dtype, dtype=self._torch_dtype,
) )
return model return model
# override # override
def _get_model_path( def _get_model_path(self, config: AnyModelConfig) -> Path:
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
model_path = self._app_config.models_path / config.path model_path = self._app_config.models_path / config.path
if config.format == ModelFormat.EmbeddingFolder: if config.format == ModelFormat.EmbeddingFolder:
@ -53,4 +49,4 @@ class TextualInversionLoader(ModelLoader):
if not path.exists(): if not path.exists():
raise OSError(f"The embedding file at {path} was not found") raise OSError(f"The embedding file at {path} was not found")
return path, config, submodel_type return path