replace load_and_cache_model() with load_remote_model() and load_local_odel()

This commit is contained in:
Lincoln Stein 2024-06-06 00:31:41 -04:00 committed by psychedelicious
parent 9f9379682e
commit dc134935c8
12 changed files with 106 additions and 69 deletions

View File

@ -1585,9 +1585,9 @@ Within invocations, the following methods are available from the
### context.download_and_cache_model(source) -> Path
This method accepts a `source` of a model, downloads and caches it
locally, and returns a Path to the local model. The source can be a
local file or directory, a URL, or a HuggingFace repo_id.
This method accepts a `source` of a remote model, downloads and caches
it locally, and then returns a Path to the local model. The source can
be a direct download URL or a HuggingFace repo_id.
In the case of HuggingFace repo_id, the following variants are
recognized:
@ -1602,16 +1602,34 @@ directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
### context.load_and_cache_model(source, [loader]) -> LoadedModel
### context.load_local_model(model_path, [loader]) -> LoadedModel
This method takes a model source, downloads it, caches it, and then
loads it into the RAM cache for use in inference. The optional loader
is a Callable that accepts a Path to the object, and returns a
`Dict[str, torch.Tensor]`. If no loader is provided, then the method
will use `torch.load()` for a .ckpt or .bin checkpoint file,
`safetensors.torch.load_file()` for a safetensors checkpoint file, or
`*.from_pretrained()` for a directory that looks like a
diffusers directory.
This method loads a local model from the indicated path, returning a
`LoadedModel`. The optional loader is a Callable that accepts a Path
to the object, and returns a `AnyModel` object. If no loader is
provided, then the method will use `torch.load()` for a .ckpt or .bin
checkpoint file, `safetensors.torch.load_file()` for a safetensors
checkpoint file, or `cls.from_pretrained()` for a directory that looks
like a diffusers directory.
### context.load_remote_model(source, [loader]) -> LoadedModel
This method accepts a `source` of a remote model, downloads and caches
it locally, loads it, and returns a `LoadedModel`. The source can be a
direct download URL or a HuggingFace repo_id.
In the case of HuggingFace repo_id, the following variants are
recognized:
* stabilityai/stable-diffusion-v4 -- default model
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
You can also point at an arbitrary individual file within a repo_id
directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors

View File

@ -611,7 +611,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
)
with self._context.models.load_and_cache_model(
with self._context.models.load_remote_model(
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
) as model:
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())

View File

