replace load_and_cache_model() with load_remote_model() and load_local_odel()

This commit is contained in:
Lincoln Stein 2024-06-06 00:31:41 -04:00 committed by psychedelicious
parent 9f9379682e
commit dc134935c8
12 changed files with 106 additions and 69 deletions

View File

@ -1585,9 +1585,9 @@ Within invocations, the following methods are available from the
### context.download_and_cache_model(source) -> Path ### context.download_and_cache_model(source) -> Path
This method accepts a `source` of a model, downloads and caches it This method accepts a `source` of a remote model, downloads and caches
locally, and returns a Path to the local model. The source can be a it locally, and then returns a Path to the local model. The source can
local file or directory, a URL, or a HuggingFace repo_id. be a direct download URL or a HuggingFace repo_id.
In the case of HuggingFace repo_id, the following variants are In the case of HuggingFace repo_id, the following variants are
recognized: recognized:
@ -1602,16 +1602,34 @@ directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors * stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
### context.load_and_cache_model(source, [loader]) -> LoadedModel ### context.load_local_model(model_path, [loader]) -> LoadedModel
This method takes a model source, downloads it, caches it, and then This method loads a local model from the indicated path, returning a
loads it into the RAM cache for use in inference. The optional loader `LoadedModel`. The optional loader is a Callable that accepts a Path
is a Callable that accepts a Path to the object, and returns a to the object, and returns a `AnyModel` object. If no loader is
`Dict[str, torch.Tensor]`. If no loader is provided, then the method provided, then the method will use `torch.load()` for a .ckpt or .bin
will use `torch.load()` for a .ckpt or .bin checkpoint file, checkpoint file, `safetensors.torch.load_file()` for a safetensors
`safetensors.torch.load_file()` for a safetensors checkpoint file, or checkpoint file, or `cls.from_pretrained()` for a directory that looks
`*.from_pretrained()` for a directory that looks like a like a diffusers directory.
diffusers directory.
### context.load_remote_model(source, [loader]) -> LoadedModel
This method accepts a `source` of a remote model, downloads and caches
it locally, loads it, and returns a `LoadedModel`. The source can be a
direct download URL or a HuggingFace repo_id.
In the case of HuggingFace repo_id, the following variants are
recognized:
* stabilityai/stable-diffusion-v4 -- default model
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
You can also point at an arbitrary individual file within a repo_id
directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors

View File

@ -611,7 +611,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device() model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
) )
with self._context.models.load_and_cache_model( with self._context.models.load_remote_model(
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
) as model: ) as model:
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device()) depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())

View File

@ -134,7 +134,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model""" """Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image): def infill(self, image: Image.Image):
with self._context.models.load_and_cache_model( with self._context.models.load_remote_model(
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
loader=LaMA.load_jit_model, loader=LaMA.load_jit_model,
) as model: ) as model:

View File

@ -91,7 +91,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
context.logger.error(msg) context.logger.error(msg)
raise ValueError(msg) raise ValueError(msg)
loadnet = context.models.load_and_cache_model( loadnet = context.models.load_remote_model(
source=ESRGAN_MODEL_URLS[self.model_name], source=ESRGAN_MODEL_URLS[self.model_name],
) )

View File

@ -5,6 +5,8 @@ from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
@ -241,7 +243,7 @@ class ModelInstallServiceBase(ABC):
""" """
@abstractmethod @abstractmethod
def download_and_cache_model(self, source: str) -> Path: def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
""" """
Download the model file located at source to the models cache and return its Path. Download the model file located at source to the models cache and return its Path.

View File

@ -15,6 +15,7 @@ import torch
import yaml import yaml
from huggingface_hub import HfFolder from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from pydantic_core import Url
from requests import Session from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
@ -374,7 +375,7 @@ class ModelInstallService(ModelInstallServiceBase):
def download_and_cache_model( def download_and_cache_model(
self, self,
source: str, source: str | AnyHttpUrl,
) -> Path: ) -> Path:
"""Download the model file located at source to the models cache and return its Path.""" """Download the model file located at source to the models cache and return its Path."""
model_path = self._download_cache_path(str(source), self._app_config) model_path = self._download_cache_path(str(source), self._app_config)
@ -388,7 +389,7 @@ class ModelInstallService(ModelInstallServiceBase):
return contents[0] return contents[0]
model_path.mkdir(parents=True, exist_ok=True) model_path.mkdir(parents=True, exist_ok=True)
model_source = self._guess_source(source) model_source = self._guess_source(str(source))
remote_files, _ = self._remote_files_from_source(model_source) remote_files, _ = self._remote_files_from_source(model_source)
job = self._multifile_download( job = self._multifile_download(
dest=model_path, dest=model_path,
@ -447,7 +448,7 @@ class ModelInstallService(ModelInstallServiceBase):
) )
elif re.match(r"^https?://[^/]+", source): elif re.match(r"^https?://[^/]+", source):
source_obj = URLModelSource( source_obj = URLModelSource(
url=AnyHttpUrl(source), url=Url(source),
) )
else: else:
raise ValueError(f"Unsupported model source: '{source}'") raise ValueError(f"Unsupported model source: '{source}'")

