add invocation_context.load_ckpt_from_url() method

This commit is contained in:
Lincoln Stein 2024-04-12 00:55:21 -04:00
parent 9cc1f20ad5
commit df5ebdbc4f
8 changed files with 131 additions and 25 deletions

View File

@ -1,5 +1,4 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path
from typing import Literal
import cv2
@ -11,7 +10,6 @@ from pydantic import ConfigDict
from invokeai.app.invocations.fields import ImageField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import choose_torch_device
@ -56,7 +54,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
rrdbnet_model = None
netscale = None
esrgan_model_path = None
if self.model_name in [
"RealESRGAN_x4plus.pth",
@ -99,16 +96,13 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
context.logger.error(msg)
raise ValueError(msg)
esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}")
# Downloads the ESRGAN model if it doesn't already exist
download_with_progress_bar(
name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
loadnet = context.models.load_ckpt_from_url(
source=ESRGAN_MODEL_URLS[self.model_name],
)
upscaler = RealESRGAN(
scale=netscale,
model_path=esrgan_model_path,
loadnet=loadnet.model,
model=rrdbnet_model,
half=False,
tile=self.tile_size,
@ -118,6 +112,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
# TODO: This strips the alpha... is that okay?
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
upscaled_image = upscaler.upscale(cv2_image)
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
torch.cuda.empty_cache()

View File

@ -1,11 +1,14 @@
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
from picklescan.scanner import scan_file_path
from PIL.Image import Image
from pydantic.networks import AnyHttpUrl
from safetensors.torch import load_file as safetensors_load_file
from torch import Tensor
from torch import load as torch_load
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
@ -476,13 +479,14 @@ class ModelsInterface(InvocationContextInterface):
key: str = job.config_out.key
return key
def download_and_cache_model(
def download_and_cache_ckpt(
self,
source: Union[str, AnyHttpUrl],
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
) -> Path:
"""Download the model file located at source to the models cache and return its Path.
"""
Download the model file located at source to the models cache and return its Path.
This can be used to single-file install models and other resources of arbitrary types
which should not get registered with the database. If the model is already
@ -510,10 +514,65 @@ class ModelsInterface(InvocationContextInterface):
)
return path
def load_ckpt_from_url(
self,
source: Union[str, AnyHttpUrl],
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None,
) -> LoadedModel:
"""
Load and cache the model file located at the indicated URL.
This will check the model download cache for the model designated
by the provided URL and download it if needed using download_and_cache_model().
It will then load the model into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None.
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""
ram_cache = self._services.model_manager.load.ram_cache
try:
return LoadedModel(_locker=ram_cache.get(key=str(source)))
except IndexError:
pass
def torch_load_file(checkpoint: Path) -> Dict[str | int, Any]:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
return torch_load(path, map_location="cpu")
path = self.download_and_cache_ckpt(source, access_token, timeout)
if loader is None:
loader = (
torch_load_file
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
else lambda path: safetensors_load_file(path, device="cpu")
)
raw_model = loader(path)
ram_cache.put(key=str(source), model=raw_model)
return LoadedModel(_locker=ram_cache.get(key=str(source)))
class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig:
"""Gets the app's config.
"""
Gets the app's config.
Returns:
The app's config.

View File

@ -1,6 +1,5 @@
import math
from enum import Enum
from pathlib import Path
from typing import Any, Optional
import cv2
@ -11,6 +10,7 @@ from cv2.typing import MatLike
from tqdm import tqdm
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.util.devices import choose_torch_device
"""
@ -52,7 +52,7 @@ class RealESRGAN:
def __init__(
self,
scale: int,
model_path: Path,
loadnet: AnyModel,
model: RRDBNet,
tile: int = 0,
tile_pad: int = 10,
@ -67,8 +67,6 @@ class RealESRGAN:
self.half = half
self.device = choose_torch_device()
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
# prefer to use params_ema
if "params_ema" in loadnet:
keyname = "params_ema"

View File

@ -23,8 +23,8 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
class LoadedModel:
"""Context manager object that mediates transfer from RAM<->VRAM."""
config: AnyModelConfig
_locker: ModelLockerBase
config: Optional[AnyModelConfig] = None
def __enter__(self) -> AnyModel:
"""Context entry."""

View File

@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
@ -95,7 +95,6 @@ class ModelLoader(ModelLoaderBase):
config.key,
submodel_type=submodel_type,
model=loaded_model,
size=calc_model_size_by_data(loaded_model),
)
return self._ram_cache.get(
@ -126,9 +125,7 @@ class ModelLoader(ModelLoaderBase):
if subtype == submodel_type:
continue
if submodel := getattr(pipeline, subtype.value, None):
self._ram_cache.put(
config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel)
)
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:

View File

@ -143,7 +143,6 @@ class ModelCacheBase(ABC, Generic[T]):
self,
key: str,
model: T,
size: int,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""

View File

@ -30,6 +30,7 @@ import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
@ -157,13 +158,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
self,
key: str,
model: AnyModel,
size: int,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
return
size = calc_model_size_by_data(model)
self.make_room(size)
cache_record = CacheRecord(key, model, size)
self._cached_models[key] = cache_record

View File

@ -0,0 +1,57 @@
from pathlib import Path
import pytest
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_manager import ModelManagerServiceBase
from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context
from invokeai.backend.model_manager.load.load_base import LoadedModel
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
@pytest.fixture()
def mock_context(
mock_services: InvocationServices,
mm2_model_manager: ModelManagerServiceBase,
) -> InvocationContext:
mock_services.model_manager = mm2_model_manager
return build_invocation_context(
services=mock_services,
data=None, # type: ignore
cancel_event=None, # type: ignore
)
def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path):
downloaded_path = mock_context.models.download_and_cache_ckpt(
"https://www.test.foo/download/test_embedding.safetensors"
)
assert downloaded_path.is_file()
assert downloaded_path.exists()
assert downloaded_path.name == "test_embedding.safetensors"
assert downloaded_path.parent.parent == mm2_root_dir / "models/.cache"
downloaded_path_2 = mock_context.models.download_and_cache_ckpt(
"https://www.test.foo/download/test_embedding.safetensors"
)
assert downloaded_path == downloaded_path_2
def test_download_and_load(mock_context: InvocationContext):
loaded_model_1 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors")
assert isinstance(loaded_model_1, LoadedModel)
loaded_model_2 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors")
assert isinstance(loaded_model_2, LoadedModel)
with loaded_model_1 as model_1, loaded_model_2 as model_2:
assert model_1 == model_2
assert isinstance(model_1, dict)
def test_install_model(mock_context: InvocationContext):
key = mock_context.models.install_model("https://www.test.foo/download/test_embedding.safetensors")
assert key is not None
model = mock_context.models.load(key)
assert model is not None
assert model.config.key == key