mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(mm): move load_model_from_url
from mm to invocation context
This commit is contained in:
parent
e3a70e598e
commit
b124440023
@ -1,15 +1,11 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
from ..download import DownloadQueueServiceBase
|
||||
@ -70,30 +66,3 @@ class ModelManagerServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_from_url(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Download, cache, and Load the model file located at the indicated URL.
|
||||
|
||||
This will check the model download cache for the model designated
|
||||
by the provided URL and download it if needed using download_and_cache_ckpt().
|
||||
It will then load the model into the RAM cache. If the optional loader
|
||||
argument is provided, the loader will be invoked to load the model into
|
||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||
torch.load() as appropriate to the file suffix.
|
||||
|
||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||
|
||||
Args:
|
||||
source: A URL or a string that can be converted in one. Repo_ids
|
||||
do not work here.
|
||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
"""
|
||||
|
@ -1,15 +1,13 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of ModelManagerServiceBase."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load import LoadedModel, ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
@ -64,34 +62,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
if hasattr(service, "stop"):
|
||||
service.stop(invoker)
|
||||
|
||||
def load_model_from_url(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Download, cache, and Load the model file located at the indicated URL.
|
||||
|
||||
This will check the model download cache for the model designated
|
||||
by the provided URL and download it if needed using download_and_cache_ckpt().
|
||||
It will then load the model into the RAM cache. If the optional loader
|
||||
argument is provided, the loader will be invoked to load the model into
|
||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||
torch.load() as appropriate to the file suffix.
|
||||
|
||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||
|
||||
Args:
|
||||
source: A URL or a string that can be converted in one. Repo_ids
|
||||
do not work here.
|
||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
"""
|
||||
model_path = self.install.download_and_cache_model(source=str(source))
|
||||
return self.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
@classmethod
|
||||
def build_model_manager(
|
||||
cls,
|
||||
|
@ -483,7 +483,9 @@ class ModelsInterface(InvocationContextInterface):
|
||||
if isinstance(source, Path):
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
|
||||
else:
|
||||
return self._services.model_manager.load_model_from_url(source=source, loader=loader)
|
||||
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
|
Loading…
Reference in New Issue
Block a user