mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add invocation_context.load_ckpt_from_url() method
This commit is contained in:
parent
9cc1f20ad5
commit
df5ebdbc4f
@ -1,5 +1,4 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||||
from pathlib import Path
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -11,7 +10,6 @@ from pydantic import ConfigDict
|
|||||||
from invokeai.app.invocations.fields import ImageField
|
from invokeai.app.invocations.fields import ImageField
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
|
||||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
@ -56,7 +54,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
rrdbnet_model = None
|
rrdbnet_model = None
|
||||||
netscale = None
|
netscale = None
|
||||||
esrgan_model_path = None
|
|
||||||
|
|
||||||
if self.model_name in [
|
if self.model_name in [
|
||||||
"RealESRGAN_x4plus.pth",
|
"RealESRGAN_x4plus.pth",
|
||||||
@ -99,16 +96,13 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
context.logger.error(msg)
|
context.logger.error(msg)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}")
|
loadnet = context.models.load_ckpt_from_url(
|
||||||
|
source=ESRGAN_MODEL_URLS[self.model_name],
|
||||||
# Downloads the ESRGAN model if it doesn't already exist
|
|
||||||
download_with_progress_bar(
|
|
||||||
name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
|
|
||||||
)
|
)
|
||||||
|
|
||||||
upscaler = RealESRGAN(
|
upscaler = RealESRGAN(
|
||||||
scale=netscale,
|
scale=netscale,
|
||||||
model_path=esrgan_model_path,
|
loadnet=loadnet.model,
|
||||||
model=rrdbnet_model,
|
model=rrdbnet_model,
|
||||||
half=False,
|
half=False,
|
||||||
tile=self.tile_size,
|
tile=self.tile_size,
|
||||||
@ -118,6 +112,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
# TODO: This strips the alpha... is that okay?
|
# TODO: This strips the alpha... is that okay?
|
||||||
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||||
upscaled_image = upscaler.upscale(cv2_image)
|
upscaled_image = upscaler.upscale(cv2_image)
|
||||||
|
|
||||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
import threading
|
import threading
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
|
||||||
|
|
||||||
|
from picklescan.scanner import scan_file_path
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
from safetensors.torch import load_file as safetensors_load_file
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch import load as torch_load
|
||||||
|
|
||||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||||
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
|
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
|
||||||
@ -476,13 +479,14 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
key: str = job.config_out.key
|
key: str = job.config_out.key
|
||||||
return key
|
return key
|
||||||
|
|
||||||
def download_and_cache_model(
|
def download_and_cache_ckpt(
|
||||||
self,
|
self,
|
||||||
source: Union[str, AnyHttpUrl],
|
source: Union[str, AnyHttpUrl],
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
timeout: Optional[int] = 0,
|
timeout: Optional[int] = 0,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""Download the model file located at source to the models cache and return its Path.
|
"""
|
||||||
|
Download the model file located at source to the models cache and return its Path.
|
||||||
|
|
||||||
This can be used to single-file install models and other resources of arbitrary types
|
This can be used to single-file install models and other resources of arbitrary types
|
||||||
which should not get registered with the database. If the model is already
|
which should not get registered with the database. If the model is already
|
||||||
@ -510,10 +514,65 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
)
|
)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
def load_ckpt_from_url(
|
||||||
|
self,
|
||||||
|
source: Union[str, AnyHttpUrl],
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
timeout: Optional[int] = 0,
|
||||||
|
loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None,
|
||||||
|
) -> LoadedModel:
|
||||||
|
"""
|
||||||
|
Load and cache the model file located at the indicated URL.
|
||||||
|
|
||||||
|
This will check the model download cache for the model designated
|
||||||
|
by the provided URL and download it if needed using download_and_cache_model().
|
||||||
|
It will then load the model into the RAM cache. If the optional loader
|
||||||
|
argument is provided, the loader will be invoked to load the model into
|
||||||
|
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||||
|
torch.load() as appropriate to the file suffix.
|
||||||
|
|
||||||
|
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: A URL or a string that can be converted in one. Repo_ids
|
||||||
|
do not work here.
|
||||||
|
access_token: Optional access token for restricted resources.
|
||||||
|
timeout: Wait up to the indicated number of seconds before timing
|
||||||
|
out long downloads.
|
||||||
|
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A LoadedModel object.
|
||||||
|
"""
|
||||||
|
ram_cache = self._services.model_manager.load.ram_cache
|
||||||
|
try:
|
||||||
|
return LoadedModel(_locker=ram_cache.get(key=str(source)))
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def torch_load_file(checkpoint: Path) -> Dict[str | int, Any]:
|
||||||
|
scan_result = scan_file_path(checkpoint)
|
||||||
|
if scan_result.infected_files != 0:
|
||||||
|
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||||
|
return torch_load(path, map_location="cpu")
|
||||||
|
|
||||||
|
path = self.download_and_cache_ckpt(source, access_token, timeout)
|
||||||
|
if loader is None:
|
||||||
|
loader = (
|
||||||
|
torch_load_file
|
||||||
|
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
|
||||||
|
else lambda path: safetensors_load_file(path, device="cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_model = loader(path)
|
||||||
|
ram_cache.put(key=str(source), model=raw_model)
|
||||||
|
return LoadedModel(_locker=ram_cache.get(key=str(source)))
|
||||||
|
|
||||||
|
|
||||||
class ConfigInterface(InvocationContextInterface):
|
class ConfigInterface(InvocationContextInterface):
|
||||||
def get(self) -> InvokeAIAppConfig:
|
def get(self) -> InvokeAIAppConfig:
|
||||||
"""Gets the app's config.
|
"""
|
||||||
|
Gets the app's config.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The app's config.
|
The app's config.
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -11,6 +10,7 @@ from cv2.typing import MatLike
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||||
|
from invokeai.backend.model_manager.config import AnyModel
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -52,7 +52,7 @@ class RealESRGAN:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
scale: int,
|
scale: int,
|
||||||
model_path: Path,
|
loadnet: AnyModel,
|
||||||
model: RRDBNet,
|
model: RRDBNet,
|
||||||
tile: int = 0,
|
tile: int = 0,
|
||||||
tile_pad: int = 10,
|
tile_pad: int = 10,
|
||||||
@ -67,8 +67,6 @@ class RealESRGAN:
|
|||||||
self.half = half
|
self.half = half
|
||||||
self.device = choose_torch_device()
|
self.device = choose_torch_device()
|
||||||
|
|
||||||
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
|
|
||||||
|
|
||||||
# prefer to use params_ema
|
# prefer to use params_ema
|
||||||
if "params_ema" in loadnet:
|
if "params_ema" in loadnet:
|
||||||
keyname = "params_ema"
|
keyname = "params_ema"
|
||||||
|
@ -23,8 +23,8 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
|||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||||
|
|
||||||
config: AnyModelConfig
|
|
||||||
_locker: ModelLockerBase
|
_locker: ModelLockerBase
|
||||||
|
config: Optional[AnyModelConfig] = None
|
||||||
|
|
||||||
def __enter__(self) -> AnyModel:
|
def __enter__(self) -> AnyModel:
|
||||||
"""Context entry."""
|
"""Context entry."""
|
||||||
|
@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
|||||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
|
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||||
|
|
||||||
@ -95,7 +95,6 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
config.key,
|
config.key,
|
||||||
submodel_type=submodel_type,
|
submodel_type=submodel_type,
|
||||||
model=loaded_model,
|
model=loaded_model,
|
||||||
size=calc_model_size_by_data(loaded_model),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._ram_cache.get(
|
return self._ram_cache.get(
|
||||||
@ -126,9 +125,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
if subtype == submodel_type:
|
if subtype == submodel_type:
|
||||||
continue
|
continue
|
||||||
if submodel := getattr(pipeline, subtype.value, None):
|
if submodel := getattr(pipeline, subtype.value, None):
|
||||||
self._ram_cache.put(
|
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
|
||||||
config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel)
|
|
||||||
)
|
|
||||||
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
|
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
|
||||||
|
|
||||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||||
|
@ -143,7 +143,6 @@ class ModelCacheBase(ABC, Generic[T]):
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
model: T,
|
model: T,
|
||||||
size: int,
|
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Store model under key and optional submodel_type."""
|
"""Store model under key and optional submodel_type."""
|
||||||
|
@ -30,6 +30,7 @@ import torch
|
|||||||
|
|
||||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||||
|
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
@ -157,13 +158,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
model: AnyModel,
|
model: AnyModel,
|
||||||
size: int,
|
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Store model under key and optional submodel_type."""
|
"""Store model under key and optional submodel_type."""
|
||||||
key = self._make_cache_key(key, submodel_type)
|
key = self._make_cache_key(key, submodel_type)
|
||||||
if key in self._cached_models:
|
if key in self._cached_models:
|
||||||
return
|
return
|
||||||
|
size = calc_model_size_by_data(model)
|
||||||
self.make_room(size)
|
self.make_room(size)
|
||||||
cache_record = CacheRecord(key, model, size)
|
cache_record = CacheRecord(key, model, size)
|
||||||
self._cached_models[key] = cache_record
|
self._cached_models[key] = cache_record
|
||||||
|
57
tests/app/services/model_load/test_load_api.py
Normal file
57
tests/app/services/model_load/test_load_api.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
|
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context
|
||||||
|
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||||
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mock_context(
|
||||||
|
mock_services: InvocationServices,
|
||||||
|
mm2_model_manager: ModelManagerServiceBase,
|
||||||
|
) -> InvocationContext:
|
||||||
|
mock_services.model_manager = mm2_model_manager
|
||||||
|
return build_invocation_context(
|
||||||
|
services=mock_services,
|
||||||
|
data=None, # type: ignore
|
||||||
|
cancel_event=None, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path):
|
||||||
|
downloaded_path = mock_context.models.download_and_cache_ckpt(
|
||||||
|
"https://www.test.foo/download/test_embedding.safetensors"
|
||||||
|
)
|
||||||
|
assert downloaded_path.is_file()
|
||||||
|
assert downloaded_path.exists()
|
||||||
|
assert downloaded_path.name == "test_embedding.safetensors"
|
||||||
|
assert downloaded_path.parent.parent == mm2_root_dir / "models/.cache"
|
||||||
|
|
||||||
|
downloaded_path_2 = mock_context.models.download_and_cache_ckpt(
|
||||||
|
"https://www.test.foo/download/test_embedding.safetensors"
|
||||||
|
)
|
||||||
|
assert downloaded_path == downloaded_path_2
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_and_load(mock_context: InvocationContext):
|
||||||
|
loaded_model_1 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors")
|
||||||
|
assert isinstance(loaded_model_1, LoadedModel)
|
||||||
|
|
||||||
|
loaded_model_2 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors")
|
||||||
|
assert isinstance(loaded_model_2, LoadedModel)
|
||||||
|
|
||||||
|
with loaded_model_1 as model_1, loaded_model_2 as model_2:
|
||||||
|
assert model_1 == model_2
|
||||||
|
assert isinstance(model_1, dict)
|
||||||
|
|
||||||
|
|
||||||
|
def test_install_model(mock_context: InvocationContext):
|
||||||
|
key = mock_context.models.install_model("https://www.test.foo/download/test_embedding.safetensors")
|
||||||
|
assert key is not None
|
||||||
|
model = mock_context.models.load(key)
|
||||||
|
assert model is not None
|
||||||
|
assert model.config.key == key
|
Loading…
Reference in New Issue
Block a user