tidy(mm): move load_model_from_url from mm to invocation context

This commit is contained in:
psychedelicious 2024-06-03 08:51:21 +10:00
parent e3a70e598e
commit b124440023
3 changed files with 5 additions and 64 deletions

View File

@ -1,15 +1,11 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, Optional
import torch
from pydantic.networks import AnyHttpUrl
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import LoadedModel
from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase
@ -70,30 +66,3 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def stop(self, invoker: Invoker) -> None:
pass
@abstractmethod
def load_model_from_url(
self,
source: str | AnyHttpUrl,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
Download, cache, and Load the model file located at the indicated URL.
This will check the model download cache for the model designated
by the provided URL and download it if needed using download_and_cache_ckpt().
It will then load the model into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None.
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""

View File

@ -1,15 +1,13 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
from pathlib import Path
from typing import Callable, Dict, Optional
from typing import Optional
import torch
from pydantic.networks import AnyHttpUrl
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import LoadedModel, ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@ -64,34 +62,6 @@ class ModelManagerService(ModelManagerServiceBase):
if hasattr(service, "stop"):
service.stop(invoker)
def load_model_from_url(
self,
source: str | AnyHttpUrl,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
Download, cache, and Load the model file located at the indicated URL.
This will check the model download cache for the model designated
by the provided URL and download it if needed using download_and_cache_ckpt().
It will then load the model into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None.
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""
model_path = self.install.download_and_cache_model(source=str(source))
return self.load.load_model_from_path(model_path=model_path, loader=loader)
@classmethod
def build_model_manager(
cls,

View File

@ -483,7 +483,9 @@ class ModelsInterface(InvocationContextInterface):
if isinstance(source, Path):
return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
else:
return self._services.model_manager.load_model_from_url(source=source, loader=loader)
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
class ConfigInterface(InvocationContextInterface):