@ -134,7 +134,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image):
with self._context.models.load_and_cache_model(
with self._context.models.load_remote_model(
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
loader=LaMA.load_jit_model,
) as model:

View File

@ -91,7 +91,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
context.logger.error(msg)
raise ValueError(msg)
loadnet = context.models.load_and_cache_model(
loadnet = context.models.load_remote_model(
source=ESRGAN_MODEL_URLS[self.model_name],
)

View File

@ -5,6 +5,8 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
@ -241,7 +243,7 @@ class ModelInstallServiceBase(ABC):
"""
@abstractmethod
def download_and_cache_model(self, source: str) -> Path:
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
"""
Download the model file located at source to the models cache and return its Path.

View File

@ -15,6 +15,7 @@ import torch
import yaml
from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl
from pydantic_core import Url
from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig
@ -374,7 +375,7 @@ class ModelInstallService(ModelInstallServiceBase):
def download_and_cache_model(
self,
source: str,
source: str | AnyHttpUrl,
) -> Path:
"""Download the model file located at source to the models cache and return its Path."""
model_path = self._download_cache_path(str(source), self._app_config)
@ -388,7 +389,7 @@ class ModelInstallService(ModelInstallServiceBase):
return contents[0]
model_path.mkdir(parents=True, exist_ok=True)
model_source = self._guess_source(source)
model_source = self._guess_source(str(source))
remote_files, _ = self._remote_files_from_source(model_source)
job = self._multifile_download(
dest=model_path,
@ -447,7 +448,7 @@ class ModelInstallService(ModelInstallServiceBase):
)
elif re.match(r"^https?://[^/]+", source):
source_obj = URLModelSource(
url=AnyHttpUrl(source),
url=Url(source),
)
else:
raise ValueError(f"Unsupported model source: '{source}'")

View File

@ -3,9 +3,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, Optional
from torch import Tensor
from typing import Callable, Optional
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
@ -37,7 +35,7 @@ class ModelLoadServiceBase(ABC):
@abstractmethod
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
"""
Load the model file or directory located at the indicated Path.

View File

@ -2,11 +2,10 @@
"""Implementation of model loader service."""
from pathlib import Path
from typing import Callable, Dict, Optional, Type
from typing import Callable, Optional, Type
from picklescan.scanner import scan_file_path
from safetensors.torch import load_file as safetensors_load_file
from torch import Tensor
from torch import load as torch_load
from invokeai.app.services.config import InvokeAIAppConfig
@ -86,7 +85,7 @@ class ModelLoadService(ModelLoadServiceBase):
return loaded_model
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor] | AnyModel]] = None
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
cache_key = str(model_path)
ram_cache = self.ram_cache
@ -95,11 +94,11 @@ class ModelLoadService(ModelLoadServiceBase):
except IndexError:
pass
def torch_load_file(checkpoint: Path) -> Dict[str, Tensor]:
def torch_load_file(checkpoint: Path) -> AnyModel:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu")
result = torch_load(checkpoint, map_location="cpu")
return result
def diffusers_load_directory(directory: Path) -> AnyModel:
@ -109,18 +108,16 @@ class ModelLoadService(ModelLoadServiceBase):
ram_cache=self._ram_cache,
convert_cache=self.convert_cache,
).get_hf_load_class(directory)
result: AnyModel = load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
return result
if loader is None:
loader = (
diffusers_load_directory
if model_path.is_dir()
else torch_load_file
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
else lambda path: safetensors_load_file(path, device="cpu")
)
return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
loader = loader or (
diffusers_load_directory
if model_path.is_dir()
else torch_load_file
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
else lambda path: safetensors_load_file(path, device="cpu")
)
assert loader is not None
raw_model = loader(model_path)
ram_cache.put(key=cache_key, model=raw_model)
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))

View File

@ -15,7 +15,14 @@ from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@ -449,21 +456,42 @@ class ModelsInterface(InvocationContextInterface):
installed, the cached path will be returned. Otherwise it will be downloaded.
Args:
source: A model path, URL or repo_id.
source: A URL that points to the model, or a huggingface repo_id.
Returns:
Path to the downloaded model
"""
return self._services.model_manager.install.download_and_cache_model(source=source)
def load_and_cache_model(
def load_local_model(
self,
source: Path | str | AnyHttpUrl,
loader: Optional[Callable[[Path], dict[str, Tensor]]] = None,
model_path: Path,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
Download, cache, and load the model file located at the indicated URL.
Load the model file located at the indicated path
If a loader callable is provided, it will be invoked to load the model. Otherwise,
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
Args:
path: A model Path
loader: A Callable that expects a Path and returns a dict[str|int, Any]
Returns:
A LoadedModelWithoutConfig object.
"""
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
def load_remote_model(
self,
source: str | AnyHttpUrl,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
Download, cache, and load the model file located at the indicated URL or repo_id.
If the model is already downloaded, it will be loaded from the cache.
@ -473,18 +501,14 @@ class ModelsInterface(InvocationContextInterface):
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
Args:
source: A model Path, URL, or repoid.
source: A URL or huggingface repoid.
loader: A Callable that expects a Path and returns a dict[str|int, Any]
Returns:
A LoadedModelWithoutConfig object.
"""
if isinstance(source, Path):
return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
else:
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
class ConfigInterface(InvocationContextInterface):

View File

@ -59,14 +59,12 @@ class Migration11Callback:
def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
"""
Build the migration from database version 9 to 10.
Build the migration from database version 10 to 11.
This migration does the following:
- Moves "core" models previously downloaded with download_with_progress_bar() into new
"models/.download_cache" directory.
- Renames "models/.cache" to "models/.convert_cache".
- Adds `error_type` and `error_message` columns to the session queue table.
- Renames the `error` column to `error_traceback`.
"""
migration_11 = Migration(
from_version=10,

View File

@ -43,14 +43,14 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
downloaded_path = mock_context.models.download_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors"
)
loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path)
loaded_model_1 = mock_context.models.load_local_model(downloaded_path)
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path)
loaded_model_2 = mock_context.models.load_local_model(downloaded_path)
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
assert loaded_model_1.model is loaded_model_2.model
loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file)
loaded_model_3 = mock_context.models.load_local_model(embedding_file)
assert isinstance(loaded_model_3, LoadedModelWithoutConfig)
assert loaded_model_1.model is not loaded_model_3.model
assert isinstance(loaded_model_1.model, dict)
@ -58,21 +58,18 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
@pytest.mark.skip(reason="This requires a test model to load")
def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None:
loaded_model = mock_context.models.load_and_cache_model(vae_directory)
loaded_model = mock_context.models.load_local_model(vae_directory)
assert isinstance(loaded_model, LoadedModelWithoutConfig)
assert isinstance(loaded_model.model, AutoencoderTiny)
def test_download_and_load(mock_context: InvocationContext) -> None:
loaded_model_1 = mock_context.models.load_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors"
)
loaded_model_1 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
loaded_model_2 = mock_context.models.load_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors"
)
loaded_model_2 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
assert loaded_model_1.model is loaded_model_2.model # should be cached copy

View File

@ -61,9 +61,11 @@ def embedding_file(mm2_model_files: Path) -> Path:
return mm2_model_files / "test_embedding.safetensors"
@pytest.fixture
def vae_directory(mm2_model_files: Path) -> Path:
return mm2_model_files / "taesdxl"
# Can be used to test diffusers model directory loading, but
# the test file adds ~10MB of space.
# @pytest.fixture
# def vae_directory(mm2_model_files: Path) -> Path:
# return mm2_model_files / "taesdxl"
@pytest.fixture