View File

@ -3,9 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Optional from typing import Callable, Optional
from torch import Tensor
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
@ -37,7 +35,7 @@ class ModelLoadServiceBase(ABC):
@abstractmethod @abstractmethod
def load_model_from_path( def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig: ) -> LoadedModelWithoutConfig:
""" """
Load the model file or directory located at the indicated Path. Load the model file or directory located at the indicated Path.

View File

@ -2,11 +2,10 @@
"""Implementation of model loader service.""" """Implementation of model loader service."""
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Optional, Type from typing import Callable, Optional, Type
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from safetensors.torch import load_file as safetensors_load_file from safetensors.torch import load_file as safetensors_load_file
from torch import Tensor
from torch import load as torch_load from torch import load as torch_load
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
@ -86,7 +85,7 @@ class ModelLoadService(ModelLoadServiceBase):
return loaded_model return loaded_model
def load_model_from_path( def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor] | AnyModel]] = None self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig: ) -> LoadedModelWithoutConfig:
cache_key = str(model_path) cache_key = str(model_path)
ram_cache = self.ram_cache ram_cache = self.ram_cache
@ -95,11 +94,11 @@ class ModelLoadService(ModelLoadServiceBase):
except IndexError: except IndexError:
pass pass
def torch_load_file(checkpoint: Path) -> Dict[str, Tensor]: def torch_load_file(checkpoint: Path) -> AnyModel:
scan_result = scan_file_path(checkpoint) scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0: if scan_result.infected_files != 0:
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.") raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu") result = torch_load(checkpoint, map_location="cpu")
return result return result
def diffusers_load_directory(directory: Path) -> AnyModel: def diffusers_load_directory(directory: Path) -> AnyModel:
@ -109,18 +108,16 @@ class ModelLoadService(ModelLoadServiceBase):
ram_cache=self._ram_cache, ram_cache=self._ram_cache,
convert_cache=self.convert_cache, convert_cache=self.convert_cache,
).get_hf_load_class(directory) ).get_hf_load_class(directory)
result: AnyModel = load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype()) return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
return result
if loader is None: loader = loader or (
loader = (
diffusers_load_directory diffusers_load_directory
if model_path.is_dir() if model_path.is_dir()
else torch_load_file else torch_load_file
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")) if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
else lambda path: safetensors_load_file(path, device="cpu") else lambda path: safetensors_load_file(path, device="cpu")
) )
assert loader is not None
raw_model = loader(model_path) raw_model = loader(model_path)
ram_cache.put(key=cache_key, model=raw_model) ram_cache.put(key=cache_key, model=raw_model)
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key)) return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))

View File

