mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
replace load_and_cache_model() with load_remote_model() and load_local_odel()
This commit is contained in:
parent
9f9379682e
commit
dc134935c8
@ -1585,9 +1585,9 @@ Within invocations, the following methods are available from the
|
||||
|
||||
### context.download_and_cache_model(source) -> Path
|
||||
|
||||
This method accepts a `source` of a model, downloads and caches it
|
||||
locally, and returns a Path to the local model. The source can be a
|
||||
local file or directory, a URL, or a HuggingFace repo_id.
|
||||
This method accepts a `source` of a remote model, downloads and caches
|
||||
it locally, and then returns a Path to the local model. The source can
|
||||
be a direct download URL or a HuggingFace repo_id.
|
||||
|
||||
In the case of HuggingFace repo_id, the following variants are
|
||||
recognized:
|
||||
@ -1602,16 +1602,34 @@ directory using this syntax:
|
||||
|
||||
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||
|
||||
### context.load_and_cache_model(source, [loader]) -> LoadedModel
|
||||
### context.load_local_model(model_path, [loader]) -> LoadedModel
|
||||
|
||||
This method takes a model source, downloads it, caches it, and then
|
||||
loads it into the RAM cache for use in inference. The optional loader
|
||||
is a Callable that accepts a Path to the object, and returns a
|
||||
`Dict[str, torch.Tensor]`. If no loader is provided, then the method
|
||||
will use `torch.load()` for a .ckpt or .bin checkpoint file,
|
||||
`safetensors.torch.load_file()` for a safetensors checkpoint file, or
|
||||
`*.from_pretrained()` for a directory that looks like a
|
||||
diffusers directory.
|
||||
This method loads a local model from the indicated path, returning a
|
||||
`LoadedModel`. The optional loader is a Callable that accepts a Path
|
||||
to the object, and returns a `AnyModel` object. If no loader is
|
||||
provided, then the method will use `torch.load()` for a .ckpt or .bin
|
||||
checkpoint file, `safetensors.torch.load_file()` for a safetensors
|
||||
checkpoint file, or `cls.from_pretrained()` for a directory that looks
|
||||
like a diffusers directory.
|
||||
|
||||
### context.load_remote_model(source, [loader]) -> LoadedModel
|
||||
|
||||
This method accepts a `source` of a remote model, downloads and caches
|
||||
it locally, loads it, and returns a `LoadedModel`. The source can be a
|
||||
direct download URL or a HuggingFace repo_id.
|
||||
|
||||
In the case of HuggingFace repo_id, the following variants are
|
||||
recognized:
|
||||
|
||||
* stabilityai/stable-diffusion-v4 -- default model
|
||||
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
|
||||
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
|
||||
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
|
||||
|
||||
You can also point at an arbitrary individual file within a repo_id
|
||||
directory using this syntax:
|
||||
|
||||
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||
|
||||
|
||||
|
||||
|
@ -611,7 +611,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
with self._context.models.load_and_cache_model(
|
||||
with self._context.models.load_remote_model(
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
|
||||
) as model:
|
||||
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
||||
|
@ -134,7 +134,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
with self._context.models.load_and_cache_model(
|
||||
with self._context.models.load_remote_model(
|
||||
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
loader=LaMA.load_jit_model,
|
||||
) as model:
|
||||
|
@ -91,7 +91,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
context.logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
loadnet = context.models.load_and_cache_model(
|
||||
loadnet = context.models.load_remote_model(
|
||||
source=ESRGAN_MODEL_URLS[self.model_name],
|
||||
)
|
||||
|
||||
|
@ -5,6 +5,8 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
@ -241,7 +243,7 @@ class ModelInstallServiceBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def download_and_cache_model(self, source: str) -> Path:
|
||||
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
|
||||
"""
|
||||
Download the model file located at source to the models cache and return its Path.
|
||||
|
||||
|
@ -15,6 +15,7 @@ import torch
|
||||
import yaml
|
||||
from huggingface_hub import HfFolder
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from pydantic_core import Url
|
||||
from requests import Session
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
@ -374,7 +375,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def download_and_cache_model(
|
||||
self,
|
||||
source: str,
|
||||
source: str | AnyHttpUrl,
|
||||
) -> Path:
|
||||
"""Download the model file located at source to the models cache and return its Path."""
|
||||
model_path = self._download_cache_path(str(source), self._app_config)
|
||||
@ -388,7 +389,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
return contents[0]
|
||||
|
||||
model_path.mkdir(parents=True, exist_ok=True)
|
||||
model_source = self._guess_source(source)
|
||||
model_source = self._guess_source(str(source))
|
||||
remote_files, _ = self._remote_files_from_source(model_source)
|
||||
job = self._multifile_download(
|
||||
dest=model_path,
|
||||
@ -447,7 +448,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
)
|
||||
elif re.match(r"^https?://[^/]+", source):
|
||||
source_obj = URLModelSource(
|
||||
url=AnyHttpUrl(source),
|
||||
url=Url(source),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model source: '{source}'")
|
||||
|
@ -3,9 +3,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from torch import Tensor
|
||||
from typing import Callable, Optional
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
||||
@ -37,7 +35,7 @@ class ModelLoadServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Load the model file or directory located at the indicated Path.
|
||||
|
@ -2,11 +2,10 @@
|
||||
"""Implementation of model loader service."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional, Type
|
||||
from typing import Callable, Optional, Type
|
||||
|
||||
from picklescan.scanner import scan_file_path
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
from torch import Tensor
|
||||
from torch import load as torch_load
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
@ -86,7 +85,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
return loaded_model
|
||||
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor] | AnyModel]] = None
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
cache_key = str(model_path)
|
||||
ram_cache = self.ram_cache
|
||||
@ -95,11 +94,11 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
def torch_load_file(checkpoint: Path) -> Dict[str, Tensor]:
|
||||
def torch_load_file(checkpoint: Path) -> AnyModel:
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||
result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu")
|
||||
result = torch_load(checkpoint, map_location="cpu")
|
||||
return result
|
||||
|
||||
def diffusers_load_directory(directory: Path) -> AnyModel:
|
||||
@ -109,18 +108,16 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
ram_cache=self._ram_cache,
|
||||
convert_cache=self.convert_cache,
|
||||
).get_hf_load_class(directory)
|
||||
result: AnyModel = load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
|
||||
return result
|
||||
|
||||
if loader is None:
|
||||
loader = (
|
||||
diffusers_load_directory
|
||||
if model_path.is_dir()
|
||||
else torch_load_file
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
|
||||
else lambda path: safetensors_load_file(path, device="cpu")
|
||||
)
|
||||
return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
|
||||
|
||||
loader = loader or (
|
||||
diffusers_load_directory
|
||||
if model_path.is_dir()
|
||||
else torch_load_file
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
|
||||
else lambda path: safetensors_load_file(path, device="cpu")
|
||||
)
|
||||
assert loader is not None
|
||||
raw_model = loader(model_path)
|
||||
ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
|
@ -15,7 +15,14 @@ from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
@ -449,21 +456,42 @@ class ModelsInterface(InvocationContextInterface):
|
||||
installed, the cached path will be returned. Otherwise it will be downloaded.
|
||||
|
||||
Args:
|
||||
source: A model path, URL or repo_id.
|
||||
source: A URL that points to the model, or a huggingface repo_id.
|
||||
|
||||
Returns:
|
||||
Path to the downloaded model
|
||||
"""
|
||||
|
||||
return self._services.model_manager.install.download_and_cache_model(source=source)
|
||||
|
||||
def load_and_cache_model(
|
||||
def load_local_model(
|
||||
self,
|
||||
source: Path | str | AnyHttpUrl,
|
||||
loader: Optional[Callable[[Path], dict[str, Tensor]]] = None,
|
||||
model_path: Path,
|
||||
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Download, cache, and load the model file located at the indicated URL.
|
||||
Load the model file located at the indicated path
|
||||
|
||||
If a loader callable is provided, it will be invoked to load the model. Otherwise,
|
||||
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
|
||||
|
||||
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
|
||||
|
||||
Args:
|
||||
path: A model Path
|
||||
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
def load_remote_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Download, cache, and load the model file located at the indicated URL or repo_id.
|
||||
|
||||
If the model is already downloaded, it will be loaded from the cache.
|
||||
|
||||
@ -473,18 +501,14 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
|
||||
|
||||
Args:
|
||||
source: A model Path, URL, or repoid.
|
||||
source: A URL or huggingface repoid.
|
||||
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
|
||||
if isinstance(source, Path):
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
|
||||
else:
|
||||
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
|
@ -59,14 +59,12 @@ class Migration11Callback:
|
||||
|
||||
def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
|
||||
"""
|
||||
Build the migration from database version 9 to 10.
|
||||
Build the migration from database version 10 to 11.
|
||||
|
||||
This migration does the following:
|
||||
- Moves "core" models previously downloaded with download_with_progress_bar() into new
|
||||
"models/.download_cache" directory.
|
||||
- Renames "models/.cache" to "models/.convert_cache".
|
||||
- Adds `error_type` and `error_message` columns to the session queue table.
|
||||
- Renames the `error` column to `error_traceback`.
|
||||
"""
|
||||
migration_11 = Migration(
|
||||
from_version=10,
|
||||
|
@ -43,14 +43,14 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
|
||||
downloaded_path = mock_context.models.download_and_cache_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path)
|
||||
loaded_model_1 = mock_context.models.load_local_model(downloaded_path)
|
||||
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||
|
||||
loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path)
|
||||
loaded_model_2 = mock_context.models.load_local_model(downloaded_path)
|
||||
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is loaded_model_2.model
|
||||
|
||||
loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file)
|
||||
loaded_model_3 = mock_context.models.load_local_model(embedding_file)
|
||||
assert isinstance(loaded_model_3, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is not loaded_model_3.model
|
||||
assert isinstance(loaded_model_1.model, dict)
|
||||
@ -58,21 +58,18 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
|
||||
assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This requires a test model to load")
|
||||
def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None:
|
||||
loaded_model = mock_context.models.load_and_cache_model(vae_directory)
|
||||
loaded_model = mock_context.models.load_local_model(vae_directory)
|
||||
assert isinstance(loaded_model, LoadedModelWithoutConfig)
|
||||
assert isinstance(loaded_model.model, AutoencoderTiny)
|
||||
|
||||
|
||||
def test_download_and_load(mock_context: InvocationContext) -> None:
|
||||
loaded_model_1 = mock_context.models.load_and_cache_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
loaded_model_1 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
|
||||
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||
|
||||
loaded_model_2 = mock_context.models.load_and_cache_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
loaded_model_2 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
|
||||
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
|
||||
|
||||
|
@ -61,9 +61,11 @@ def embedding_file(mm2_model_files: Path) -> Path:
|
||||
return mm2_model_files / "test_embedding.safetensors"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vae_directory(mm2_model_files: Path) -> Path:
|
||||
return mm2_model_files / "taesdxl"
|
||||
# Can be used to test diffusers model directory loading, but
|
||||
# the test file adds ~10MB of space.
|
||||
# @pytest.fixture
|
||||
# def vae_directory(mm2_model_files: Path) -> Path:
|
||||
# return mm2_model_files / "taesdxl"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
Loading…
Reference in New Issue
Block a user