mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add invocation_context.load_ckpt_from_url() method
This commit is contained in:
parent
9cc1f20ad5
commit
df5ebdbc4f
@ -1,5 +1,4 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import cv2
|
||||
@ -11,7 +10,6 @@ from pydantic import ConfigDict
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
@ -56,7 +54,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
rrdbnet_model = None
|
||||
netscale = None
|
||||
esrgan_model_path = None
|
||||
|
||||
if self.model_name in [
|
||||
"RealESRGAN_x4plus.pth",
|
||||
@ -99,16 +96,13 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
context.logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}")
|
||||
|
||||
# Downloads the ESRGAN model if it doesn't already exist
|
||||
download_with_progress_bar(
|
||||
name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
|
||||
loadnet = context.models.load_ckpt_from_url(
|
||||
source=ESRGAN_MODEL_URLS[self.model_name],
|
||||
)
|
||||
|
||||
upscaler = RealESRGAN(
|
||||
scale=netscale,
|
||||
model_path=esrgan_model_path,
|
||||
loadnet=loadnet.model,
|
||||
model=rrdbnet_model,
|
||||
half=False,
|
||||
tile=self.tile_size,
|
||||
@ -118,6 +112,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# TODO: This strips the alpha... is that okay?
|
||||
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||
upscaled_image = upscaler.upscale(cv2_image)
|
||||
|
||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -1,11 +1,14 @@
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
|
||||
|
||||
from picklescan.scanner import scan_file_path
|
||||
from PIL.Image import Image
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
from torch import Tensor
|
||||
from torch import load as torch_load
|
||||
|
||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
|
||||
@ -476,13 +479,14 @@ class ModelsInterface(InvocationContextInterface):
|
||||
key: str = job.config_out.key
|
||||
return key
|
||||
|
||||
def download_and_cache_model(
|
||||
def download_and_cache_ckpt(
|
||||
self,
|
||||
source: Union[str, AnyHttpUrl],
|
||||
access_token: Optional[str] = None,
|
||||
timeout: Optional[int] = 0,
|
||||
) -> Path:
|
||||
"""Download the model file located at source to the models cache and return its Path.
|
||||
"""
|
||||
Download the model file located at source to the models cache and return its Path.
|
||||
|
||||
This can be used to single-file install models and other resources of arbitrary types
|
||||
which should not get registered with the database. If the model is already
|
||||
@ -510,10 +514,65 @@ class ModelsInterface(InvocationContextInterface):
|
||||
)
|
||||
return path
|
||||
|
||||
def load_ckpt_from_url(
|
||||
self,
|
||||
source: Union[str, AnyHttpUrl],
|
||||
access_token: Optional[str] = None,
|
||||
timeout: Optional[int] = 0,
|
||||
loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Load and cache the model file located at the indicated URL.
|
||||
|
||||
This will check the model download cache for the model designated
|
||||
by the provided URL and download it if needed using download_and_cache_model().
|
||||
It will then load the model into the RAM cache. If the optional loader
|
||||
argument is provided, the loader will be invoked to load the model into
|
||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||
torch.load() as appropriate to the file suffix.
|
||||
|
||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||
|
||||
Args:
|
||||
source: A URL or a string that can be converted in one. Repo_ids
|
||||
do not work here.
|
||||
access_token: Optional access token for restricted resources.
|
||||
timeout: Wait up to the indicated number of seconds before timing
|
||||
out long downloads.
|
||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
"""
|
||||
ram_cache = self._services.model_manager.load.ram_cache
|
||||
try:
|
||||
return LoadedModel(_locker=ram_cache.get(key=str(source)))
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
def torch_load_file(checkpoint: Path) -> Dict[str | int, Any]:
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||
return torch_load(path, map_location="cpu")
|
||||
|
||||
path = self.download_and_cache_ckpt(source, access_token, timeout)
|
||||
if loader is None:
|
||||
loader = (
|
||||
torch_load_file
|
||||
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
|
||||
else lambda path: safetensors_load_file(path, device="cpu")
|
||||
)
|
||||
|
||||
raw_model = loader(path)
|
||||
ram_cache.put(key=str(source), model=raw_model)
|
||||
return LoadedModel(_locker=ram_cache.get(key=str(source)))
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
def get(self) -> InvokeAIAppConfig:
|
||||
"""Gets the app's config.
|
||||
"""
|
||||
Gets the app's config.
|
||||
|
||||
Returns:
|
||||
The app's config.
|
||||
|
@ -1,6 +1,5 @@
|
||||
import math
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import cv2
|
||||
@ -11,6 +10,7 @@ from cv2.typing import MatLike
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
"""
|
||||
@ -52,7 +52,7 @@ class RealESRGAN:
|
||||
def __init__(
|
||||
self,
|
||||
scale: int,
|
||||
model_path: Path,
|
||||
loadnet: AnyModel,
|
||||
model: RRDBNet,
|
||||
tile: int = 0,
|
||||
tile_pad: int = 10,
|
||||
@ -67,8 +67,6 @@ class RealESRGAN:
|
||||
self.half = half
|
||||
self.device = choose_torch_device()
|
||||
|
||||
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
|
||||
|
||||
# prefer to use params_ema
|
||||
if "params_ema" in loadnet:
|
||||
keyname = "params_ema"
|
||||
|
@ -23,8 +23,8 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
||||
class LoadedModel:
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
|
||||
config: AnyModelConfig
|
||||
_locker: ModelLockerBase
|
||||
config: Optional[AnyModelConfig] = None
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
"""Context entry."""
|
||||
|
@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
|
||||
@ -95,7 +95,6 @@ class ModelLoader(ModelLoaderBase):
|
||||
config.key,
|
||||
submodel_type=submodel_type,
|
||||
model=loaded_model,
|
||||
size=calc_model_size_by_data(loaded_model),
|
||||
)
|
||||
|
||||
return self._ram_cache.get(
|
||||
@ -126,9 +125,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
if subtype == submodel_type:
|
||||
continue
|
||||
if submodel := getattr(pipeline, subtype.value, None):
|
||||
self._ram_cache.put(
|
||||
config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel)
|
||||
)
|
||||
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
|
||||
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
|
@ -143,7 +143,6 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
self,
|
||||
key: str,
|
||||
model: T,
|
||||
size: int,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
|
@ -30,6 +30,7 @@ import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
@ -157,13 +158,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self,
|
||||
key: str,
|
||||
model: AnyModel,
|
||||
size: int,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
return
|
||||
size = calc_model_size_by_data(model)
|
||||
self.make_room(size)
|
||||
cache_record = CacheRecord(key, model, size)
|
||||
self._cached_models[key] = cache_record
|
||||
|
57
tests/app/services/model_load/test_load_api.py
Normal file
57
tests/app/services/model_load/test_load_api.py
Normal file
@ -0,0 +1,57 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_context(
|
||||
mock_services: InvocationServices,
|
||||
mm2_model_manager: ModelManagerServiceBase,
|
||||
) -> InvocationContext:
|
||||
mock_services.model_manager = mm2_model_manager
|
||||
return build_invocation_context(
|
||||
services=mock_services,
|
||||
data=None, # type: ignore
|
||||
cancel_event=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path):
|
||||
downloaded_path = mock_context.models.download_and_cache_ckpt(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
assert downloaded_path.is_file()
|
||||
assert downloaded_path.exists()
|
||||
assert downloaded_path.name == "test_embedding.safetensors"
|
||||
assert downloaded_path.parent.parent == mm2_root_dir / "models/.cache"
|
||||
|
||||
downloaded_path_2 = mock_context.models.download_and_cache_ckpt(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
assert downloaded_path == downloaded_path_2
|
||||
|
||||
|
||||
def test_download_and_load(mock_context: InvocationContext):
|
||||
loaded_model_1 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors")
|
||||
assert isinstance(loaded_model_1, LoadedModel)
|
||||
|
||||
loaded_model_2 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors")
|
||||
assert isinstance(loaded_model_2, LoadedModel)
|
||||
|
||||
with loaded_model_1 as model_1, loaded_model_2 as model_2:
|
||||
assert model_1 == model_2
|
||||
assert isinstance(model_1, dict)
|
||||
|
||||
|
||||
def test_install_model(mock_context: InvocationContext):
|
||||
key = mock_context.models.install_model("https://www.test.foo/download/test_embedding.safetensors")
|
||||
assert key is not None
|
||||
model = mock_context.models.load(key)
|
||||
assert model is not None
|
||||
assert model.config.key == key
|
Loading…
Reference in New Issue
Block a user