diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index d687384fcb..e09618960e 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -1,5 +1,4 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team -from pathlib import Path from typing import Literal import cv2 @@ -11,7 +10,6 @@ from pydantic import ConfigDict from invokeai.app.invocations.fields import ImageField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.util.devices import choose_torch_device @@ -56,7 +54,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): rrdbnet_model = None netscale = None - esrgan_model_path = None if self.model_name in [ "RealESRGAN_x4plus.pth", @@ -99,16 +96,13 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): context.logger.error(msg) raise ValueError(msg) - esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}") - - # Downloads the ESRGAN model if it doesn't already exist - download_with_progress_bar( - name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path + loadnet = context.models.load_ckpt_from_url( + source=ESRGAN_MODEL_URLS[self.model_name], ) upscaler = RealESRGAN( scale=netscale, - model_path=esrgan_model_path, + loadnet=loadnet.model, model=rrdbnet_model, half=False, tile=self.tile_size, @@ -118,6 +112,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): # TODO: This strips the alpha... is that okay? cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) upscaled_image = upscaler.upscale(cv2_image) + pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA") torch.cuda.empty_cache() diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 20b03bb28c..5cfe9c17a1 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,11 +1,14 @@ import threading from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +from picklescan.scanner import scan_file_path from PIL.Image import Image from pydantic.networks import AnyHttpUrl +from safetensors.torch import load_file as safetensors_load_file from torch import Tensor +from torch import load as torch_load from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata @@ -476,13 +479,14 @@ class ModelsInterface(InvocationContextInterface): key: str = job.config_out.key return key - def download_and_cache_model( + def download_and_cache_ckpt( self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None, timeout: Optional[int] = 0, ) -> Path: - """Download the model file located at source to the models cache and return its Path. + """ + Download the model file located at source to the models cache and return its Path. This can be used to single-file install models and other resources of arbitrary types which should not get registered with the database. If the model is already @@ -510,10 +514,65 @@ class ModelsInterface(InvocationContextInterface): ) return path + def load_ckpt_from_url( + self, + source: Union[str, AnyHttpUrl], + access_token: Optional[str] = None, + timeout: Optional[int] = 0, + loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None, + ) -> LoadedModel: + """ + Load and cache the model file located at the indicated URL. + + This will check the model download cache for the model designated + by the provided URL and download it if needed using download_and_cache_model(). + It will then load the model into the RAM cache. If the optional loader + argument is provided, the loader will be invoked to load the model into + memory. Otherwise the method will call safetensors.torch.load_file() or + torch.load() as appropriate to the file suffix. + + Be aware that the LoadedModel object will have a `config` attribute of None. + + Args: + source: A URL or a string that can be converted in one. Repo_ids + do not work here. + access_token: Optional access token for restricted resources. + timeout: Wait up to the indicated number of seconds before timing + out long downloads. + loader: A Callable that expects a Path and returns a Dict[str|int, Any] + + Returns: + A LoadedModel object. + """ + ram_cache = self._services.model_manager.load.ram_cache + try: + return LoadedModel(_locker=ram_cache.get(key=str(source))) + except IndexError: + pass + + def torch_load_file(checkpoint: Path) -> Dict[str | int, Any]: + scan_result = scan_file_path(checkpoint) + if scan_result.infected_files != 0: + raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.") + return torch_load(path, map_location="cpu") + + path = self.download_and_cache_ckpt(source, access_token, timeout) + if loader is None: + loader = ( + torch_load_file + if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")) + else lambda path: safetensors_load_file(path, device="cpu") + ) + + raw_model = loader(path) + ram_cache.put(key=str(source), model=raw_model) + return LoadedModel(_locker=ram_cache.get(key=str(source))) + class ConfigInterface(InvocationContextInterface): def get(self) -> InvokeAIAppConfig: - """Gets the app's config. + """ + Gets the app's config. Returns: The app's config. diff --git a/invokeai/backend/image_util/realesrgan/realesrgan.py b/invokeai/backend/image_util/realesrgan/realesrgan.py index c06504b608..7c4d90f5bd 100644 --- a/invokeai/backend/image_util/realesrgan/realesrgan.py +++ b/invokeai/backend/image_util/realesrgan/realesrgan.py @@ -1,6 +1,5 @@ import math from enum import Enum -from pathlib import Path from typing import Any, Optional import cv2 @@ -11,6 +10,7 @@ from cv2.typing import MatLike from tqdm import tqdm from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet +from invokeai.backend.model_manager.config import AnyModel from invokeai.backend.util.devices import choose_torch_device """ @@ -52,7 +52,7 @@ class RealESRGAN: def __init__( self, scale: int, - model_path: Path, + loadnet: AnyModel, model: RRDBNet, tile: int = 0, tile_pad: int = 10, @@ -67,8 +67,6 @@ class RealESRGAN: self.half = half self.device = choose_torch_device() - loadnet = torch.load(model_path, map_location=torch.device("cpu")) - # prefer to use params_ema if "params_ema" in loadnet: keyname = "params_ema" diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index c336926aea..41a36d7b51 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -23,8 +23,8 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod class LoadedModel: """Context manager object that mediates transfer from RAM<->VRAM.""" - config: AnyModelConfig _locker: ModelLockerBase + config: Optional[AnyModelConfig] = None def __enter__(self) -> AnyModel: """Context entry.""" diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 60cc1f5e6c..8e08f065e1 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase -from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.util.devices import choose_torch_device, torch_dtype @@ -95,7 +95,6 @@ class ModelLoader(ModelLoaderBase): config.key, submodel_type=submodel_type, model=loaded_model, - size=calc_model_size_by_data(loaded_model), ) return self._ram_cache.get( @@ -126,9 +125,7 @@ class ModelLoader(ModelLoaderBase): if subtype == submodel_type: continue if submodel := getattr(pipeline, subtype.value, None): - self._ram_cache.put( - config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel) - ) + self._ram_cache.put(config.key, submodel_type=subtype, model=submodel) return getattr(pipeline, submodel_type.value) if submodel_type else pipeline def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index eb82f87cb2..3ffd6714a1 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -143,7 +143,6 @@ class ModelCacheBase(ABC, Generic[T]): self, key: str, model: T, - size: int, submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index b8312f619a..b8fe19047c 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -30,6 +30,7 @@ import torch from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.logging import InvokeAILogger @@ -157,13 +158,13 @@ class ModelCache(ModelCacheBase[AnyModel]): self, key: str, model: AnyModel, - size: int, submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" key = self._make_cache_key(key, submodel_type) if key in self._cached_models: return + size = calc_model_size_by_data(model) self.make_room(size) cache_record = CacheRecord(key, model, size) self._cached_models[key] = cache_record diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py new file mode 100644 index 0000000000..b99cc7d824 --- /dev/null +++ b/tests/app/services/model_load/test_load_api.py @@ -0,0 +1,57 @@ +from pathlib import Path + +import pytest + +from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.model_manager import ModelManagerServiceBase +from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context +from invokeai.backend.model_manager.load.load_base import LoadedModel +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 + + +@pytest.fixture() +def mock_context( + mock_services: InvocationServices, + mm2_model_manager: ModelManagerServiceBase, +) -> InvocationContext: + mock_services.model_manager = mm2_model_manager + return build_invocation_context( + services=mock_services, + data=None, # type: ignore + cancel_event=None, # type: ignore + ) + + +def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path): + downloaded_path = mock_context.models.download_and_cache_ckpt( + "https://www.test.foo/download/test_embedding.safetensors" + ) + assert downloaded_path.is_file() + assert downloaded_path.exists() + assert downloaded_path.name == "test_embedding.safetensors" + assert downloaded_path.parent.parent == mm2_root_dir / "models/.cache" + + downloaded_path_2 = mock_context.models.download_and_cache_ckpt( + "https://www.test.foo/download/test_embedding.safetensors" + ) + assert downloaded_path == downloaded_path_2 + + +def test_download_and_load(mock_context: InvocationContext): + loaded_model_1 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors") + assert isinstance(loaded_model_1, LoadedModel) + + loaded_model_2 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors") + assert isinstance(loaded_model_2, LoadedModel) + + with loaded_model_1 as model_1, loaded_model_2 as model_2: + assert model_1 == model_2 + assert isinstance(model_1, dict) + + +def test_install_model(mock_context: InvocationContext): + key = mock_context.models.install_model("https://www.test.foo/download/test_embedding.safetensors") + assert key is not None + model = mock_context.models.load(key) + assert model is not None + assert model.config.key == key