diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index edee275e72..50fec22994 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -133,7 +133,9 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation): """Infills transparent areas of an image using the LaMa model""" def infill(self, image: Image.Image, context: InvocationContext): - lama = LaMA(context) + # Note that this accesses a protected attribute to get to the model manager service. + # Is there a better way? + lama = LaMA(context._services.model_manager) return lama(image) diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 0ea901fb46..388f4a5ba2 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -468,7 +468,12 @@ class ModelInstallServiceBase(ABC): """ @abstractmethod - def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path: + def download_and_cache_ckpt( + self, + source: str | AnyHttpUrl, + access_token: Optional[str] = None, + timeout: int = 0, + ) -> Path: """ Download the model file located at source to the models cache and return its Path. diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index c4127acf7a..32c86ad3a3 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -399,9 +399,9 @@ class ModelInstallService(ModelInstallServiceBase): escaped_source = slugify(str(source)) return app_config.download_cache_path / escaped_source - def download_and_cache( + def download_and_cache_ckpt( self, - source: Union[str, AnyHttpUrl], + source: str | AnyHttpUrl, access_token: Optional[str] = None, timeout: int = 0, ) -> Path: diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index cc80333e93..d59f7a370d 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -2,7 +2,10 @@ """Base class for model loader.""" from abc import ABC, abstractmethod -from typing import Optional +from pathlib import Path +from typing import Callable, Dict, Optional + +from torch import Tensor from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType @@ -38,3 +41,25 @@ class ModelLoadServiceBase(ABC): @abstractmethod def convert_cache(self) -> ModelConvertCacheBase: """Return the checkpoint convert cache used by this loader.""" + + @abstractmethod + def load_ckpt_from_path( + self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None + ) -> LoadedModel: + """ + Load the checkpoint-format model file located at the indicated Path. + + This will load an arbitrary model file into the RAM cache. If the optional loader + argument is provided, the loader will be invoked to load the model into + memory. Otherwise the method will call safetensors.torch.load_file() or + torch.load() as appropriate to the file suffix. + + Be aware that the LoadedModel object will have a `config` attribute of None. + + Args: + model_path: A pathlib.Path to a checkpoint-style models file + loader: A Callable that expects a Path and returns a Dict[str|int, Any] + + Returns: + A LoadedModel object. + """ diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 21d3c56f36..a87b6123ce 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -1,7 +1,13 @@ # Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team """Implementation of model loader service.""" -from typing import Optional, Type +from pathlib import Path +from typing import Callable, Dict, Optional, Type + +from picklescan.scanner import scan_file_path +from safetensors.torch import load_file as safetensors_load_file +from torch import Tensor +from torch import load as torch_load from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.invoker import Invoker @@ -88,6 +94,51 @@ class ModelLoadService(ModelLoadServiceBase): ) return loaded_model + def load_ckpt_from_path( + self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None + ) -> LoadedModel: + """ + Load the checkpoint-format model file located at the indicated Path. + + This will load an arbitrary model file into the RAM cache. If the optional loader + argument is provided, the loader will be invoked to load the model into + memory. Otherwise the method will call safetensors.torch.load_file() or + torch.load() as appropriate to the file suffix. + + Be aware that the LoadedModel object will have a `config` attribute of None. + + Args: + model_path: A pathlib.Path to a checkpoint-style models file + loader: A Callable that expects a Path and returns a Dict[str|int, Any] + + Returns: + A LoadedModel object. + """ + cache_key = str(model_path) + ram_cache = self.ram_cache + try: + return LoadedModel(_locker=ram_cache.get(key=cache_key)) + except IndexError: + pass + + def torch_load_file(checkpoint: Path) -> Dict[str, Tensor]: + scan_result = scan_file_path(checkpoint) + if scan_result.infected_files != 0: + raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.") + result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu") + return result + + if loader is None: + loader = ( + torch_load_file + if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")) + else lambda path: safetensors_load_file(path, device="cpu") + ) + + raw_model = loader(model_path) + ram_cache.put(key=cache_key, model=raw_model) + return LoadedModel(_locker=ram_cache.get(key=cache_key)) + def _emit_load_event( self, context_data: InvocationContextData, diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index af1b68e1ec..7a5f433aca 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -1,11 +1,15 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team from abc import ABC, abstractmethod +from pathlib import Path +from typing import Callable, Dict, Optional import torch +from pydantic.networks import AnyHttpUrl from typing_extensions import Self from invokeai.app.services.invoker import Invoker +from invokeai.backend.model_manager.load import LoadedModel from ..config import InvokeAIAppConfig from ..download import DownloadQueueServiceBase @@ -66,3 +70,35 @@ class ModelManagerServiceBase(ABC): @abstractmethod def stop(self, invoker: Invoker) -> None: pass + + @abstractmethod + def load_ckpt_from_url( + self, + source: str | AnyHttpUrl, + access_token: Optional[str] = None, + timeout: Optional[int] = 0, + loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, + ) -> LoadedModel: + """ + Download, cache, and Load the model file located at the indicated URL. + + This will check the model download cache for the model designated + by the provided URL and download it if needed using download_and_cache_ckpt(). + It will then load the model into the RAM cache. If the optional loader + argument is provided, the loader will be invoked to load the model into + memory. Otherwise the method will call safetensors.torch.load_file() or + torch.load() as appropriate to the file suffix. + + Be aware that the LoadedModel object will have a `config` attribute of None. + + Args: + source: A URL or a string that can be converted in one. Repo_ids + do not work here. + access_token: Optional access token for restricted resources. + timeout: Wait up to the indicated number of seconds before timing + out long downloads. + loader: A Callable that expects a Path and returns a Dict[str|int, Any] + + Returns: + A LoadedModel object. + """ diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 1a2b9a3402..57c409c066 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,13 +1,15 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team """Implementation of ModelManagerServiceBase.""" -from typing import Optional +from pathlib import Path +from typing import Callable, Dict, Optional import torch +from pydantic.networks import AnyHttpUrl from typing_extensions import Self from invokeai.app.services.invoker import Invoker -from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry +from invokeai.backend.model_manager.load import LoadedModel, ModelCache, ModelConvertCache, ModelLoaderRegistry from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger @@ -100,3 +102,36 @@ class ModelManagerService(ModelManagerServiceBase): event_bus=events, ) return cls(store=model_record_service, install=installer, load=loader) + + def load_ckpt_from_url( + self, + source: str | AnyHttpUrl, + access_token: Optional[str] = None, + timeout: Optional[int] = 0, + loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, + ) -> LoadedModel: + """ + Download, cache, and Load the model file located at the indicated URL. + + This will check the model download cache for the model designated + by the provided URL and download it if needed using download_and_cache_ckpt(). + It will then load the model into the RAM cache. If the optional loader + argument is provided, the loader will be invoked to load the model into + memory. Otherwise the method will call safetensors.torch.load_file() or + torch.load() as appropriate to the file suffix. + + Be aware that the LoadedModel object will have a `config` attribute of None. + + Args: + source: A URL or a string that can be converted in one. Repo_ids + do not work here. + access_token: Optional access token for restricted resources. + timeout: Wait up to the indicated number of seconds before timing + out long downloads. + loader: A Callable that expects a Path and returns a Dict[str|int, Any] + + Returns: + A LoadedModel object. + """ + model_path = self.install.download_and_cache_ckpt(source=source, access_token=access_token, timeout=timeout) + return self.load.load_ckpt_from_path(model_path=model_path, loader=loader) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 50551efa31..485be2ba91 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,14 +1,11 @@ import threading from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +from typing import TYPE_CHECKING, Callable, Dict, Optional, Union -from picklescan.scanner import scan_file_path +import torch from PIL.Image import Image from pydantic.networks import AnyHttpUrl -from safetensors.torch import load_file as safetensors_load_file -from torch import Tensor -from torch import load as torch_load from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata @@ -263,7 +260,7 @@ class ImagesInterface(InvocationContextInterface): class TensorsInterface(InvocationContextInterface): - def save(self, tensor: Tensor) -> str: + def save(self, tensor: torch.Tensor) -> str: """Saves a tensor, returning its name. Args: @@ -276,7 +273,7 @@ class TensorsInterface(InvocationContextInterface): name = self._services.tensors.save(obj=tensor) return name - def load(self, name: str) -> Tensor: + def load(self, name: str) -> torch.Tensor: """Loads a tensor by name. Args: @@ -316,8 +313,10 @@ class ConditioningInterface(InvocationContextInterface): class ModelsInterface(InvocationContextInterface): + """Common API for loading, downloading and managing models.""" + def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool: - """Checks if a model exists. + """Check if a model exists. Args: identifier: The key or ModelField representing the model. @@ -326,14 +325,18 @@ class ModelsInterface(InvocationContextInterface): True if the model exists, False if not. """ if isinstance(identifier, str): - return self._services.model_manager.store.exists(identifier) - - return self._services.model_manager.store.exists(identifier.key) + # For some reason, Mypy is not getting the type annotations for many of + # the model manager service calls and raises a "returning Any in typed + # context" error. Hence the extra typing hints here and below. + result: bool = self._services.model_manager.store.exists(identifier) + else: + result = self._services.model_manager.store.exists(identifier.key) + return result def load( self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None ) -> LoadedModel: - """Loads a model. + """Load a model. Args: identifier: The key or ModelField representing the model. @@ -342,22 +345,22 @@ class ModelsInterface(InvocationContextInterface): Returns: An object representing the loaded model. """ - # The model manager emits events as it loads the model. It needs the context data to build # the event payloads. if isinstance(identifier, str): model = self._services.model_manager.store.get_model(identifier) - return self._services.model_manager.load.load_model(model, submodel_type, self._data) + result: LoadedModel = self._services.model_manager.load.load_model(model, submodel_type, self._data) else: _submodel_type = submodel_type or identifier.submodel_type model = self._services.model_manager.store.get_model(identifier.key) - return self._services.model_manager.load.load_model(model, _submodel_type, self._data) + result = self._services.model_manager.load.load_model(model, _submodel_type, self._data) + return result def load_by_attrs( self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None ) -> LoadedModel: - """Loads a model by its attributes. + """Load a model by its attributes. Args: name: Name of the model. @@ -369,7 +372,6 @@ class ModelsInterface(InvocationContextInterface): Returns: An object representing the loaded model. """ - configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type) if len(configs) == 0: raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}") @@ -377,10 +379,11 @@ class ModelsInterface(InvocationContextInterface): if len(configs) > 1: raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}") - return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data) + result: LoadedModel = self._services.model_manager.load.load_model(configs[0], submodel_type, self._data) + return result def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig: - """Gets a model's config. + """Get a model's config. Args: identifier: The key or ModelField representing the model. @@ -389,12 +392,13 @@ class ModelsInterface(InvocationContextInterface): The model's config. """ if isinstance(identifier, str): - return self._services.model_manager.store.get_model(identifier) - - return self._services.model_manager.store.get_model(identifier.key) + result: AnyModelConfig = self._services.model_manager.store.get_model(identifier) + else: + result = self._services.model_manager.store.get_model(identifier.key) + return result def search_by_path(self, path: Path) -> list[AnyModelConfig]: - """Searches for models by path. + """Search for models by path. Args: path: The path to search for. @@ -402,7 +406,8 @@ class ModelsInterface(InvocationContextInterface): Returns: A list of models that match the path. """ - return self._services.model_manager.store.search_by_path(path) + result: list[AnyModelConfig] = self._services.model_manager.store.search_by_path(path) + return result def search_by_attrs( self, @@ -411,7 +416,7 @@ class ModelsInterface(InvocationContextInterface): type: Optional[ModelType] = None, format: Optional[ModelFormat] = None, ) -> list[AnyModelConfig]: - """Searches for models by attributes. + """Search for models by attributes. Args: name: The name to search for (exact match). @@ -422,13 +427,13 @@ class ModelsInterface(InvocationContextInterface): Returns: A list of models that match the attributes. """ - - return self._services.model_manager.store.search_by_attr( + result: list[AnyModelConfig] = self._services.model_manager.store.search_by_attr( model_name=name, base_model=base, model_type=type, model_format=format, ) + return result def download_and_cache_ckpt( self, @@ -451,26 +456,49 @@ class ModelsInterface(InvocationContextInterface): out long downloads. Result: - Path of the downloaded model + Path to the downloaded model May Raise: HTTPError TimeoutError """ installer = self._services.model_manager.install - path: Path = installer.download_and_cache( + path: Path = installer.download_and_cache_ckpt( source=source, access_token=access_token, timeout=timeout, ) return path + def load_ckpt_from_path( + self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None + ) -> LoadedModel: + """ + Load the checkpoint-format model file located at the indicated Path. + + This will load an arbitrary model file into the RAM cache. If the optional loader + argument is provided, the loader will be invoked to load the model into + memory. Otherwise the method will call safetensors.torch.load_file() or + torch.load() as appropriate to the file suffix. + + Be aware that the LoadedModel object will have a `config` attribute of None. + + Args: + model_path: A pathlib.Path to a checkpoint-style models file + loader: A Callable that expects a Path and returns a Dict[str|int, Any] + + Returns: + A LoadedModel object. + """ + result: LoadedModel = self._services.model_manager.load.load_ckpt_from_path(model_path, loader=loader) + return result + def load_ckpt_from_url( self, - source: Union[str, AnyHttpUrl], + source: str | AnyHttpUrl, access_token: Optional[str] = None, timeout: Optional[int] = 0, - loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None, + loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, ) -> LoadedModel: """ Download, cache, and Load the model file located at the indicated URL. @@ -495,29 +523,10 @@ class ModelsInterface(InvocationContextInterface): Returns: A LoadedModel object. """ - ram_cache = self._services.model_manager.load.ram_cache - try: - return LoadedModel(_locker=ram_cache.get(key=str(source))) - except IndexError: - pass - - def torch_load_file(checkpoint: Path) -> Dict[str | int, Any]: - scan_result = scan_file_path(checkpoint) - if scan_result.infected_files != 0: - raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.") - return torch_load(path, map_location="cpu") - - path = self.download_and_cache_ckpt(source, access_token, timeout) - if loader is None: - loader = ( - torch_load_file - if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")) - else lambda path: safetensors_load_file(path, device="cpu") - ) - - raw_model = loader(path) - ram_cache.put(key=str(source), model=raw_model) - return LoadedModel(_locker=ram_cache.get(key=str(source))) + result: LoadedModel = self._services.model_manager.load_ckpt_from_url( + source=source, access_token=access_token, timeout=timeout, loader=loader + ) + return result class ConfigInterface(InvocationContextInterface): diff --git a/invokeai/backend/image_util/infill_methods/lama.py b/invokeai/backend/image_util/infill_methods/lama.py index 8c3f33efad..c7fea497ca 100644 --- a/invokeai/backend/image_util/infill_methods/lama.py +++ b/invokeai/backend/image_util/infill_methods/lama.py @@ -1,11 +1,13 @@ -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import torch from PIL import Image import invokeai.backend.util.logging as logger -from invokeai.app.services.shared.invocation_context import InvocationContext + +if TYPE_CHECKING: + from invokeai.app.services.model_manager import ModelManagerServiceBase def norm_img(np_img): @@ -16,20 +18,20 @@ def norm_img(np_img): return np_img -def load_jit_model(url_or_path, device): +def load_jit_model(url_or_path, device) -> torch.nn.Module: model_path = url_or_path logger.info(f"Loading model from: {model_path}") - model = torch.jit.load(model_path, map_location="cpu").to(device) + model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore model.eval() return model class LaMA: - def __init__(self, context: InvocationContext): - self._context = context + def __init__(self, model_manager: "ModelManagerServiceBase"): + self._model_manager = model_manager def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: - loaded_model = self._context.models.load_ckpt_from_url( + loaded_model = self._model_manager.load_ckpt_from_url( source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", loader=lambda path: load_jit_model(path, "cpu"), ) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 82f88c0e81..1a5d95b7d8 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -36,7 +36,7 @@ from ..raw_model import RawModel # ModelMixin is the base class for all diffusers and transformers models # RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime -AnyModel = Union[ModelMixin, RawModel, torch.nn.Module] +AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]] class InvalidModelConfigException(Exception): diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py index 463be86c68..167c2a09df 100644 --- a/tests/app/services/model_load/test_load_api.py +++ b/tests/app/services/model_load/test_load_api.py @@ -47,5 +47,3 @@ def test_download_and_load(mock_context: InvocationContext): with loaded_model_1 as model_1, loaded_model_2 as model_2: assert model_1 == model_2 assert isinstance(model_1, dict) - -