Address change requests in first round of PR reviews.

Pending:

- Move model install calls into model manager and create passthrus in invocation_context.
- Consider splitting load_model_from_url() into a call to get the path and a call to load the path.
This commit is contained in:
Lincoln Stein 2024-04-24 23:53:30 -04:00
parent 34cdfc61ab
commit d72f272f16
7 changed files with 64 additions and 123 deletions

View File

@ -11,7 +11,6 @@ from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, invocation from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata from .fields import InputField, WithBoard, WithMetadata
@ -96,22 +95,21 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
source=ESRGAN_MODEL_URLS[self.model_name], source=ESRGAN_MODEL_URLS[self.model_name],
) )
upscaler = RealESRGAN( with loadnet as loadnet_model:
scale=netscale, upscaler = RealESRGAN(
loadnet=loadnet.model, scale=netscale,
model=rrdbnet_model, loadnet=loadnet_model,
half=False, model=rrdbnet_model,
tile=self.tile_size, half=False,
) tile=self.tile_size,
)
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
# TODO: This strips the alpha... is that okay? # TODO: This strips the alpha... is that okay?
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
upscaled_image = upscaler.upscale(cv2_image) upscaled_image = upscaler.upscale(cv2_image)
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA") pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
TorchDevice.empty_cache()
image_dto = context.images.save(image=pil_image) image_dto = context.images.save(image=pil_image)

View File

