refactor load_ckpt_from_url()

This commit is contained in:
Lincoln Stein
2024-04-28 11:33:23 -04:00
parent d72f272f16
commit 70903ef057
11 changed files with 235 additions and 72 deletions

View File

@ -133,7 +133,9 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model""" """Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image, context: InvocationContext): def infill(self, image: Image.Image, context: InvocationContext):
lama = LaMA(context) # Note that this accesses a protected attribute to get to the model manager service.
# Is there a better way?
lama = LaMA(context._services.model_manager)
return lama(image) return lama(image)

View File

@ -468,7 +468,12 @@ class ModelInstallServiceBase(ABC):
""" """
@abstractmethod @abstractmethod
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path: def download_and_cache_ckpt(
self,
source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: int = 0,
) -> Path:
""" """
Download the model file located at source to the models cache and return its Path. Download the model file located at source to the models cache and return its Path.

View File

@ -399,9 +399,9 @@ class ModelInstallService(ModelInstallServiceBase):
escaped_source = slugify(str(source)) escaped_source = slugify(str(source))
return app_config.download_cache_path / escaped_source return app_config.download_cache_path / escaped_source
def download_and_cache( def download_and_cache_ckpt(
self, self,
source: Union[str, AnyHttpUrl], source: str | AnyHttpUrl,
access_token: Optional[str] = None, access_token: Optional[str] = None,
timeout: int = 0, timeout: int = 0,
) -> Path: ) -> Path:

View File

