From b1244400234cd3a44d4b060751c758d231a8da98 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 3 Jun 2024 08:51:21 +1000 Subject: [PATCH] tidy(mm): move `load_model_from_url` from mm to invocation context --- .../model_manager/model_manager_base.py | 31 ----------------- .../model_manager/model_manager_default.py | 34 ++----------------- .../app/services/shared/invocation_context.py | 4 ++- 3 files changed, 5 insertions(+), 64 deletions(-) diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index 063979ebe6..af1b68e1ec 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -1,15 +1,11 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team from abc import ABC, abstractmethod -from pathlib import Path -from typing import Callable, Dict, Optional import torch -from pydantic.networks import AnyHttpUrl from typing_extensions import Self from invokeai.app.services.invoker import Invoker -from invokeai.backend.model_manager.load import LoadedModel from ..config import InvokeAIAppConfig from ..download import DownloadQueueServiceBase @@ -70,30 +66,3 @@ class ModelManagerServiceBase(ABC): @abstractmethod def stop(self, invoker: Invoker) -> None: pass - - @abstractmethod - def load_model_from_url( - self, - source: str | AnyHttpUrl, - loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, - ) -> LoadedModel: - """ - Download, cache, and Load the model file located at the indicated URL. - - This will check the model download cache for the model designated - by the provided URL and download it if needed using download_and_cache_ckpt(). - It will then load the model into the RAM cache. If the optional loader - argument is provided, the loader will be invoked to load the model into - memory. Otherwise the method will call safetensors.torch.load_file() or - torch.load() as appropriate to the file suffix. - - Be aware that the LoadedModel object will have a `config` attribute of None. - - Args: - source: A URL or a string that can be converted in one. Repo_ids - do not work here. - loader: A Callable that expects a Path and returns a Dict[str|int, Any] - - Returns: - A LoadedModel object. - """ diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index dd78f1f3b2..1a2b9a3402 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,15 +1,13 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team """Implementation of ModelManagerServiceBase.""" -from pathlib import Path -from typing import Callable, Dict, Optional +from typing import Optional import torch -from pydantic.networks import AnyHttpUrl from typing_extensions import Self from invokeai.app.services.invoker import Invoker -from invokeai.backend.model_manager.load import LoadedModel, ModelCache, ModelConvertCache, ModelLoaderRegistry +from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger @@ -64,34 +62,6 @@ class ModelManagerService(ModelManagerServiceBase): if hasattr(service, "stop"): service.stop(invoker) - def load_model_from_url( - self, - source: str | AnyHttpUrl, - loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, - ) -> LoadedModel: - """ - Download, cache, and Load the model file located at the indicated URL. - - This will check the model download cache for the model designated - by the provided URL and download it if needed using download_and_cache_ckpt(). - It will then load the model into the RAM cache. If the optional loader - argument is provided, the loader will be invoked to load the model into - memory. Otherwise the method will call safetensors.torch.load_file() or - torch.load() as appropriate to the file suffix. - - Be aware that the LoadedModel object will have a `config` attribute of None. - - Args: - source: A URL or a string that can be converted in one. Repo_ids - do not work here. - loader: A Callable that expects a Path and returns a Dict[str|int, Any] - - Returns: - A LoadedModel object. - """ - model_path = self.install.download_and_cache_model(source=str(source)) - return self.load.load_model_from_path(model_path=model_path, loader=loader) - @classmethod def build_model_manager( cls, diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 27a29f6646..b0d1ee4d2f 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -483,7 +483,9 @@ class ModelsInterface(InvocationContextInterface): if isinstance(source, Path): return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader) else: - return self._services.model_manager.load_model_from_url(source=source, loader=loader) + model_path = self._services.model_manager.install.download_and_cache_model(source=str(source)) + return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader) + class ConfigInterface(InvocationContextInterface):