mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
replace load_and_cache_model() with load_remote_model() and load_local_odel()
This commit is contained in:
parent
9f9379682e
commit
dc134935c8
@ -1585,9 +1585,9 @@ Within invocations, the following methods are available from the
|
|||||||
|
|
||||||
### context.download_and_cache_model(source) -> Path
|
### context.download_and_cache_model(source) -> Path
|
||||||
|
|
||||||
This method accepts a `source` of a model, downloads and caches it
|
This method accepts a `source` of a remote model, downloads and caches
|
||||||
locally, and returns a Path to the local model. The source can be a
|
it locally, and then returns a Path to the local model. The source can
|
||||||
local file or directory, a URL, or a HuggingFace repo_id.
|
be a direct download URL or a HuggingFace repo_id.
|
||||||
|
|
||||||
In the case of HuggingFace repo_id, the following variants are
|
In the case of HuggingFace repo_id, the following variants are
|
||||||
recognized:
|
recognized:
|
||||||
@ -1602,16 +1602,34 @@ directory using this syntax:
|
|||||||
|
|
||||||
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||||
|
|
||||||
### context.load_and_cache_model(source, [loader]) -> LoadedModel
|
### context.load_local_model(model_path, [loader]) -> LoadedModel
|
||||||
|
|
||||||
This method takes a model source, downloads it, caches it, and then
|
This method loads a local model from the indicated path, returning a
|
||||||
loads it into the RAM cache for use in inference. The optional loader
|
`LoadedModel`. The optional loader is a Callable that accepts a Path
|
||||||
is a Callable that accepts a Path to the object, and returns a
|
to the object, and returns a `AnyModel` object. If no loader is
|
||||||
`Dict[str, torch.Tensor]`. If no loader is provided, then the method
|
provided, then the method will use `torch.load()` for a .ckpt or .bin
|
||||||
will use `torch.load()` for a .ckpt or .bin checkpoint file,
|
checkpoint file, `safetensors.torch.load_file()` for a safetensors
|
||||||
`safetensors.torch.load_file()` for a safetensors checkpoint file, or
|
checkpoint file, or `cls.from_pretrained()` for a directory that looks
|
||||||
`*.from_pretrained()` for a directory that looks like a
|
like a diffusers directory.
|
||||||
diffusers directory.
|
|
||||||
|
### context.load_remote_model(source, [loader]) -> LoadedModel
|
||||||
|
|
||||||
|
This method accepts a `source` of a remote model, downloads and caches
|
||||||
|
it locally, loads it, and returns a `LoadedModel`. The source can be a
|
||||||
|
direct download URL or a HuggingFace repo_id.
|
||||||
|
|
||||||
|
In the case of HuggingFace repo_id, the following variants are
|
||||||
|
recognized:
|
||||||
|
|
||||||
|
* stabilityai/stable-diffusion-v4 -- default model
|
||||||
|
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
|
||||||
|
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
|
||||||
|
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
|
||||||
|
|
||||||
|
You can also point at an arbitrary individual file within a repo_id
|
||||||
|
directory using this syntax:
|
||||||
|
|
||||||
|
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -611,7 +611,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
||||||
)
|
)
|
||||||
|
|
||||||
with self._context.models.load_and_cache_model(
|
with self._context.models.load_remote_model(
|
||||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
|
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
|
||||||
) as model:
|
) as model:
|
||||||
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
||||||
|
@ -134,7 +134,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
|||||||
"""Infills transparent areas of an image using the LaMa model"""
|
"""Infills transparent areas of an image using the LaMa model"""
|
||||||
|
|
||||||
def infill(self, image: Image.Image):
|
def infill(self, image: Image.Image):
|
||||||
with self._context.models.load_and_cache_model(
|
with self._context.models.load_remote_model(
|
||||||
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||||
loader=LaMA.load_jit_model,
|
loader=LaMA.load_jit_model,
|
||||||
) as model:
|
) as model:
|
||||||
|
@ -91,7 +91,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
context.logger.error(msg)
|
context.logger.error(msg)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
loadnet = context.models.load_and_cache_model(
|
loadnet = context.models.load_remote_model(
|
||||||
source=ESRGAN_MODEL_URLS[self.model_name],
|
source=ESRGAN_MODEL_URLS[self.model_name],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -5,6 +5,8 @@ from abc import ABC, abstractmethod
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.download import DownloadQueueServiceBase
|
from invokeai.app.services.download import DownloadQueueServiceBase
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
@ -241,7 +243,7 @@ class ModelInstallServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def download_and_cache_model(self, source: str) -> Path:
|
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
|
||||||
"""
|
"""
|
||||||
Download the model file located at source to the models cache and return its Path.
|
Download the model file located at source to the models cache and return its Path.
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ import torch
|
|||||||
import yaml
|
import yaml
|
||||||
from huggingface_hub import HfFolder
|
from huggingface_hub import HfFolder
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
from pydantic_core import Url
|
||||||
from requests import Session
|
from requests import Session
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
@ -374,7 +375,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def download_and_cache_model(
|
def download_and_cache_model(
|
||||||
self,
|
self,
|
||||||
source: str,
|
source: str | AnyHttpUrl,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""Download the model file located at source to the models cache and return its Path."""
|
"""Download the model file located at source to the models cache and return its Path."""
|
||||||
model_path = self._download_cache_path(str(source), self._app_config)
|
model_path = self._download_cache_path(str(source), self._app_config)
|
||||||
@ -388,7 +389,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
return contents[0]
|
return contents[0]
|
||||||
|
|
||||||
model_path.mkdir(parents=True, exist_ok=True)
|
model_path.mkdir(parents=True, exist_ok=True)
|
||||||
model_source = self._guess_source(source)
|
model_source = self._guess_source(str(source))
|
||||||
remote_files, _ = self._remote_files_from_source(model_source)
|
remote_files, _ = self._remote_files_from_source(model_source)
|
||||||
job = self._multifile_download(
|
job = self._multifile_download(
|
||||||
dest=model_path,
|
dest=model_path,
|
||||||
@ -447,7 +448,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
)
|
)
|
||||||
elif re.match(r"^https?://[^/]+", source):
|
elif re.match(r"^https?://[^/]+", source):
|
||||||
source_obj = URLModelSource(
|
source_obj = URLModelSource(
|
||||||
url=AnyHttpUrl(source),
|
url=Url(source),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported model source: '{source}'")
|
raise ValueError(f"Unsupported model source: '{source}'")
|
||||||
|
@ -3,9 +3,7 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||||
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
||||||
@ -37,7 +35,7 @@ class ModelLoadServiceBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_model_from_path(
|
def load_model_from_path(
|
||||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
|
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||||
) -> LoadedModelWithoutConfig:
|
) -> LoadedModelWithoutConfig:
|
||||||
"""
|
"""
|
||||||
Load the model file or directory located at the indicated Path.
|
Load the model file or directory located at the indicated Path.
|
||||||
|
@ -2,11 +2,10 @@
|
|||||||
"""Implementation of model loader service."""
|
"""Implementation of model loader service."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Optional, Type
|
from typing import Callable, Optional, Type
|
||||||
|
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
from safetensors.torch import load_file as safetensors_load_file
|
from safetensors.torch import load_file as safetensors_load_file
|
||||||
from torch import Tensor
|
|
||||||
from torch import load as torch_load
|
from torch import load as torch_load
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
@ -86,7 +85,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
||||||
def load_model_from_path(
|
def load_model_from_path(
|
||||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor] | AnyModel]] = None
|
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||||
) -> LoadedModelWithoutConfig:
|
) -> LoadedModelWithoutConfig:
|
||||||
cache_key = str(model_path)
|
cache_key = str(model_path)
|
||||||
ram_cache = self.ram_cache
|
ram_cache = self.ram_cache
|
||||||
@ -95,11 +94,11 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
except IndexError:
|
except IndexError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def torch_load_file(checkpoint: Path) -> Dict[str, Tensor]:
|
def torch_load_file(checkpoint: Path) -> AnyModel:
|
||||||
scan_result = scan_file_path(checkpoint)
|
scan_result = scan_file_path(checkpoint)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||||
result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu")
|
result = torch_load(checkpoint, map_location="cpu")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def diffusers_load_directory(directory: Path) -> AnyModel:
|
def diffusers_load_directory(directory: Path) -> AnyModel:
|
||||||
@ -109,18 +108,16 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
ram_cache=self._ram_cache,
|
ram_cache=self._ram_cache,
|
||||||
convert_cache=self.convert_cache,
|
convert_cache=self.convert_cache,
|
||||||
).get_hf_load_class(directory)
|
).get_hf_load_class(directory)
|
||||||
result: AnyModel = load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
|
return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
|
||||||
return result
|
|
||||||
|
|
||||||
if loader is None:
|
loader = loader or (
|
||||||
loader = (
|
|
||||||
diffusers_load_directory
|
diffusers_load_directory
|
||||||
if model_path.is_dir()
|
if model_path.is_dir()
|
||||||
else torch_load_file
|
else torch_load_file
|
||||||
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
|
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
|
||||||
else lambda path: safetensors_load_file(path, device="cpu")
|
else lambda path: safetensors_load_file(path, device="cpu")
|
||||||
)
|
)
|
||||||
|
assert loader is not None
|
||||||
raw_model = loader(model_path)
|
raw_model = loader(model_path)
|
||||||
ram_cache.put(key=cache_key, model=raw_model)
|
ram_cache.put(key=cache_key, model=raw_model)
|
||||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||||
|
@ -15,7 +15,14 @@ from invokeai.app.services.images.images_common import ImageDTO
|
|||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
from invokeai.backend.model_manager.config import (
|
||||||
|
AnyModel,
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
ModelFormat,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||||
@ -449,21 +456,42 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
installed, the cached path will be returned. Otherwise it will be downloaded.
|
installed, the cached path will be returned. Otherwise it will be downloaded.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source: A model path, URL or repo_id.
|
source: A URL that points to the model, or a huggingface repo_id.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Path to the downloaded model
|
Path to the downloaded model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self._services.model_manager.install.download_and_cache_model(source=source)
|
return self._services.model_manager.install.download_and_cache_model(source=source)
|
||||||
|
|
||||||
def load_and_cache_model(
|
def load_local_model(
|
||||||
self,
|
self,
|
||||||
source: Path | str | AnyHttpUrl,
|
model_path: Path,
|
||||||
loader: Optional[Callable[[Path], dict[str, Tensor]]] = None,
|
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||||
) -> LoadedModelWithoutConfig:
|
) -> LoadedModelWithoutConfig:
|
||||||
"""
|
"""
|
||||||
Download, cache, and load the model file located at the indicated URL.
|
Load the model file located at the indicated path
|
||||||
|
|
||||||
|
If a loader callable is provided, it will be invoked to load the model. Otherwise,
|
||||||
|
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
|
||||||
|
|
||||||
|
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: A model Path
|
||||||
|
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A LoadedModelWithoutConfig object.
|
||||||
|
"""
|
||||||
|
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||||
|
|
||||||
|
def load_remote_model(
|
||||||
|
self,
|
||||||
|
source: str | AnyHttpUrl,
|
||||||
|
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||||
|
) -> LoadedModelWithoutConfig:
|
||||||
|
"""
|
||||||
|
Download, cache, and load the model file located at the indicated URL or repo_id.
|
||||||
|
|
||||||
If the model is already downloaded, it will be loaded from the cache.
|
If the model is already downloaded, it will be loaded from the cache.
|
||||||
|
|
||||||
@ -473,16 +501,12 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
|
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source: A model Path, URL, or repoid.
|
source: A URL or huggingface repoid.
|
||||||
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A LoadedModelWithoutConfig object.
|
A LoadedModelWithoutConfig object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(source, Path):
|
|
||||||
return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
|
|
||||||
else:
|
|
||||||
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||||
|
|
||||||
|
@ -59,14 +59,12 @@ class Migration11Callback:
|
|||||||
|
|
||||||
def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
|
def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
|
||||||
"""
|
"""
|
||||||
Build the migration from database version 9 to 10.
|
Build the migration from database version 10 to 11.
|
||||||
|
|
||||||
This migration does the following:
|
This migration does the following:
|
||||||
- Moves "core" models previously downloaded with download_with_progress_bar() into new
|
- Moves "core" models previously downloaded with download_with_progress_bar() into new
|
||||||
"models/.download_cache" directory.
|
"models/.download_cache" directory.
|
||||||
- Renames "models/.cache" to "models/.convert_cache".
|
- Renames "models/.cache" to "models/.convert_cache".
|
||||||
- Adds `error_type` and `error_message` columns to the session queue table.
|
|
||||||
- Renames the `error` column to `error_traceback`.
|
|
||||||
"""
|
"""
|
||||||
migration_11 = Migration(
|
migration_11 = Migration(
|
||||||
from_version=10,
|
from_version=10,
|
||||||
|
@ -43,14 +43,14 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
|
|||||||
downloaded_path = mock_context.models.download_and_cache_model(
|
downloaded_path = mock_context.models.download_and_cache_model(
|
||||||
"https://www.test.foo/download/test_embedding.safetensors"
|
"https://www.test.foo/download/test_embedding.safetensors"
|
||||||
)
|
)
|
||||||
loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path)
|
loaded_model_1 = mock_context.models.load_local_model(downloaded_path)
|
||||||
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||||
|
|
||||||
loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path)
|
loaded_model_2 = mock_context.models.load_local_model(downloaded_path)
|
||||||
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||||
assert loaded_model_1.model is loaded_model_2.model
|
assert loaded_model_1.model is loaded_model_2.model
|
||||||
|
|
||||||
loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file)
|
loaded_model_3 = mock_context.models.load_local_model(embedding_file)
|
||||||
assert isinstance(loaded_model_3, LoadedModelWithoutConfig)
|
assert isinstance(loaded_model_3, LoadedModelWithoutConfig)
|
||||||
assert loaded_model_1.model is not loaded_model_3.model
|
assert loaded_model_1.model is not loaded_model_3.model
|
||||||
assert isinstance(loaded_model_1.model, dict)
|
assert isinstance(loaded_model_1.model, dict)
|
||||||
@ -58,21 +58,18 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
|
|||||||
assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
|
assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="This requires a test model to load")
|
||||||
def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None:
|
def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None:
|
||||||
loaded_model = mock_context.models.load_and_cache_model(vae_directory)
|
loaded_model = mock_context.models.load_local_model(vae_directory)
|
||||||
assert isinstance(loaded_model, LoadedModelWithoutConfig)
|
assert isinstance(loaded_model, LoadedModelWithoutConfig)
|
||||||
assert isinstance(loaded_model.model, AutoencoderTiny)
|
assert isinstance(loaded_model.model, AutoencoderTiny)
|
||||||
|
|
||||||
|
|
||||||
def test_download_and_load(mock_context: InvocationContext) -> None:
|
def test_download_and_load(mock_context: InvocationContext) -> None:
|
||||||
loaded_model_1 = mock_context.models.load_and_cache_model(
|
loaded_model_1 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
|
||||||
"https://www.test.foo/download/test_embedding.safetensors"
|
|
||||||
)
|
|
||||||
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||||
|
|
||||||
loaded_model_2 = mock_context.models.load_and_cache_model(
|
loaded_model_2 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
|
||||||
"https://www.test.foo/download/test_embedding.safetensors"
|
|
||||||
)
|
|
||||||
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||||
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
|
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
|
||||||
|
|
||||||
|
@ -61,9 +61,11 @@ def embedding_file(mm2_model_files: Path) -> Path:
|
|||||||
return mm2_model_files / "test_embedding.safetensors"
|
return mm2_model_files / "test_embedding.safetensors"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
# Can be used to test diffusers model directory loading, but
|
||||||
def vae_directory(mm2_model_files: Path) -> Path:
|
# the test file adds ~10MB of space.
|
||||||
return mm2_model_files / "taesdxl"
|
# @pytest.fixture
|
||||||
|
# def vae_directory(mm2_model_files: Path) -> Path:
|
||||||
|
# return mm2_model_files / "taesdxl"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
Loading…
Reference in New Issue
Block a user