@ -6,7 +6,6 @@ import re
import signal import signal
import threading import threading
import time import time
from hashlib import sha256
from pathlib import Path from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree from shutil import copyfile, copytree, move, rmtree
@ -44,6 +43,7 @@ from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.util import slugify
from .model_install_base import ( from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP, MODEL_SOURCE_TO_TYPE_MAP,
@ -396,8 +396,8 @@ class ModelInstallService(ModelInstallServiceBase):
@classmethod @classmethod
def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path: def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path:
model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32] escaped_source = slugify(str(source))
return app_config.download_cache_path / model_hash return app_config.download_cache_path / escaped_source
def download_and_cache( def download_and_cache(
self, self,

View File

@ -430,67 +430,6 @@ class ModelsInterface(InvocationContextInterface):
model_format=format, model_format=format,
) )
def install_model(
self,
source: str,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
timeout: Optional[int] = 0,
) -> str:
"""Install and register a model in the database.
Args:
source: String source; see below
config: Optional dict. Any fields in this dict
will override corresponding autoassigned probe fields in the
model's config record.
access_token: Optional access token for remote sources.
inplace: If true, installs a local model in place rather than copying
it into the models directory
timeout: How long to wait on install (in seconds). A value of 0 (default)
blocks indefinitely
The source can be:
1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`)
2. An http or https URL (`https://foo.bar/foo`)
3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`)
We extend the HuggingFace repo_id syntax to include the variant and the
subfolder or path. The following are acceptable alternatives:
stabilityai/stable-diffusion-v4
stabilityai/stable-diffusion-v4:fp16
stabilityai/stable-diffusion-v4:fp16:vae
stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
stabilityai/stable-diffusion-v4:onnx:vae
Because a local file path can look like a huggingface repo_id, the logic
first checks whether the path exists on disk, and if not, it is treated as
a parseable huggingface repo.
Returns:
Key to the newly installed model.
May Raise:
ValueError -- bad source
UnknownModelException -- remote model not found
InvalidModelException -- what was retrieved from remote is not a model
TimeoutError -- model could not be installed within timeout
Exception -- another error condition
"""
installer = self._services.model_manager.install
job = installer.heuristic_import(
source=source,
config=config,
access_token=access_token,
inplace=inplace,
)
installer.wait_for_job(job, timeout)
if job.errored:
raise Exception(job.error)
key: str = job.config_out.key
return key
def download_and_cache_ckpt( def download_and_cache_ckpt(
self, self,
source: Union[str, AnyHttpUrl], source: Union[str, AnyHttpUrl],

View File

@ -1,28 +1,26 @@
import pathlib
import shutil import shutil
import sqlite3 import sqlite3
from logging import Logger from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_install.model_install_default import ModelInstallService
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
LEGACY_CORE_MODELS = { LEGACY_CORE_MODELS = [
# OpenPose # OpenPose
"https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true": "any/annotators/dwpose/yolox_l.onnx", "any/annotators/dwpose/yolox_l.onnx",
"https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true": "any/annotators/dwpose/dw-ll_ucoco_384.onnx", "any/annotators/dwpose/dw-ll_ucoco_384.onnx",
# DepthAnything # DepthAnything
"https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true": "any/annotators/depth_anything/depth_anything_vitl14.pth", "any/annotators/depth_anything/depth_anything_vitl14.pth",
"https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true": "any/annotators/depth_anything/depth_anything_vitb14.pth", "any/annotators/depth_anything/depth_anything_vitb14.pth",
"https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true": "any/annotators/depth_anything/depth_anything_vits14.pth", "any/annotators/depth_anything/depth_anything_vits14.pth",
# Lama inpaint # Lama inpaint
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt": "core/misc/lama/lama.pt", "core/misc/lama/lama.pt",
# RealESRGAN upscale # RealESRGAN upscale
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth": "core/upscaling/realesrgan/RealESRGAN_x4plus.pth", "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth": "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth", "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth": "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth": "core/upscaling/realesrgan/RealESRGAN_x2plus.pth", "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
} ]
class Migration10Callback: class Migration10Callback:
@ -31,34 +29,24 @@ class Migration10Callback:
self._logger = logger self._logger = logger
def __call__(self, cursor: sqlite3.Cursor) -> None: def __call__(self, cursor: sqlite3.Cursor) -> None:
self._rename_convert_cache() self._remove_convert_cache()
self._migrate_downloaded_models_cache() self._remove_downloaded_models()
self._remove_unused_core_models() self._remove_unused_core_models()
def _rename_convert_cache(self) -> None: def _remove_convert_cache(self) -> None:
"""Rename models/.cache to models/.convert_cache.""" """Rename models/.cache to models/.convert_cache."""
self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.")
legacy_convert_path = self._app_config.root_path / "models" / ".cache" legacy_convert_path = self._app_config.root_path / "models" / ".cache"
configured_convert_dir = self._app_config.convert_cache_dir shutil.rmtree(legacy_convert_path, ignore_errors=True)
configured_convert_path = self._app_config.convert_cache_path
# old convert dir was in use, and current convert dir has not been changed
if legacy_convert_path.exists() and configured_convert_dir == pathlib.Path("models/.convert_cache"):
self._logger.info(
f"Migrating legacy convert cache directory from {str(legacy_convert_path)} to {str(configured_convert_path)}"
)
shutil.rmtree(configured_convert_path, ignore_errors=True) # shouldn't be needed, but just in case...
shutil.move(legacy_convert_path, configured_convert_path)
def _migrate_downloaded_models_cache(self) -> None: def _remove_downloaded_models(self) -> None:
"""Move used core models to modsl/.download_cache.""" """Remove models from their old locations; they will re-download when needed."""
self._logger.info(f"Migrating legacy core models to {str(self._app_config.download_cache_path)}") self._logger.info(
for url, legacy_dest in LEGACY_CORE_MODELS.items(): "Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache."
legacy_dest_path = self._app_config.models_path / legacy_dest )
if not legacy_dest_path.exists(): for model_path in LEGACY_CORE_MODELS:
continue legacy_dest_path = self._app_config.models_path / model_path
# this returns a unique directory path legacy_dest_path.unlink(missing_ok=True)
new_path = ModelInstallService._download_cache_path(url, self._app_config)
new_path.mkdir(parents=True, exist_ok=True)
shutil.move(legacy_dest_path, new_path / legacy_dest_path.name)
def _remove_unused_core_models(self) -> None: def _remove_unused_core_models(self) -> None:
"""Remove unused core models and their directories.""" """Remove unused core models and their directories."""

View File

@ -19,6 +19,7 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.util import slugify
# TO DO: The loader is not thread safe! # TO DO: The loader is not thread safe!
@ -84,7 +85,7 @@ class ModelLoader(ModelLoaderBase):
except IndexError: except IndexError:
pass pass
cache_path: Path = self._convert_cache.cache_path(config.key) cache_path: Path = self._convert_cache.cache_path(slugify(model_path))
if self._needs_conversion(config, model_path, cache_path): if self._needs_conversion(config, model_path, cache_path):
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type) loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
else: else:

View File

@ -1,6 +1,8 @@
import base64 import base64
import io import io
import os import os
import re
import unicodedata
import warnings import warnings
from pathlib import Path from pathlib import Path
@ -12,6 +14,25 @@ from transformers import logging as transformers_logging
GIG = 1073741824 GIG = 1073741824
def slugify(value: str, allow_unicode: bool = False) -> str:
"""
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
dashes to single dashes. Remove characters that aren't alphanumerics,
underscores, or hyphens. Convert to lowercase. Also strip leading and
trailing whitespace, dashes, and underscores.
Adapted from Django: https://github.com/django/django/blob/main/django/utils/text.py
"""
value = str(value)
if allow_unicode:
value = unicodedata.normalize("NFKC", value)
else:
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
value = re.sub(r"[/]", "_", value.lower())
value = re.sub(r"[^\w\s-]", "", value.lower())
return re.sub(r"[-\s]+", "-", value).strip("-_")
def directory_size(directory: Path) -> int: def directory_size(directory: Path) -> int:
""" """
Return the aggregate size of all files in a directory (bytes). Return the aggregate size of all files in a directory (bytes).

View File

@ -49,9 +49,3 @@ def test_download_and_load(mock_context: InvocationContext):
assert isinstance(model_1, dict) assert isinstance(model_1, dict)
def test_install_model(mock_context: InvocationContext):
key = mock_context.models.install_model("https://www.test.foo/download/test_embedding.safetensors")
assert key is not None
model = mock_context.models.load(key)
assert model is not None
assert model.config.key == key