diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 9359484416..71a31f7b42 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -2,7 +2,6 @@ import pathlib -import traceback from typing import List, Literal, Optional, Union from fastapi import Body, Path, Query, Response @@ -13,12 +12,13 @@ from starlette.exceptions import HTTPException from invokeai.backend import BaseModelType, ModelType from invokeai.backend.model_manager import ( OPENAPI_MODEL_CONFIGS, + DuplicateModelException, InvalidModelException, - MergeInterpolationMethod, ModelConfigBase, SchedulerPredictionType, UnknownModelException, ) +from invokeai.backend.model_manager.merge import MergeInterpolationMethod from ..dependencies import ApiDependencies @@ -225,15 +225,13 @@ async def convert_model( ), ) -> ConvertModelResponse: """Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.""" - logger = ApiDependencies.invoker.services.logger - info = ApiDependencies.invoker.services.model_manager.model_info(key) try: dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None ApiDependencies.invoker.services.model_manager.convert_model(key, convert_dest_directory=dest) model_raw = ApiDependencies.invoker.services.model_manager.model_info(key).dict() response = parse_obj_as(ConvertModelResponse, model_raw) except UnknownModelException as e: - raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}") + raise HTTPException(status_code=404, detail=f"Model '{key}' not found: {str(e)}") except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return response @@ -252,6 +250,7 @@ async def convert_model( async def search_for_models( search_path: pathlib.Path = Query(description="Directory path to search for models"), ) -> List[pathlib.Path]: + """Search for all models in a server-local path.""" if not search_path.is_dir(): raise HTTPException( status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory" @@ -283,27 +282,31 @@ async def list_ckpt_configs() -> List[pathlib.Path]: response_model=bool, ) async def sync_to_config() -> bool: - """Call after making changes to models.yaml, autoimport directories or models directory to synchronize - in-memory data structures with disk data structures.""" + """ + Synchronize model in-memory data structures with disk. + + Call after making changes to models.yaml, autoimport directories + or models directory. + """ ApiDependencies.invoker.services.model_manager.sync_to_config() return True @models_router.put( - "/merge/{base_model}", + "/merge", operation_id="merge_models", responses={ 200: {"description": "Model converted successfully"}, 400: {"description": "Incompatible models"}, 404: {"description": "One or more models not found"}, + 409: {"description": "An identical merged model is already installed"}, }, status_code=200, response_model=MergeModelResponse, ) async def merge_models( - base_model: BaseModelType = Path(description="Base model"), - model_names: List[str] = Body(description="model name", min_items=2, max_items=3), - merged_model_name: Optional[str] = Body(description="Name of destination model"), + keys: List[str] = Body(description="model name", min_items=2, max_items=3), + merged_model_name: Optional[str] = Body(description="Name of destination model", default=None), alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"), force: Optional[bool] = Body( @@ -314,28 +317,24 @@ async def merge_models( default=None, ), ) -> MergeModelResponse: - """Convert a checkpoint model into a diffusers model""" + """Merge the indicated diffusers model.""" logger = ApiDependencies.invoker.services.logger try: - logger.info(f"Merging models: {model_names} into {merge_dest_directory or ''}/{merged_model_name}") + logger.info(f"Merging models: {keys} into {merge_dest_directory or ''}/{merged_model_name}") dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None - result = ApiDependencies.invoker.services.model_manager.merge_models( - model_names, - base_model, - merged_model_name=merged_model_name or "+".join(model_names), + result: ModelConfigBase = ApiDependencies.invoker.services.model_manager.merge_models( + model_keys=keys, + merged_model_name=merged_model_name, alpha=alpha, interp=interp, force=force, merge_dest_directory=dest, ) - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - result.name, - base_model=base_model, - model_type=ModelType.Main, - ) - response = parse_obj_as(ConvertModelResponse, model_raw) + response = parse_obj_as(ConvertModelResponse, result.dict()) + except DuplicateModelException as e: + raise HTTPException(status_code=409, detail=str(e)) except UnknownModelException: - raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found") + raise HTTPException(status_code=404, detail=f"One or more of the models '{keys}' not found") except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return response diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index ec0a49cd95..44c5eb10c0 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -14,18 +14,17 @@ from invokeai.app.models.exceptions import CanceledException from invokeai.backend.model_manager import ( BaseModelType, DuplicateModelException, - MergeInterpolationMethod, ModelConfigBase, ModelInfo, ModelInstallJob, - ModelLoader, - ModelMerger, + ModelLoad, ModelSearch, ModelType, SubModelType, UnknownModelException, ) from invokeai.backend.model_manager.cache import CacheStats +from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger from .config import InvokeAIAppConfig @@ -291,7 +290,7 @@ class ModelManagerServiceBase(ABC): class ModelManagerService(ModelManagerServiceBase): """Responsible for managing models on disk and in memory.""" - _loader: ModelLoader = Field(description="InvokeAIAppConfig object for the current process") + _loader: ModelLoad = Field(description="InvokeAIAppConfig object for the current process") _event_bus: "EventServiceBase" = Field(description="an event bus to send install events to", default=None) def __init__(self, config: InvokeAIAppConfig, event_bus: Optional["EventServiceBase"] = None): @@ -304,7 +303,7 @@ class ModelManagerService(ModelManagerServiceBase): """ self._event_bus = event_bus handlers = [self._event_bus.emit_model_event] if self._event_bus else None - self._loader = ModelLoader(config, event_handlers=handlers) + self._loader = ModelLoad(config, event_handlers=handlers) def get_model( self, @@ -500,7 +499,7 @@ class ModelManagerService(ModelManagerServiceBase): model_keys: List[str] = Field( default=None, min_items=2, max_items=3, description="List of model keys to merge" ), - merged_model_name: str = Field(default=None, description="Name of destination model after merging"), + merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"), alpha: Optional[float] = 0.5, interp: Optional[MergeInterpolationMethod] = None, force: Optional[bool] = False, @@ -514,8 +513,12 @@ class ModelManagerService(ModelManagerServiceBase): :param interp: Interpolation method. None (default) :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) """ - merger = ModelMerger(self._loader) + merger = ModelMerger(self._loader.store) try: + if not merged_model_name: + merged_model_name = "+".join([self._loader.store.get_model(x).name for x in model_keys]) + raise Exception("not implemented") + self.logger.error("ModelMerger needs to be rewritten.") result = merger.merge_diffusion_models_and_save( model_keys=model_keys, diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index 72d5113864..997867dc04 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -7,7 +7,7 @@ from .model_manager import ( # noqa F401 InvalidModelException, ModelConfigStore, ModelInstall, - ModelLoader, + ModelLoad, ModelType, ModelVariantType, SchedulerPredictionType, diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index 6fc8922b33..98d6c9f06c 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -14,11 +14,10 @@ from .config import ( # noqa F401 SubModelType, ) from .install import ModelInstall, ModelInstallJob # noqa F401 -from .loader import ModelInfo, ModelLoader # noqa F401 +from .loader import ModelInfo, ModelLoad # noqa F401 from .lora import ModelPatcher, ONNXModelPatcher -from .merge import MergeInterpolationMethod, ModelMerger from .models import OPENAPI_MODEL_CONFIGS, read_checkpoint_meta # noqa F401 -from .probe import InvalidModelException, ModelProbe # noqa F401 +from .probe import InvalidModelException, ModelProbeInfo # noqa F401 from .search import ModelSearch # noqa F401 from .storage import ( # noqa F401 DuplicateModelException, diff --git a/invokeai/backend/model_manager/download/base.py b/invokeai/backend/model_manager/download/base.py index 4a02ac02ce..bb017ae6b4 100644 --- a/invokeai/backend/model_manager/download/base.py +++ b/invokeai/backend/model_manager/download/base.py @@ -128,6 +128,12 @@ class DownloadQueueBase(ABC): :param variant: Variant to download, such as "fp16" (repo_ids only). :param event_handlers: Optional callables that will be called whenever job status changes. :returns the job: job.id will be a non-negative value after execution + + Known variants currently are: + 1. onnx + 2. openvino + 3. fp16 + 4. None (usually returns fp32 model) """ pass diff --git a/invokeai/backend/model_manager/download/queue.py b/invokeai/backend/model_manager/download/queue.py index c3654b77bc..0faad6fc61 100644 --- a/invokeai/backend/model_manager/download/queue.py +++ b/invokeai/backend/model_manager/download/queue.py @@ -118,7 +118,9 @@ class DownloadQueue(DownloadQueueBase): access_token: Optional[str] = None, event_handlers: Optional[List[DownloadEventHandler]] = None, ) -> DownloadJobBase: - """Create a download job and return its ID.""" + """ + Create a download job and return its ID. + """ kwargs = dict() if Path(source).exists(): @@ -503,8 +505,8 @@ class DownloadQueue(DownloadQueueBase): repo_id = job.source variant = job.variant urls_to_download, metadata = self._get_repo_info(repo_id, variant) - if job.destination.stem != Path(repo_id).stem: - job.destination = job.destination / Path(repo_id).stem + if job.destination.name != Path(repo_id).name: + job.destination = job.destination / Path(repo_id).name job.metadata = metadata bytes_downloaded = dict() job.total_bytes = 0 @@ -535,7 +537,15 @@ class DownloadQueue(DownloadQueueBase): repo_id: str, variant: Optional[str] = None, ) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], ModelSourceMetadata]: - """Given a repo_id and an optional variant, return list of URLs to download to get the model.""" + """ + Given a repo_id and an optional variant, return list of URLs to download to get the model. + + Known variants currently are: + 1. onnx + 2. openvino + 3. fp16 + 4. None (usually returns fp32 model) + """ model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True) sibs = model_info.siblings paths = [x.rfilename for x in sibs] @@ -564,7 +574,19 @@ class DownloadQueue(DownloadQueueBase): basenames = dict() for p in paths: path = Path(p) - if path.suffix in [".bin", ".safetensors", ".pt"]: + + if path.suffix == ".onnx": + if variant == "onnx": + result.add(path) + + elif path.name.startswith("openvino_model"): + if variant == "openvino": + result.add(path) + + elif path.suffix in [".json", ".txt"]: + result.add(path) + + elif path.suffix in [".bin", ".safetensors", ".pt"] and variant in ["fp16", None]: parent = path.parent suffixes = path.suffixes if len(suffixes) == 2: @@ -584,10 +606,13 @@ class DownloadQueue(DownloadQueueBase): basenames[basename] = path else: basenames[basename] = path + else: - result.add(path) + continue + for v in basenames.values(): result.add(v) + return result def _download_path(self, job: DownloadJobBase): diff --git a/invokeai/backend/model_manager/install.py b/invokeai/backend/model_manager/install.py index 5fbba5b4a1..719a768abe 100644 --- a/invokeai/backend/model_manager/install.py +++ b/invokeai/backend/model_manager/install.py @@ -146,13 +146,19 @@ class ModelInstallBase(ABC): """Return the download queue used by the installer.""" pass + @property @abstractmethod - def register_path(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str: + def store(self) -> ModelConfigStore: + """Return the storage backend used by the installer.""" + pass + + @abstractmethod + def register_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]]) -> str: """ Probe and register the model at model_path. :param model_path: Filesystem Path to the model. - :param info: Optional ModelProbeInfo object. If not provided, model will be probed. + :param overrides: Dict of attributes that will override probed values. :returns id: The string ID of the registered model. """ pass @@ -201,6 +207,12 @@ class ModelInstallBase(ABC): The `inplace` flag does not affect the behavior of downloaded models, which are always moved into the `models` directory. + + Variants recognized by HuggingFace currently are: + 1. onnx + 2. openvino + 3. fp16 + 4. None (usually returns fp32 model) """ pass @@ -349,6 +361,11 @@ class ModelInstall(ModelInstallBase): """Return the queue.""" return self._download_queue + @property + def store(self) -> ModelConfigStore: + """Return the storage backend used by the installer.""" + return self._store + def register_path( self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None ) -> str: # noqa D102 @@ -360,7 +377,7 @@ class ModelInstall(ModelInstallBase): key: str = FastModelHash.hash(model_path) registration_data = dict( path=model_path.as_posix(), - name=model_path.stem, + name=model_path.name if model_path.is_dir() else model_path.stem, base_model=info.base_type, model_type=info.model_type, model_format=info.format, @@ -581,7 +598,7 @@ class ModelInstall(ModelInstallBase): This will raise a ValueError unless the model is a checkpoint. This will raise an UnknownModelException if key is unknown. """ - from .loader import ModelInfo, ModelLoader # to avoid circular imports + from .loader import ModelInfo, ModelLoad # to avoid circular imports new_diffusers_path = None @@ -594,7 +611,7 @@ class ModelInstall(ModelInstallBase): # We are taking advantage of a side effect of get_model() that converts check points # into cached diffusers directories stored at `path`. It doesn't matter # what submodel type we request here, so we get the smallest. - loader = ModelLoader(self._config) + loader = ModelLoad(self._config) submodel = {"submodel_type": SubModelType.Scheduler} if info.model_type == ModelType.Main else {} converted_model: ModelInfo = loader.get_model(key, **submodel) diff --git a/invokeai/backend/model_manager/loader.py b/invokeai/backend/model_manager/loader.py index eb4f5b6743..0275c989e5 100644 --- a/invokeai/backend/model_manager/loader.py +++ b/invokeai/backend/model_manager/loader.py @@ -42,7 +42,7 @@ class ModelInfo: self.context.__exit__(*args, **kwargs) -class ModelLoaderBase(ABC): +class ModelLoadBase(ABC): """Abstract base class for a model loader which works with the ModelConfigStore backend.""" @abstractmethod @@ -113,8 +113,8 @@ class ModelLoaderBase(ABC): pass -class ModelLoader(ModelLoaderBase): - """Implementation of ModelLoaderBase.""" +class ModelLoad(ModelLoadBase): + """Implementation of ModelLoadBase.""" _app_config: InvokeAIAppConfig _store: ModelConfigStore @@ -130,7 +130,7 @@ class ModelLoader(ModelLoaderBase): event_handlers: Optional[List[DownloadEventHandler]] = None, ): """ - Initialize ModelLoader object. + Initialize ModelLoad object. :param config: The app's InvokeAIAppConfig object. """ diff --git a/invokeai/backend/model_manager/merge.py b/invokeai/backend/model_manager/merge.py index a220184e3a..11274d9e70 100644 --- a/invokeai/backend/model_manager/merge.py +++ b/invokeai/backend/model_manager/merge.py @@ -9,14 +9,16 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team import warnings from enum import Enum from pathlib import Path -from typing import List, Optional, Union +from typing import List, Optional from diffusers import DiffusionPipeline from diffusers import logging as dlogging import invokeai.backend.util.logging as logger +from invokeai.app.services.config import InvokeAIAppConfig -from . import BaseModelType, ModelConfigBase, ModelLoader, ModelType, ModelVariantType +from . import ModelConfigBase, ModelConfigStore, ModelInstall, ModelType +from .probe import ModelProbe, ModelProbeInfo class MergeInterpolationMethod(str, Enum): @@ -27,8 +29,18 @@ class MergeInterpolationMethod(str, Enum): class ModelMerger(object): - def __init__(self, manager: ModelLoader): - self.manager = manager + _store: ModelConfigStore + _config: InvokeAIAppConfig + + def __init__(self, store: ModelConfigStore, config: Optional[InvokeAIAppConfig] = None): + """ + Initialize a ModelMerger object. + + :param store: Underlying storage manager for the running process. + :param config: InvokeAIAppConfig object (if not provided, default will be selected). + """ + self._store = store + self._config = config or InvokeAIAppConfig.get_config() def merge_diffusion_models( self, @@ -70,8 +82,7 @@ class ModelMerger(object): def merge_diffusion_models_and_save( self, - model_names: List[str], - base_model: Union[BaseModelType, str], + model_keys: List[str], merged_model_name: str, alpha: float = 0.5, interp: Optional[MergeInterpolationMethod] = None, @@ -93,24 +104,36 @@ class ModelMerger(object): cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map """ model_paths = list() - config = self.manager.app_config - base_model = BaseModelType(base_model) + model_names = list() + config = self._config + store = self._store + base_models = set() vae = None - for mod in model_names: - info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main) - assert info, f"model {mod}, base_model {base_model}, is unknown" + assert ( + len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference + ), "When merging three models, only the 'add_difference' merge method is supported" + + for key in model_keys: + info = store.get_model(key) + model_names.append(info.name) assert ( - info["model_format"] == "diffusers" - ), f"{mod} is not a diffusers model. It must be optimized before merging" - assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged" + info.model_format == "diffusers" + ), f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging" assert ( - len(model_names) <= 2 or interp == MergeInterpolationMethod.AddDifference - ), "When merging three models, only the 'add_difference' merge method is supported" + info.variant == "normal" + ), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged" + # pick up the first model's vae - if mod == model_names[0]: - vae = info.get("vae") - model_paths.extend([(config.root_path / info["path"]).as_posix()]) + if key == model_keys[0]: + vae = info.vae + + # tally base models used + base_models.add(info.base_model) + model_paths.extend([(config.models_path / info.path).as_posix()]) + + assert len(base_models) == 1, f"All models to merge must have same base model, but found bases {base_models}" + base_model = base_models.pop() merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp) logger.debug(f"interp = {interp}, merge_method={merge_method}") @@ -126,18 +149,11 @@ class ModelMerger(object): merged_pipe.save_pretrained(dump_path, safe_serialization=True) # register model and get its unique key - info = ModelProbeInfo( - model_type=ModelType.Main, - base_type=base_model, - format="diffusers", - ) - key = self.manager.installer.register_path( - model_path=dump_path, - info=info, - ) + installer = ModelInstall(store=self._store, config=self._config) + key = installer.register_path(dump_path) # update model's config - model_config = self.manager.store.get_model(key) + model_config = self._store.get_model(key) model_config.update( dict( name=merged_model_name, @@ -145,5 +161,5 @@ class ModelMerger(object): vae=vae, ) ) - self.manager.store.update_model(key, model_config) + self._store.update_model(key, model_config) return model_config diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index f78c700147..4e54c8b869 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -15,26 +15,28 @@ from typing import Callable, Optional import safetensors.torch import torch from picklescan.scanner import scan_file_path +from pydantic import BaseModel from .config import BaseModelType, ModelFormat, ModelType, ModelVariantType, SchedulerPredictionType -from .util import SilenceWarnings, lora_token_vector_length, read_checkpoint_meta +from .hash import FastModelHash +from .util import lora_token_vector_length, read_checkpoint_meta class InvalidModelException(Exception): """Raised when an invalid model is encountered.""" -@dataclass -class ModelProbeInfo(object): +class ModelProbeInfo(BaseModel): """Fields describing a probed model.""" model_type: ModelType base_type: BaseModelType format: ModelFormat - variant_type: ModelVariantType = "normal" - prediction_type: SchedulerPredictionType = "v_prediction" - upcast_attention: bool = False - image_size: int = None + hash: str + variant_type: Optional[ModelVariantType] = "normal" + prediction_type: Optional[SchedulerPredictionType] = "v_prediction" + upcast_attention: Optional[bool] = False + image_size: Optional[int] = None class ModelProbeBase(ABC): @@ -131,6 +133,7 @@ class ModelProbe(ModelProbeBase): variant_type = probe.get_variant_type() prediction_type = probe.get_scheduler_prediction_type() format = probe.get_format() + hash = FastModelHash.hash(model) model_info = ModelProbeInfo( model_type=model_type, @@ -142,6 +145,7 @@ class ModelProbe(ModelProbeBase): and prediction_type == SchedulerPredictionType.VPrediction ), format=format, + hash=hash, image_size=1024 if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}) else 768 diff --git a/scripts/probe-model.py b/scripts/probe-model.py index 14af866634..7067572fca 100755 --- a/scripts/probe-model.py +++ b/scripts/probe-model.py @@ -3,6 +3,7 @@ """Little command-line utility for probing a model on disk.""" import argparse +import json import sys from pathlib import Path @@ -25,6 +26,6 @@ args = parser.parse_args() for path in args.model_path: try: info = ModelProbe().probe(path, helper) - print(f"{path}: {info}") + print(f"{path}:{json.dumps(info.dict(), sort_keys=True, indent=4)}") except InvalidModelException as exc: print(exc) diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py index 9402f81c5f..ceed1abe7d 100644 --- a/tests/test_model_manager.py +++ b/tests/test_model_manager.py @@ -4,7 +4,7 @@ import pytest from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend import BaseModelType, ModelConfigStore, ModelType, SubModelType -from invokeai.backend.model_manager import ModelLoader +from invokeai.backend.model_manager import ModelLoad BASIC_MODEL_NAME = "sdxl-base-1-0" VAE_OVERRIDE_MODEL_NAME = "sdxl-base-with-custom-vae-1-0" @@ -12,18 +12,18 @@ VAE_NULL_OVERRIDE_MODEL_NAME = "sdxl-base-with-empty-vae-1-0" @pytest.fixture -def model_manager(datadir) -> ModelLoader: +def model_manager(datadir) -> ModelLoad: config = InvokeAIAppConfig(root=datadir, conf_path="configs/relative_sub.models.yaml") - return ModelLoader(config=config) + return ModelLoad(config=config) -def test_get_model_names(model_manager: ModelLoader): +def test_get_model_names(model_manager: ModelLoad): store = model_manager.store names = [x.name for x in store.all_models()] assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME] -def test_get_model_path_for_diffusers(model_manager: ModelLoader, datadir: Path): +def test_get_model_path_for_diffusers(model_manager: ModelLoad, datadir: Path): models = model_manager.store.search_by_name(model_name=BASIC_MODEL_NAME) assert len(models) == 1 model_config = models[0] @@ -33,7 +33,7 @@ def test_get_model_path_for_diffusers(model_manager: ModelLoader, datadir: Path) assert not is_override -def test_get_model_path_for_overridden_vae(model_manager: ModelLoader, datadir: Path): +def test_get_model_path_for_overridden_vae(model_manager: ModelLoad, datadir: Path): models = model_manager.store.search_by_name(model_name=VAE_OVERRIDE_MODEL_NAME) assert len(models) == 1 model_config = models[0] @@ -43,7 +43,7 @@ def test_get_model_path_for_overridden_vae(model_manager: ModelLoader, datadir: assert is_override -def test_get_model_path_for_null_overridden_vae(model_manager: ModelLoader, datadir: Path): +def test_get_model_path_for_null_overridden_vae(model_manager: ModelLoad, datadir: Path): model_config = model_manager.store.search_by_name(model_name=VAE_NULL_OVERRIDE_MODEL_NAME)[0] vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) assert not is_override