mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(mm): move load_model_from_url
from mm to invocation context
This commit is contained in:
parent
e3a70e598e
commit
b124440023
@ -1,15 +1,11 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, Dict, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic.networks import AnyHttpUrl
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.backend.model_manager.load import LoadedModel
|
|
||||||
|
|
||||||
from ..config import InvokeAIAppConfig
|
from ..config import InvokeAIAppConfig
|
||||||
from ..download import DownloadQueueServiceBase
|
from ..download import DownloadQueueServiceBase
|
||||||
@ -70,30 +66,3 @@ class ModelManagerServiceBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def stop(self, invoker: Invoker) -> None:
|
def stop(self, invoker: Invoker) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load_model_from_url(
|
|
||||||
self,
|
|
||||||
source: str | AnyHttpUrl,
|
|
||||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
|
||||||
) -> LoadedModel:
|
|
||||||
"""
|
|
||||||
Download, cache, and Load the model file located at the indicated URL.
|
|
||||||
|
|
||||||
This will check the model download cache for the model designated
|
|
||||||
by the provided URL and download it if needed using download_and_cache_ckpt().
|
|
||||||
It will then load the model into the RAM cache. If the optional loader
|
|
||||||
argument is provided, the loader will be invoked to load the model into
|
|
||||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
|
||||||
torch.load() as appropriate to the file suffix.
|
|
||||||
|
|
||||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source: A URL or a string that can be converted in one. Repo_ids
|
|
||||||
do not work here.
|
|
||||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A LoadedModel object.
|
|
||||||
"""
|
|
||||||
|
@ -1,15 +1,13 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
"""Implementation of ModelManagerServiceBase."""
|
"""Implementation of ModelManagerServiceBase."""
|
||||||
|
|
||||||
from pathlib import Path
|
from typing import Optional
|
||||||
from typing import Callable, Dict, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic.networks import AnyHttpUrl
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.backend.model_manager.load import LoadedModel, ModelCache, ModelConvertCache, ModelLoaderRegistry
|
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
@ -64,34 +62,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
if hasattr(service, "stop"):
|
if hasattr(service, "stop"):
|
||||||
service.stop(invoker)
|
service.stop(invoker)
|
||||||
|
|
||||||
def load_model_from_url(
|
|
||||||
self,
|
|
||||||
source: str | AnyHttpUrl,
|
|
||||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
|
||||||
) -> LoadedModel:
|
|
||||||
"""
|
|
||||||
Download, cache, and Load the model file located at the indicated URL.
|
|
||||||
|
|
||||||
This will check the model download cache for the model designated
|
|
||||||
by the provided URL and download it if needed using download_and_cache_ckpt().
|
|
||||||
It will then load the model into the RAM cache. If the optional loader
|
|
||||||
argument is provided, the loader will be invoked to load the model into
|
|
||||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
|
||||||
torch.load() as appropriate to the file suffix.
|
|
||||||
|
|
||||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source: A URL or a string that can be converted in one. Repo_ids
|
|
||||||
do not work here.
|
|
||||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A LoadedModel object.
|
|
||||||
"""
|
|
||||||
model_path = self.install.download_and_cache_model(source=str(source))
|
|
||||||
return self.load.load_model_from_path(model_path=model_path, loader=loader)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_model_manager(
|
def build_model_manager(
|
||||||
cls,
|
cls,
|
||||||
|
@ -483,7 +483,9 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
if isinstance(source, Path):
|
if isinstance(source, Path):
|
||||||
return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
|
return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
|
||||||
else:
|
else:
|
||||||
return self._services.model_manager.load_model_from_url(source=source, loader=loader)
|
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||||
|
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigInterface(InvocationContextInterface):
|
class ConfigInterface(InvocationContextInterface):
|
||||||
|
Loading…
Reference in New Issue
Block a user