@ -2,7 +2,10 @@
"""Base class for model loader.""" """Base class for model loader."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from pathlib import Path
from typing import Callable, Dict, Optional
from torch import Tensor
from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
@ -38,3 +41,25 @@ class ModelLoadServiceBase(ABC):
@abstractmethod @abstractmethod
def convert_cache(self) -> ModelConvertCacheBase: def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader.""" """Return the checkpoint convert cache used by this loader."""
@abstractmethod
def load_ckpt_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
) -> LoadedModel:
"""
Load the checkpoint-format model file located at the indicated Path.
This will load an arbitrary model file into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None.
Args:
model_path: A pathlib.Path to a checkpoint-style models file
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""

View File

@ -1,7 +1,13 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team # Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
"""Implementation of model loader service.""" """Implementation of model loader service."""
from typing import Optional, Type from pathlib import Path
from typing import Callable, Dict, Optional, Type
from picklescan.scanner import scan_file_path
from safetensors.torch import load_file as safetensors_load_file
from torch import Tensor
from torch import load as torch_load
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
@ -88,6 +94,51 @@ class ModelLoadService(ModelLoadServiceBase):
) )
return loaded_model return loaded_model
def load_ckpt_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
) -> LoadedModel:
"""
Load the checkpoint-format model file located at the indicated Path.
This will load an arbitrary model file into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None.
Args:
model_path: A pathlib.Path to a checkpoint-style models file
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""
cache_key = str(model_path)
ram_cache = self.ram_cache
try:
return LoadedModel(_locker=ram_cache.get(key=cache_key))
except IndexError:
pass
def torch_load_file(checkpoint: Path) -> Dict[str, Tensor]:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu")
return result
if loader is None:
loader = (
torch_load_file
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
else lambda path: safetensors_load_file(path, device="cpu")
)
raw_model = loader(model_path)
ram_cache.put(key=cache_key, model=raw_model)
return LoadedModel(_locker=ram_cache.get(key=cache_key))
def _emit_load_event( def _emit_load_event(
self, self,
context_data: InvocationContextData, context_data: InvocationContextData,

View File

@ -1,11 +1,15 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, Optional
import torch import torch
from pydantic.networks import AnyHttpUrl
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import LoadedModel
from ..config import InvokeAIAppConfig from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase from ..download import DownloadQueueServiceBase
@ -66,3 +70,35 @@ class ModelManagerServiceBase(ABC):
@abstractmethod @abstractmethod
def stop(self, invoker: Invoker) -> None: def stop(self, invoker: Invoker) -> None:
pass pass
@abstractmethod
def load_ckpt_from_url(
self,
source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
Download, cache, and Load the model file located at the indicated URL.
This will check the model download cache for the model designated
by the provided URL and download it if needed using download_and_cache_ckpt().
It will then load the model into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None.
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""

View File

@ -1,13 +1,15 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase.""" """Implementation of ModelManagerServiceBase."""
from typing import Optional from pathlib import Path
from typing import Callable, Dict, Optional
import torch import torch
from pydantic.networks import AnyHttpUrl
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry from invokeai.backend.model_manager.load import LoadedModel, ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -100,3 +102,36 @@ class ModelManagerService(ModelManagerServiceBase):
event_bus=events, event_bus=events,
) )
return cls(store=model_record_service, install=installer, load=loader) return cls(store=model_record_service, install=installer, load=loader)
def load_ckpt_from_url(
self,
source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
Download, cache, and Load the model file located at the indicated URL.
This will check the model download cache for the model designated
by the provided URL and download it if needed using download_and_cache_ckpt().
It will then load the model into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None.
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""
model_path = self.install.download_and_cache_ckpt(source=source, access_token=access_token, timeout=timeout)
return self.load.load_ckpt_from_path(model_path=model_path, loader=loader)

View File

@ -1,14 +1,11 @@
import threading import threading
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
from picklescan.scanner import scan_file_path import torch
from PIL.Image import Image from PIL.Image import Image
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from safetensors.torch import load_file as safetensors_load_file
from torch import Tensor
from torch import load as torch_load
from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
@ -263,7 +260,7 @@ class ImagesInterface(InvocationContextInterface):
class TensorsInterface(InvocationContextInterface): class TensorsInterface(InvocationContextInterface):
def save(self, tensor: Tensor) -> str: def save(self, tensor: torch.Tensor) -> str:
"""Saves a tensor, returning its name. """Saves a tensor, returning its name.
Args: Args:
@ -276,7 +273,7 @@ class TensorsInterface(InvocationContextInterface):
name = self._services.tensors.save(obj=tensor) name = self._services.tensors.save(obj=tensor)
return name return name
def load(self, name: str) -> Tensor: def load(self, name: str) -> torch.Tensor:
"""Loads a tensor by name. """Loads a tensor by name.
Args: Args:
@ -316,8 +313,10 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface): class ModelsInterface(InvocationContextInterface):
"""Common API for loading, downloading and managing models."""
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool: def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
"""Checks if a model exists. """Check if a model exists.
Args: Args:
identifier: The key or ModelField representing the model. identifier: The key or ModelField representing the model.
@ -326,14 +325,18 @@ class ModelsInterface(InvocationContextInterface):
True if the model exists, False if not. True if the model exists, False if not.
""" """
if isinstance(identifier, str): if isinstance(identifier, str):
return self._services.model_manager.store.exists(identifier) # For some reason, Mypy is not getting the type annotations for many of
# the model manager service calls and raises a "returning Any in typed
return self._services.model_manager.store.exists(identifier.key) # context" error. Hence the extra typing hints here and below.
result: bool = self._services.model_manager.store.exists(identifier)
else:
result = self._services.model_manager.store.exists(identifier.key)
return result
def load( def load(
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
) -> LoadedModel: ) -> LoadedModel:
"""Loads a model. """Load a model.
Args: Args:
identifier: The key or ModelField representing the model. identifier: The key or ModelField representing the model.
@ -342,22 +345,22 @@ class ModelsInterface(InvocationContextInterface):
Returns: Returns:
An object representing the loaded model. An object representing the loaded model.
""" """
# The model manager emits events as it loads the model. It needs the context data to build # The model manager emits events as it loads the model. It needs the context data to build
# the event payloads. # the event payloads.
if isinstance(identifier, str): if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier) model = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.load.load_model(model, submodel_type, self._data) result: LoadedModel = self._services.model_manager.load.load_model(model, submodel_type, self._data)
else: else:
_submodel_type = submodel_type or identifier.submodel_type _submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key) model = self._services.model_manager.store.get_model(identifier.key)
return self._services.model_manager.load.load_model(model, _submodel_type, self._data) result = self._services.model_manager.load.load_model(model, _submodel_type, self._data)
return result
def load_by_attrs( def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
) -> LoadedModel: ) -> LoadedModel:
"""Loads a model by its attributes. """Load a model by its attributes.
Args: Args:
name: Name of the model. name: Name of the model.
@ -369,7 +372,6 @@ class ModelsInterface(InvocationContextInterface):
Returns: Returns:
An object representing the loaded model. An object representing the loaded model.
""" """
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type) configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
if len(configs) == 0: if len(configs) == 0:
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}") raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
@ -377,10 +379,11 @@ class ModelsInterface(InvocationContextInterface):
if len(configs) > 1: if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}") raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data) result: LoadedModel = self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
return result
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig: def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Gets a model's config. """Get a model's config.
Args: Args:
identifier: The key or ModelField representing the model. identifier: The key or ModelField representing the model.
@ -389,12 +392,13 @@ class ModelsInterface(InvocationContextInterface):
The model's config. The model's config.
""" """
if isinstance(identifier, str): if isinstance(identifier, str):
return self._services.model_manager.store.get_model(identifier) result: AnyModelConfig = self._services.model_manager.store.get_model(identifier)
else:
return self._services.model_manager.store.get_model(identifier.key) result = self._services.model_manager.store.get_model(identifier.key)
return result
def search_by_path(self, path: Path) -> list[AnyModelConfig]: def search_by_path(self, path: Path) -> list[AnyModelConfig]:
"""Searches for models by path. """Search for models by path.
Args: Args:
path: The path to search for. path: The path to search for.
@ -402,7 +406,8 @@ class ModelsInterface(InvocationContextInterface):
Returns: Returns:
A list of models that match the path. A list of models that match the path.
""" """
return self._services.model_manager.store.search_by_path(path) result: list[AnyModelConfig] = self._services.model_manager.store.search_by_path(path)
return result
def search_by_attrs( def search_by_attrs(
self, self,
@ -411,7 +416,7 @@ class ModelsInterface(InvocationContextInterface):
type: Optional[ModelType] = None, type: Optional[ModelType] = None,
format: Optional[ModelFormat] = None, format: Optional[ModelFormat] = None,
) -> list[AnyModelConfig]: ) -> list[AnyModelConfig]:
"""Searches for models by attributes. """Search for models by attributes.
Args: Args:
name: The name to search for (exact match). name: The name to search for (exact match).
@ -422,13 +427,13 @@ class ModelsInterface(InvocationContextInterface):
Returns: Returns:
A list of models that match the attributes. A list of models that match the attributes.
""" """
result: list[AnyModelConfig] = self._services.model_manager.store.search_by_attr(
return self._services.model_manager.store.search_by_attr(
model_name=name, model_name=name,
base_model=base, base_model=base,
model_type=type, model_type=type,
model_format=format, model_format=format,
) )
return result
def download_and_cache_ckpt( def download_and_cache_ckpt(
self, self,
@ -451,26 +456,49 @@ class ModelsInterface(InvocationContextInterface):
out long downloads. out long downloads.
Result: Result:
Path of the downloaded model Path to the downloaded model
May Raise: May Raise:
HTTPError HTTPError
TimeoutError TimeoutError
""" """
installer = self._services.model_manager.install installer = self._services.model_manager.install
path: Path = installer.download_and_cache( path: Path = installer.download_and_cache_ckpt(
source=source, source=source,
access_token=access_token, access_token=access_token,
timeout=timeout, timeout=timeout,
) )
return path return path
def load_ckpt_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None
) -> LoadedModel:
"""
Load the checkpoint-format model file located at the indicated Path.
This will load an arbitrary model file into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None.
Args:
model_path: A pathlib.Path to a checkpoint-style models file
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""
result: LoadedModel = self._services.model_manager.load.load_ckpt_from_path(model_path, loader=loader)
return result
def load_ckpt_from_url( def load_ckpt_from_url(
self, self,
source: Union[str, AnyHttpUrl], source: str | AnyHttpUrl,
access_token: Optional[str] = None, access_token: Optional[str] = None,
timeout: Optional[int] = 0, timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel: ) -> LoadedModel:
""" """
Download, cache, and Load the model file located at the indicated URL. Download, cache, and Load the model file located at the indicated URL.
@ -495,29 +523,10 @@ class ModelsInterface(InvocationContextInterface):
Returns: Returns:
A LoadedModel object. A LoadedModel object.
""" """
ram_cache = self._services.model_manager.load.ram_cache result: LoadedModel = self._services.model_manager.load_ckpt_from_url(
try: source=source, access_token=access_token, timeout=timeout, loader=loader
return LoadedModel(_locker=ram_cache.get(key=str(source))) )
except IndexError: return result
pass
def torch_load_file(checkpoint: Path) -> Dict[str | int, Any]:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
return torch_load(path, map_location="cpu")
path = self.download_and_cache_ckpt(source, access_token, timeout)
if loader is None:
loader = (
torch_load_file
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
else lambda path: safetensors_load_file(path, device="cpu")
)
raw_model = loader(path)
ram_cache.put(key=str(source), model=raw_model)
return LoadedModel(_locker=ram_cache.get(key=str(source)))
class ConfigInterface(InvocationContextInterface): class ConfigInterface(InvocationContextInterface):

View File

@ -1,11 +1,13 @@
from typing import Any from typing import TYPE_CHECKING, Any
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.shared.invocation_context import InvocationContext
if TYPE_CHECKING:
from invokeai.app.services.model_manager import ModelManagerServiceBase
def norm_img(np_img): def norm_img(np_img):
@ -16,20 +18,20 @@ def norm_img(np_img):
return np_img return np_img
def load_jit_model(url_or_path, device): def load_jit_model(url_or_path, device) -> torch.nn.Module:
model_path = url_or_path model_path = url_or_path
logger.info(f"Loading model from: {model_path}") logger.info(f"Loading model from: {model_path}")
model = torch.jit.load(model_path, map_location="cpu").to(device) model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore
model.eval() model.eval()
return model return model
class LaMA: class LaMA:
def __init__(self, context: InvocationContext): def __init__(self, model_manager: "ModelManagerServiceBase"):
self._context = context self._model_manager = model_manager
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
loaded_model = self._context.models.load_ckpt_from_url( loaded_model = self._model_manager.load_ckpt_from_url(
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
loader=lambda path: load_jit_model(path, "cpu"), loader=lambda path: load_jit_model(path, "cpu"),
) )

View File

@ -36,7 +36,7 @@ from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models # ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime # RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module] AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
class InvalidModelConfigException(Exception): class InvalidModelConfigException(Exception):

View File

@ -47,5 +47,3 @@ def test_download_and_load(mock_context: InvocationContext):
with loaded_model_1 as model_1, loaded_model_2 as model_2: with loaded_model_1 as model_1, loaded_model_2 as model_2:
assert model_1 == model_2 assert model_1 == model_2
assert isinstance(model_1, dict) assert isinstance(model_1, dict)