mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor load_ckpt_from_url()
This commit is contained in:
parent
d72f272f16
commit
70903ef057
@ -133,7 +133,9 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
def infill(self, image: Image.Image, context: InvocationContext):
|
||||
lama = LaMA(context)
|
||||
# Note that this accesses a protected attribute to get to the model manager service.
|
||||
# Is there a better way?
|
||||
lama = LaMA(context._services.model_manager)
|
||||
return lama(image)
|
||||
|
||||
|
||||
|
@ -468,7 +468,12 @@ class ModelInstallServiceBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
|
||||
def download_and_cache_ckpt(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
access_token: Optional[str] = None,
|
||||
timeout: int = 0,
|
||||
) -> Path:
|
||||
"""
|
||||
Download the model file located at source to the models cache and return its Path.
|
||||
|
||||
|
@ -399,9 +399,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
escaped_source = slugify(str(source))
|
||||
return app_config.download_cache_path / escaped_source
|
||||
|
||||
def download_and_cache(
|
||||
def download_and_cache_ckpt(
|
||||
self,
|
||||
source: Union[str, AnyHttpUrl],
|
||||
source: str | AnyHttpUrl,
|
||||
access_token: Optional[str] = None,
|
||||
timeout: int = 0,
|
||||
) -> Path:
|
||||
|
@ -2,7 +2,10 @@
|
||||
"""Base class for model loader."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
@ -38,3 +41,25 @@ class ModelLoadServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
|
||||
@abstractmethod
|
||||
def load_ckpt_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Load the checkpoint-format model file located at the indicated Path.
|
||||
|
||||
This will load an arbitrary model file into the RAM cache. If the optional loader
|
||||
argument is provided, the loader will be invoked to load the model into
|
||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||
torch.load() as appropriate to the file suffix.
|
||||
|
||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||
|
||||
Args:
|
||||
model_path: A pathlib.Path to a checkpoint-style models file
|
||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
"""
|
||||
|
@ -1,7 +1,13 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of model loader service."""
|
||||
|
||||
from typing import Optional, Type
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional, Type
|
||||
|
||||
from picklescan.scanner import scan_file_path
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
from torch import Tensor
|
||||
from torch import load as torch_load
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
@ -88,6 +94,51 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
)
|
||||
return loaded_model
|
||||
|
||||
def load_ckpt_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Load the checkpoint-format model file located at the indicated Path.
|
||||
|
||||
This will load an arbitrary model file into the RAM cache. If the optional loader
|
||||
argument is provided, the loader will be invoked to load the model into
|
||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||
torch.load() as appropriate to the file suffix.
|
||||
|
||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||
|
||||
Args:
|
||||
model_path: A pathlib.Path to a checkpoint-style models file
|
||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
"""
|
||||
cache_key = str(model_path)
|
||||
ram_cache = self.ram_cache
|
||||
try:
|
||||
return LoadedModel(_locker=ram_cache.get(key=cache_key))
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
def torch_load_file(checkpoint: Path) -> Dict[str, Tensor]:
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||
result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu")
|
||||
return result
|
||||
|
||||
if loader is None:
|
||||
loader = (
|
||||
torch_load_file
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
|
||||
else lambda path: safetensors_load_file(path, device="cpu")
|
||||
)
|
||||
|
||||
raw_model = loader(model_path)
|
||||
ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModel(_locker=ram_cache.get(key=cache_key))
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context_data: InvocationContextData,
|
||||
|
@ -1,11 +1,15 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
from ..download import DownloadQueueServiceBase
|
||||
@ -66,3 +70,35 @@ class ModelManagerServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_ckpt_from_url(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
access_token: Optional[str] = None,
|
||||
timeout: Optional[int] = 0,
|
||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Download, cache, and Load the model file located at the indicated URL.
|
||||
|
||||
This will check the model download cache for the model designated
|
||||
by the provided URL and download it if needed using download_and_cache_ckpt().
|
||||
It will then load the model into the RAM cache. If the optional loader
|
||||
argument is provided, the loader will be invoked to load the model into
|
||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||
torch.load() as appropriate to the file suffix.
|
||||
|
||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||
|
||||
Args:
|
||||
source: A URL or a string that can be converted in one. Repo_ids
|
||||
do not work here.
|
||||
access_token: Optional access token for restricted resources.
|
||||
timeout: Wait up to the indicated number of seconds before timing
|
||||
out long downloads.
|
||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
"""
|
||||
|
@ -1,13 +1,15 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of ModelManagerServiceBase."""
|
||||
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load import LoadedModel, ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
@ -100,3 +102,36 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
event_bus=events,
|
||||
)
|
||||
return cls(store=model_record_service, install=installer, load=loader)
|
||||
|
||||
def load_ckpt_from_url(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
access_token: Optional[str] = None,
|
||||
timeout: Optional[int] = 0,
|
||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Download, cache, and Load the model file located at the indicated URL.
|
||||
|
||||
This will check the model download cache for the model designated
|
||||
by the provided URL and download it if needed using download_and_cache_ckpt().
|
||||
It will then load the model into the RAM cache. If the optional loader
|
||||
argument is provided, the loader will be invoked to load the model into
|
||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||
torch.load() as appropriate to the file suffix.
|
||||
|
||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||
|
||||
Args:
|
||||
source: A URL or a string that can be converted in one. Repo_ids
|
||||
do not work here.
|
||||
access_token: Optional access token for restricted resources.
|
||||
timeout: Wait up to the indicated number of seconds before timing
|
||||
out long downloads.
|
||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
"""
|
||||
model_path = self.install.download_and_cache_ckpt(source=source, access_token=access_token, timeout=timeout)
|
||||
return self.load.load_ckpt_from_path(model_path=model_path, loader=loader)
|
||||
|
@ -1,14 +1,11 @@
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
|
||||
|
||||
from picklescan.scanner import scan_file_path
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
from torch import Tensor
|
||||
from torch import load as torch_load
|
||||
|
||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
|
||||
@ -263,7 +260,7 @@ class ImagesInterface(InvocationContextInterface):
|
||||
|
||||
|
||||
class TensorsInterface(InvocationContextInterface):
|
||||
def save(self, tensor: Tensor) -> str:
|
||||
def save(self, tensor: torch.Tensor) -> str:
|
||||
"""Saves a tensor, returning its name.
|
||||
|
||||
Args:
|
||||
@ -276,7 +273,7 @@ class TensorsInterface(InvocationContextInterface):
|
||||
name = self._services.tensors.save(obj=tensor)
|
||||
return name
|
||||
|
||||
def load(self, name: str) -> Tensor:
|
||||
def load(self, name: str) -> torch.Tensor:
|
||||
"""Loads a tensor by name.
|
||||
|
||||
Args:
|
||||
@ -316,8 +313,10 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
|
||||
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
"""Common API for loading, downloading and managing models."""
|
||||
|
||||
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
|
||||
"""Checks if a model exists.
|
||||
"""Check if a model exists.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
@ -326,14 +325,18 @@ class ModelsInterface(InvocationContextInterface):
|
||||
True if the model exists, False if not.
|
||||
"""
|
||||
if isinstance(identifier, str):
|
||||
return self._services.model_manager.store.exists(identifier)
|
||||
|
||||
return self._services.model_manager.store.exists(identifier.key)
|
||||
# For some reason, Mypy is not getting the type annotations for many of
|
||||
# the model manager service calls and raises a "returning Any in typed
|
||||
# context" error. Hence the extra typing hints here and below.
|
||||
result: bool = self._services.model_manager.store.exists(identifier)
|
||||
else:
|
||||
result = self._services.model_manager.store.exists(identifier.key)
|
||||
return result
|
||||
|
||||
def load(
|
||||
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
"""Loads a model.
|
||||
"""Load a model.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
@ -342,22 +345,22 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
An object representing the loaded model.
|
||||
"""
|
||||
|
||||
# The model manager emits events as it loads the model. It needs the context data to build
|
||||
# the event payloads.
|
||||
|
||||
if isinstance(identifier, str):
|
||||
model = self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
|
||||
result: LoadedModel = self._services.model_manager.load.load_model(model, submodel_type, self._data)
|
||||
else:
|
||||
_submodel_type = submodel_type or identifier.submodel_type
|
||||
model = self._services.model_manager.store.get_model(identifier.key)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
|
||||
result = self._services.model_manager.load.load_model(model, _submodel_type, self._data)
|
||||
return result
|
||||
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
"""Loads a model by its attributes.
|
||||
"""Load a model by its attributes.
|
||||
|
||||
Args:
|
||||
name: Name of the model.
|
||||
@ -369,7 +372,6 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
An object representing the loaded model.
|
||||
"""
|
||||
|
||||
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
|
||||
if len(configs) == 0:
|
||||
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
|
||||
@ -377,10 +379,11 @@ class ModelsInterface(InvocationContextInterface):
|
||||
if len(configs) > 1:
|
||||
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
|
||||
result: LoadedModel = self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
|
||||
return result
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
||||
"""Gets a model's config.
|
||||
"""Get a model's config.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
@ -389,12 +392,13 @@ class ModelsInterface(InvocationContextInterface):
|
||||
The model's config.
|
||||
"""
|
||||
if isinstance(identifier, str):
|
||||
return self._services.model_manager.store.get_model(identifier)
|
||||
|
||||
return self._services.model_manager.store.get_model(identifier.key)
|
||||
result: AnyModelConfig = self._services.model_manager.store.get_model(identifier)
|
||||
else:
|
||||
result = self._services.model_manager.store.get_model(identifier.key)
|
||||
return result
|
||||
|
||||
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||
"""Searches for models by path.
|
||||
"""Search for models by path.
|
||||
|
||||
Args:
|
||||
path: The path to search for.
|
||||
@ -402,7 +406,8 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
A list of models that match the path.
|
||||
"""
|
||||
return self._services.model_manager.store.search_by_path(path)
|
||||
result: list[AnyModelConfig] = self._services.model_manager.store.search_by_path(path)
|
||||
return result
|
||||
|
||||
def search_by_attrs(
|
||||
self,
|
||||
@ -411,7 +416,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
type: Optional[ModelType] = None,
|
||||
format: Optional[ModelFormat] = None,
|
||||
) -> list[AnyModelConfig]:
|
||||
"""Searches for models by attributes.
|
||||
"""Search for models by attributes.
|
||||
|
||||
Args:
|
||||
name: The name to search for (exact match).
|
||||
@ -422,13 +427,13 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
A list of models that match the attributes.
|
||||
"""
|
||||
|
||||
return self._services.model_manager.store.search_by_attr(
|
||||
result: list[AnyModelConfig] = self._services.model_manager.store.search_by_attr(
|
||||
model_name=name,
|
||||
base_model=base,
|
||||
model_type=type,
|
||||
model_format=format,
|
||||
)
|
||||
return result
|
||||
|
||||
def download_and_cache_ckpt(
|
||||
self,
|
||||
@ -451,26 +456,49 @@ class ModelsInterface(InvocationContextInterface):
|
||||
out long downloads.
|
||||
|
||||
Result:
|
||||
Path of the downloaded model
|
||||
Path to the downloaded model
|
||||
|
||||
May Raise:
|
||||
HTTPError
|
||||
TimeoutError
|
||||
"""
|
||||
installer = self._services.model_manager.install
|
||||
path: Path = installer.download_and_cache(
|
||||
path: Path = installer.download_and_cache_ckpt(
|
||||
source=source,
|
||||
access_token=access_token,
|
||||
timeout=timeout,
|
||||
)
|
||||
return path
|
||||
|
||||
def load_ckpt_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Load the checkpoint-format model file located at the indicated Path.
|
||||
|
||||
This will load an arbitrary model file into the RAM cache. If the optional loader
|
||||
argument is provided, the loader will be invoked to load the model into
|
||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||
torch.load() as appropriate to the file suffix.
|
||||
|
||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||
|
||||
Args:
|
||||
model_path: A pathlib.Path to a checkpoint-style models file
|
||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
"""
|
||||
result: LoadedModel = self._services.model_manager.load.load_ckpt_from_path(model_path, loader=loader)
|
||||
return result
|
||||
|
||||
def load_ckpt_from_url(
|
||||
self,
|
||||
source: Union[str, AnyHttpUrl],
|
||||
source: str | AnyHttpUrl,
|
||||
access_token: Optional[str] = None,
|
||||
timeout: Optional[int] = 0,
|
||||
loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None,
|
||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Download, cache, and Load the model file located at the indicated URL.
|
||||
@ -495,29 +523,10 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
"""
|
||||
ram_cache = self._services.model_manager.load.ram_cache
|
||||
try:
|
||||
return LoadedModel(_locker=ram_cache.get(key=str(source)))
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
def torch_load_file(checkpoint: Path) -> Dict[str | int, Any]:
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||
return torch_load(path, map_location="cpu")
|
||||
|
||||
path = self.download_and_cache_ckpt(source, access_token, timeout)
|
||||
if loader is None:
|
||||
loader = (
|
||||
torch_load_file
|
||||
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
|
||||
else lambda path: safetensors_load_file(path, device="cpu")
|
||||
result: LoadedModel = self._services.model_manager.load_ckpt_from_url(
|
||||
source=source, access_token=access_token, timeout=timeout, loader=loader
|
||||
)
|
||||
|
||||
raw_model = loader(path)
|
||||
ram_cache.put(key=str(source), model=raw_model)
|
||||
return LoadedModel(_locker=ram_cache.get(key=str(source)))
|
||||
return result
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
|
@ -1,11 +1,13 @@
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
@ -16,20 +18,20 @@ def norm_img(np_img):
|
||||
return np_img
|
||||
|
||||
|
||||
def load_jit_model(url_or_path, device):
|
||||
def load_jit_model(url_or_path, device) -> torch.nn.Module:
|
||||
model_path = url_or_path
|
||||
logger.info(f"Loading model from: {model_path}")
|
||||
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
||||
model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class LaMA:
|
||||
def __init__(self, context: InvocationContext):
|
||||
self._context = context
|
||||
def __init__(self, model_manager: "ModelManagerServiceBase"):
|
||||
self._model_manager = model_manager
|
||||
|
||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||
loaded_model = self._context.models.load_ckpt_from_url(
|
||||
loaded_model = self._model_manager.load_ckpt_from_url(
|
||||
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
loader=lambda path: load_jit_model(path, "cpu"),
|
||||
)
|
||||
|
@ -36,7 +36,7 @@ from ..raw_model import RawModel
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
|
||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
|
@ -47,5 +47,3 @@ def test_download_and_load(mock_context: InvocationContext):
|
||||
with loaded_model_1 as model_1, loaded_model_2 as model_2:
|
||||
assert model_1 == model_2
|
||||
assert isinstance(model_1, dict)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user