refactor load_ckpt_from_url()

This commit is contained in:
Lincoln Stein 2024-04-28 11:33:23 -04:00
parent d72f272f16
commit 70903ef057
11 changed files with 235 additions and 72 deletions

View File

@ -133,7 +133,9 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image, context: InvocationContext):
lama = LaMA(context)
# Note that this accesses a protected attribute to get to the model manager service.
# Is there a better way?
lama = LaMA(context._services.model_manager)
return lama(image)

View File

@ -468,7 +468,12 @@ class ModelInstallServiceBase(ABC):
"""
@abstractmethod
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
def download_and_cache_ckpt(
self,
source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: int = 0,
) -> Path:
"""
Download the model file located at source to the models cache and return its Path.

View File

@ -399,9 +399,9 @@ class ModelInstallService(ModelInstallServiceBase):
escaped_source = slugify(str(source))
return app_config.download_cache_path / escaped_source
def download_and_cache(
def download_and_cache_ckpt(
self,
source: Union[str, AnyHttpUrl],
source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: int = 0,
) -> Path:

View File

@ -2,7 +2,10 @@
"""Base class for model loader."""
from abc import ABC, abstractmethod
from typing import Optional
from pathlib import Path
from typing import Callable, Dict, Optional
from torch import Tensor
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
@ -38,3 +41,25 @@ class ModelLoadServiceBase(ABC):
@abstractmethod
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
@abstractmethod
def load_ckpt_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
) -> LoadedModel:
"""
Load the checkpoint-format model file located at the indicated Path.
This will load an arbitrary model file into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None.
Args:
model_path: A pathlib.Path to a checkpoint-style models file
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""

View File

@ -1,7 +1,13 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
"""Implementation of model loader service."""
from typing import Optional, Type
from pathlib import Path
from typing import Callable, Dict, Optional, Type
from picklescan.scanner import scan_file_path
from safetensors.torch import load_file as safetensors_load_file
from torch import Tensor
from torch import load as torch_load
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker
@ -88,6 +94,51 @@ class ModelLoadService(ModelLoadServiceBase):
)
return loaded_model
def load_ckpt_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
) -> LoadedModel:
"""
Load the checkpoint-format model file located at the indicated Path.
This will load an arbitrary model file into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None.
Args:
model_path: A pathlib.Path to a checkpoint-style models file
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""
cache_key = str(model_path)
ram_cache = self.ram_cache
try:
return LoadedModel(_locker=ram_cache.get(key=cache_key))
except IndexError:
pass
def torch_load_file(checkpoint: Path) -> Dict[str, Tensor]:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu")
return result
if loader is None:
loader = (
torch_load_file
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
else lambda path: safetensors_load_file(path, device="cpu")
)
raw_model = loader(model_path)
ram_cache.put(key=cache_key, model=raw_model)
return LoadedModel(_locker=ram_cache.get(key=cache_key))
def _emit_load_event(
self,
context_data: InvocationContextData,

View File

@ -1,11 +1,15 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, Optional
import torch
from pydantic.networks import AnyHttpUrl
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import LoadedModel
from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase
@ -66,3 +70,35 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def stop(self, invoker: Invoker) -> None:
pass
@abstractmethod
def load_ckpt_from_url(
self,
source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
Download, cache, and Load the model file located at the indicated URL.
This will check the model download cache for the model designated
by the provided URL and download it if needed using download_and_cache_ckpt().
It will then load the model into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None.
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""

View File

@ -1,13 +1,15 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
from typing import Optional
from pathlib import Path
from typing import Callable, Dict, Optional
import torch
from pydantic.networks import AnyHttpUrl
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.model_manager.load import LoadedModel, ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@ -100,3 +102,36 @@ class ModelManagerService(ModelManagerServiceBase):
event_bus=events,
)
return cls(store=model_record_service, install=installer, load=loader)
def load_ckpt_from_url(
self,
source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
Download, cache, and Load the model file located at the indicated URL.
This will check the model download cache for the model designated
by the provided URL and download it if needed using download_and_cache_ckpt().
It will then load the model into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None.
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""
model_path = self.install.download_and_cache_ckpt(source=source, access_token=access_token, timeout=timeout)
return self.load.load_ckpt_from_path(model_path=model_path, loader=loader)

View File

@ -1,14 +1,11 @@
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
from picklescan.scanner import scan_file_path
import torch
from PIL.Image import Image
from pydantic.networks import AnyHttpUrl
from safetensors.torch import load_file as safetensors_load_file
from torch import Tensor
from torch import load as torch_load
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
@ -263,7 +260,7 @@ class ImagesInterface(InvocationContextInterface):
class TensorsInterface(InvocationContextInterface):
def save(self, tensor: Tensor) -> str:
def save(self, tensor: torch.Tensor) -> str:
"""Saves a tensor, returning its name.
Args:
@ -276,7 +273,7 @@ class TensorsInterface(InvocationContextInterface):
name = self._services.tensors.save(obj=tensor)
return name
def load(self, name: str) -> Tensor:
def load(self, name: str) -> torch.Tensor:
"""Loads a tensor by name.
Args:
@ -316,8 +313,10 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface):
"""Common API for loading, downloading and managing models."""
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
"""Checks if a model exists.
"""Check if a model exists.
Args:
identifier: The key or ModelField representing the model.
@ -326,14 +325,18 @@ class ModelsInterface(InvocationContextInterface):
True if the model exists, False if not.
"""
if isinstance(identifier, str):
return self._services.model_manager.store.exists(identifier)
return self._services.model_manager.store.exists(identifier.key)
# For some reason, Mypy is not getting the type annotations for many of
# the model manager service calls and raises a "returning Any in typed
# context" error. Hence the extra typing hints here and below.
result: bool = self._services.model_manager.store.exists(identifier)
else:
result = self._services.model_manager.store.exists(identifier.key)
return result
def load(
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""Loads a model.
"""Load a model.
Args:
identifier: The key or ModelField representing the model.
@ -342,22 +345,22 @@ class ModelsInterface(InvocationContextInterface):
Returns:
An object representing the loaded model.
"""
# The model manager emits events as it loads the model. It needs the context data to build
# the event payloads.
if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
result: LoadedModel = self._services.model_manager.load.load_model(model, submodel_type, self._data)
else:
_submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key)
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
result = self._services.model_manager.load.load_model(model, _submodel_type, self._data)
return result
def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""Loads a model by its attributes.
"""Load a model by its attributes.
Args:
name: Name of the model.
@ -369,7 +372,6 @@ class ModelsInterface(InvocationContextInterface):
Returns:
An object representing the loaded model.
"""
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
if len(configs) == 0:
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
@ -377,10 +379,11 @@ class ModelsInterface(InvocationContextInterface):
if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
result: LoadedModel = self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
return result
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Gets a model's config.
"""Get a model's config.
Args:
identifier: The key or ModelField representing the model.
@ -389,12 +392,13 @@ class ModelsInterface(InvocationContextInterface):
The model's config.
"""
if isinstance(identifier, str):
return self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.store.get_model(identifier.key)
result: AnyModelConfig = self._services.model_manager.store.get_model(identifier)
else:
result = self._services.model_manager.store.get_model(identifier.key)
return result
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
"""Searches for models by path.
"""Search for models by path.
Args:
path: The path to search for.
@ -402,7 +406,8 @@ class ModelsInterface(InvocationContextInterface):
Returns:
A list of models that match the path.
"""
return self._services.model_manager.store.search_by_path(path)
result: list[AnyModelConfig] = self._services.model_manager.store.search_by_path(path)
return result
def search_by_attrs(
self,
@ -411,7 +416,7 @@ class ModelsInterface(InvocationContextInterface):
type: Optional[ModelType] = None,
format: Optional[ModelFormat] = None,
) -> list[AnyModelConfig]:
"""Searches for models by attributes.
"""Search for models by attributes.
Args:
name: The name to search for (exact match).
@ -422,13 +427,13 @@ class ModelsInterface(InvocationContextInterface):
Returns:
A list of models that match the attributes.
"""
return self._services.model_manager.store.search_by_attr(
result: list[AnyModelConfig] = self._services.model_manager.store.search_by_attr(
model_name=name,
base_model=base,
model_type=type,
model_format=format,
)
return result
def download_and_cache_ckpt(
self,
@ -451,26 +456,49 @@ class ModelsInterface(InvocationContextInterface):
out long downloads.
Result:
Path of the downloaded model
Path to the downloaded model
May Raise:
HTTPError
TimeoutError
"""
installer = self._services.model_manager.install
path: Path = installer.download_and_cache(
path: Path = installer.download_and_cache_ckpt(
source=source,
access_token=access_token,
timeout=timeout,
)
return path
def load_ckpt_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None
) -> LoadedModel:
"""
Load the checkpoint-format model file located at the indicated Path.
This will load an arbitrary model file into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None.
Args:
model_path: A pathlib.Path to a checkpoint-style models file
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""
result: LoadedModel = self._services.model_manager.load.load_ckpt_from_path(model_path, loader=loader)
return result
def load_ckpt_from_url(
self,
source: Union[str, AnyHttpUrl],
source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
Download, cache, and Load the model file located at the indicated URL.
@ -495,29 +523,10 @@ class ModelsInterface(InvocationContextInterface):
Returns:
A LoadedModel object.
"""
ram_cache = self._services.model_manager.load.ram_cache
try:
return LoadedModel(_locker=ram_cache.get(key=str(source)))
except IndexError:
pass
def torch_load_file(checkpoint: Path) -> Dict[str | int, Any]:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
return torch_load(path, map_location="cpu")
path = self.download_and_cache_ckpt(source, access_token, timeout)
if loader is None:
loader = (
torch_load_file
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
else lambda path: safetensors_load_file(path, device="cpu")
)
raw_model = loader(path)
ram_cache.put(key=str(source), model=raw_model)
return LoadedModel(_locker=ram_cache.get(key=str(source)))
result: LoadedModel = self._services.model_manager.load_ckpt_from_url(
source=source, access_token=access_token, timeout=timeout, loader=loader
)
return result
class ConfigInterface(InvocationContextInterface):

View File

@ -1,11 +1,13 @@
from typing import Any
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.shared.invocation_context import InvocationContext
if TYPE_CHECKING:
from invokeai.app.services.model_manager import ModelManagerServiceBase
def norm_img(np_img):
@ -16,20 +18,20 @@ def norm_img(np_img):
return np_img
def load_jit_model(url_or_path, device):
def load_jit_model(url_or_path, device) -> torch.nn.Module:
model_path = url_or_path
logger.info(f"Loading model from: {model_path}")
model = torch.jit.load(model_path, map_location="cpu").to(device)
model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore
model.eval()
return model
class LaMA:
def __init__(self, context: InvocationContext):
self._context = context
def __init__(self, model_manager: "ModelManagerServiceBase"):
self._model_manager = model_manager
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
loaded_model = self._context.models.load_ckpt_from_url(
loaded_model = self._model_manager.load_ckpt_from_url(
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
loader=lambda path: load_jit_model(path, "cpu"),
)

View File

@ -36,7 +36,7 @@ from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
class InvalidModelConfigException(Exception):

View File

@ -47,5 +47,3 @@ def test_download_and_load(mock_context: InvocationContext):
with loaded_model_1 as model_1, loaded_model_2 as model_2:
assert model_1 == model_2
assert isinstance(model_1, dict)