@ -15,7 +15,14 @@ from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@ -449,21 +456,42 @@ class ModelsInterface(InvocationContextInterface):
installed, the cached path will be returned. Otherwise it will be downloaded. installed, the cached path will be returned. Otherwise it will be downloaded.
Args: Args:
source: A model path, URL or repo_id. source: A URL that points to the model, or a huggingface repo_id.
Returns: Returns:
Path to the downloaded model Path to the downloaded model
""" """
return self._services.model_manager.install.download_and_cache_model(source=source) return self._services.model_manager.install.download_and_cache_model(source=source)
def load_and_cache_model( def load_local_model(
self, self,
source: Path | str | AnyHttpUrl, model_path: Path,
loader: Optional[Callable[[Path], dict[str, Tensor]]] = None, loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig: ) -> LoadedModelWithoutConfig:
""" """
Download, cache, and load the model file located at the indicated URL. Load the model file located at the indicated path
If a loader callable is provided, it will be invoked to load the model. Otherwise,
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
Args:
path: A model Path
loader: A Callable that expects a Path and returns a dict[str|int, Any]
Returns:
A LoadedModelWithoutConfig object.
"""
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
def load_remote_model(
self,
source: str | AnyHttpUrl,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
Download, cache, and load the model file located at the indicated URL or repo_id.
If the model is already downloaded, it will be loaded from the cache. If the model is already downloaded, it will be loaded from the cache.
@ -473,16 +501,12 @@ class ModelsInterface(InvocationContextInterface):
Be aware that the LoadedModelWithoutConfig object has no `config` attribute Be aware that the LoadedModelWithoutConfig object has no `config` attribute
Args: Args:
source: A model Path, URL, or repoid. source: A URL or huggingface repoid.
loader: A Callable that expects a Path and returns a dict[str|int, Any] loader: A Callable that expects a Path and returns a dict[str|int, Any]
Returns: Returns:
A LoadedModelWithoutConfig object. A LoadedModelWithoutConfig object.
""" """
if isinstance(source, Path):
return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
else:
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source)) model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader) return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)

View File

@ -59,14 +59,12 @@ class Migration11Callback:
def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration: def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
""" """
Build the migration from database version 9 to 10. Build the migration from database version 10 to 11.
This migration does the following: This migration does the following:
- Moves "core" models previously downloaded with download_with_progress_bar() into new - Moves "core" models previously downloaded with download_with_progress_bar() into new
"models/.download_cache" directory. "models/.download_cache" directory.
- Renames "models/.cache" to "models/.convert_cache". - Renames "models/.cache" to "models/.convert_cache".
- Adds `error_type` and `error_message` columns to the session queue table.
- Renames the `error` column to `error_traceback`.
""" """
migration_11 = Migration( migration_11 = Migration(
from_version=10, from_version=10,

View File

@ -43,14 +43,14 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
downloaded_path = mock_context.models.download_and_cache_model( downloaded_path = mock_context.models.download_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors" "https://www.test.foo/download/test_embedding.safetensors"
) )
loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path) loaded_model_1 = mock_context.models.load_local_model(downloaded_path)
assert isinstance(loaded_model_1, LoadedModelWithoutConfig) assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path) loaded_model_2 = mock_context.models.load_local_model(downloaded_path)
assert isinstance(loaded_model_2, LoadedModelWithoutConfig) assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
assert loaded_model_1.model is loaded_model_2.model assert loaded_model_1.model is loaded_model_2.model
loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file) loaded_model_3 = mock_context.models.load_local_model(embedding_file)
assert isinstance(loaded_model_3, LoadedModelWithoutConfig) assert isinstance(loaded_model_3, LoadedModelWithoutConfig)
assert loaded_model_1.model is not loaded_model_3.model assert loaded_model_1.model is not loaded_model_3.model
assert isinstance(loaded_model_1.model, dict) assert isinstance(loaded_model_1.model, dict)
@ -58,21 +58,18 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"]) assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
@pytest.mark.skip(reason="This requires a test model to load")
def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None: def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None:
loaded_model = mock_context.models.load_and_cache_model(vae_directory) loaded_model = mock_context.models.load_local_model(vae_directory)
assert isinstance(loaded_model, LoadedModelWithoutConfig) assert isinstance(loaded_model, LoadedModelWithoutConfig)
assert isinstance(loaded_model.model, AutoencoderTiny) assert isinstance(loaded_model.model, AutoencoderTiny)
def test_download_and_load(mock_context: InvocationContext) -> None: def test_download_and_load(mock_context: InvocationContext) -> None:
loaded_model_1 = mock_context.models.load_and_cache_model( loaded_model_1 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
"https://www.test.foo/download/test_embedding.safetensors"
)
assert isinstance(loaded_model_1, LoadedModelWithoutConfig) assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
loaded_model_2 = mock_context.models.load_and_cache_model( loaded_model_2 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
"https://www.test.foo/download/test_embedding.safetensors"
)
assert isinstance(loaded_model_2, LoadedModelWithoutConfig) assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
assert loaded_model_1.model is loaded_model_2.model # should be cached copy assert loaded_model_1.model is loaded_model_2.model # should be cached copy

View File

@ -61,9 +61,11 @@ def embedding_file(mm2_model_files: Path) -> Path:
return mm2_model_files / "test_embedding.safetensors" return mm2_model_files / "test_embedding.safetensors"
@pytest.fixture # Can be used to test diffusers model directory loading, but
def vae_directory(mm2_model_files: Path) -> Path: # the test file adds ~10MB of space.
return mm2_model_files / "taesdxl" # @pytest.fixture
# def vae_directory(mm2_model_files: Path) -> Path:
# return mm2_model_files / "taesdxl"
@pytest.fixture @pytest.fixture