mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Address change requests in first round of PR reviews.
Pending: - Move model install calls into model manager and create passthrus in invocation_context. - Consider splitting load_model_from_url() into a call to get the path and a call to load the path.
This commit is contained in:
parent
34cdfc61ab
commit
d72f272f16
@ -11,7 +11,6 @@ from invokeai.app.invocations.primitives import ImageOutput
|
|||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
from .fields import InputField, WithBoard, WithMetadata
|
from .fields import InputField, WithBoard, WithMetadata
|
||||||
@ -96,22 +95,21 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
source=ESRGAN_MODEL_URLS[self.model_name],
|
source=ESRGAN_MODEL_URLS[self.model_name],
|
||||||
)
|
)
|
||||||
|
|
||||||
upscaler = RealESRGAN(
|
with loadnet as loadnet_model:
|
||||||
scale=netscale,
|
upscaler = RealESRGAN(
|
||||||
loadnet=loadnet.model,
|
scale=netscale,
|
||||||
model=rrdbnet_model,
|
loadnet=loadnet_model,
|
||||||
half=False,
|
model=rrdbnet_model,
|
||||||
tile=self.tile_size,
|
half=False,
|
||||||
)
|
tile=self.tile_size,
|
||||||
|
)
|
||||||
|
|
||||||
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
||||||
# TODO: This strips the alpha... is that okay?
|
# TODO: This strips the alpha... is that okay?
|
||||||
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||||
upscaled_image = upscaler.upscale(cv2_image)
|
upscaled_image = upscaler.upscale(cv2_image)
|
||||||
|
|
||||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||||
|
|
||||||
TorchDevice.empty_cache()
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=pil_image)
|
image_dto = context.images.save(image=pil_image)
|
||||||
|
|
||||||
|
@ -6,7 +6,6 @@ import re
|
|||||||
import signal
|
import signal
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from hashlib import sha256
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Empty, Queue
|
from queue import Empty, Queue
|
||||||
from shutil import copyfile, copytree, move, rmtree
|
from shutil import copyfile, copytree, move, rmtree
|
||||||
@ -44,6 +43,7 @@ from invokeai.backend.model_manager.probe import ModelProbe
|
|||||||
from invokeai.backend.model_manager.search import ModelSearch
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
from invokeai.backend.util import InvokeAILogger
|
from invokeai.backend.util import InvokeAILogger
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
from invokeai.backend.util.util import slugify
|
||||||
|
|
||||||
from .model_install_base import (
|
from .model_install_base import (
|
||||||
MODEL_SOURCE_TO_TYPE_MAP,
|
MODEL_SOURCE_TO_TYPE_MAP,
|
||||||
@ -396,8 +396,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path:
|
def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path:
|
||||||
model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32]
|
escaped_source = slugify(str(source))
|
||||||
return app_config.download_cache_path / model_hash
|
return app_config.download_cache_path / escaped_source
|
||||||
|
|
||||||
def download_and_cache(
|
def download_and_cache(
|
||||||
self,
|
self,
|
||||||
|
@ -430,67 +430,6 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
model_format=format,
|
model_format=format,
|
||||||
)
|
)
|
||||||
|
|
||||||
def install_model(
|
|
||||||
self,
|
|
||||||
source: str,
|
|
||||||
config: Optional[Dict[str, Any]] = None,
|
|
||||||
access_token: Optional[str] = None,
|
|
||||||
inplace: Optional[bool] = False,
|
|
||||||
timeout: Optional[int] = 0,
|
|
||||||
) -> str:
|
|
||||||
"""Install and register a model in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source: String source; see below
|
|
||||||
config: Optional dict. Any fields in this dict
|
|
||||||
will override corresponding autoassigned probe fields in the
|
|
||||||
model's config record.
|
|
||||||
access_token: Optional access token for remote sources.
|
|
||||||
inplace: If true, installs a local model in place rather than copying
|
|
||||||
it into the models directory
|
|
||||||
timeout: How long to wait on install (in seconds). A value of 0 (default)
|
|
||||||
blocks indefinitely
|
|
||||||
|
|
||||||
The source can be:
|
|
||||||
1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`)
|
|
||||||
2. An http or https URL (`https://foo.bar/foo`)
|
|
||||||
3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`)
|
|
||||||
|
|
||||||
We extend the HuggingFace repo_id syntax to include the variant and the
|
|
||||||
subfolder or path. The following are acceptable alternatives:
|
|
||||||
stabilityai/stable-diffusion-v4
|
|
||||||
stabilityai/stable-diffusion-v4:fp16
|
|
||||||
stabilityai/stable-diffusion-v4:fp16:vae
|
|
||||||
stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
|
||||||
stabilityai/stable-diffusion-v4:onnx:vae
|
|
||||||
|
|
||||||
Because a local file path can look like a huggingface repo_id, the logic
|
|
||||||
first checks whether the path exists on disk, and if not, it is treated as
|
|
||||||
a parseable huggingface repo.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Key to the newly installed model.
|
|
||||||
|
|
||||||
May Raise:
|
|
||||||
ValueError -- bad source
|
|
||||||
UnknownModelException -- remote model not found
|
|
||||||
InvalidModelException -- what was retrieved from remote is not a model
|
|
||||||
TimeoutError -- model could not be installed within timeout
|
|
||||||
Exception -- another error condition
|
|
||||||
"""
|
|
||||||
installer = self._services.model_manager.install
|
|
||||||
job = installer.heuristic_import(
|
|
||||||
source=source,
|
|
||||||
config=config,
|
|
||||||
access_token=access_token,
|
|
||||||
inplace=inplace,
|
|
||||||
)
|
|
||||||
installer.wait_for_job(job, timeout)
|
|
||||||
if job.errored:
|
|
||||||
raise Exception(job.error)
|
|
||||||
key: str = job.config_out.key
|
|
||||||
return key
|
|
||||||
|
|
||||||
def download_and_cache_ckpt(
|
def download_and_cache_ckpt(
|
||||||
self,
|
self,
|
||||||
source: Union[str, AnyHttpUrl],
|
source: Union[str, AnyHttpUrl],
|
||||||
|
@ -1,28 +1,26 @@
|
|||||||
import pathlib
|
|
||||||
import shutil
|
import shutil
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.model_install.model_install_default import ModelInstallService
|
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||||
|
|
||||||
LEGACY_CORE_MODELS = {
|
LEGACY_CORE_MODELS = [
|
||||||
# OpenPose
|
# OpenPose
|
||||||
"https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true": "any/annotators/dwpose/yolox_l.onnx",
|
"any/annotators/dwpose/yolox_l.onnx",
|
||||||
"https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true": "any/annotators/dwpose/dw-ll_ucoco_384.onnx",
|
"any/annotators/dwpose/dw-ll_ucoco_384.onnx",
|
||||||
# DepthAnything
|
# DepthAnything
|
||||||
"https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true": "any/annotators/depth_anything/depth_anything_vitl14.pth",
|
"any/annotators/depth_anything/depth_anything_vitl14.pth",
|
||||||
"https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true": "any/annotators/depth_anything/depth_anything_vitb14.pth",
|
"any/annotators/depth_anything/depth_anything_vitb14.pth",
|
||||||
"https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true": "any/annotators/depth_anything/depth_anything_vits14.pth",
|
"any/annotators/depth_anything/depth_anything_vits14.pth",
|
||||||
# Lama inpaint
|
# Lama inpaint
|
||||||
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt": "core/misc/lama/lama.pt",
|
"core/misc/lama/lama.pt",
|
||||||
# RealESRGAN upscale
|
# RealESRGAN upscale
|
||||||
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth": "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
"core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||||
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth": "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
"core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth": "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
"core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth": "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
"core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||||
}
|
]
|
||||||
|
|
||||||
|
|
||||||
class Migration10Callback:
|
class Migration10Callback:
|
||||||
@ -31,34 +29,24 @@ class Migration10Callback:
|
|||||||
self._logger = logger
|
self._logger = logger
|
||||||
|
|
||||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||||
self._rename_convert_cache()
|
self._remove_convert_cache()
|
||||||
self._migrate_downloaded_models_cache()
|
self._remove_downloaded_models()
|
||||||
self._remove_unused_core_models()
|
self._remove_unused_core_models()
|
||||||
|
|
||||||
def _rename_convert_cache(self) -> None:
|
def _remove_convert_cache(self) -> None:
|
||||||
"""Rename models/.cache to models/.convert_cache."""
|
"""Rename models/.cache to models/.convert_cache."""
|
||||||
|
self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.")
|
||||||
legacy_convert_path = self._app_config.root_path / "models" / ".cache"
|
legacy_convert_path = self._app_config.root_path / "models" / ".cache"
|
||||||
configured_convert_dir = self._app_config.convert_cache_dir
|
shutil.rmtree(legacy_convert_path, ignore_errors=True)
|
||||||
configured_convert_path = self._app_config.convert_cache_path
|
|
||||||
# old convert dir was in use, and current convert dir has not been changed
|
|
||||||
if legacy_convert_path.exists() and configured_convert_dir == pathlib.Path("models/.convert_cache"):
|
|
||||||
self._logger.info(
|
|
||||||
f"Migrating legacy convert cache directory from {str(legacy_convert_path)} to {str(configured_convert_path)}"
|
|
||||||
)
|
|
||||||
shutil.rmtree(configured_convert_path, ignore_errors=True) # shouldn't be needed, but just in case...
|
|
||||||
shutil.move(legacy_convert_path, configured_convert_path)
|
|
||||||
|
|
||||||
def _migrate_downloaded_models_cache(self) -> None:
|
def _remove_downloaded_models(self) -> None:
|
||||||
"""Move used core models to modsl/.download_cache."""
|
"""Remove models from their old locations; they will re-download when needed."""
|
||||||
self._logger.info(f"Migrating legacy core models to {str(self._app_config.download_cache_path)}")
|
self._logger.info(
|
||||||
for url, legacy_dest in LEGACY_CORE_MODELS.items():
|
"Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache."
|
||||||
legacy_dest_path = self._app_config.models_path / legacy_dest
|
)
|
||||||
if not legacy_dest_path.exists():
|
for model_path in LEGACY_CORE_MODELS:
|
||||||
continue
|
legacy_dest_path = self._app_config.models_path / model_path
|
||||||
# this returns a unique directory path
|
legacy_dest_path.unlink(missing_ok=True)
|
||||||
new_path = ModelInstallService._download_cache_path(url, self._app_config)
|
|
||||||
new_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
shutil.move(legacy_dest_path, new_path / legacy_dest_path.name)
|
|
||||||
|
|
||||||
def _remove_unused_core_models(self) -> None:
|
def _remove_unused_core_models(self) -> None:
|
||||||
"""Remove unused core models and their directories."""
|
"""Remove unused core models and their directories."""
|
||||||
|
@ -19,6 +19,7 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
|||||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
from invokeai.backend.util.util import slugify
|
||||||
|
|
||||||
|
|
||||||
# TO DO: The loader is not thread safe!
|
# TO DO: The loader is not thread safe!
|
||||||
@ -84,7 +85,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
except IndexError:
|
except IndexError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
cache_path: Path = self._convert_cache.cache_path(config.key)
|
cache_path: Path = self._convert_cache.cache_path(slugify(model_path))
|
||||||
if self._needs_conversion(config, model_path, cache_path):
|
if self._needs_conversion(config, model_path, cache_path):
|
||||||
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
||||||
else:
|
else:
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
import unicodedata
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -12,6 +14,25 @@ from transformers import logging as transformers_logging
|
|||||||
GIG = 1073741824
|
GIG = 1073741824
|
||||||
|
|
||||||
|
|
||||||
|
def slugify(value: str, allow_unicode: bool = False) -> str:
|
||||||
|
"""
|
||||||
|
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
|
||||||
|
dashes to single dashes. Remove characters that aren't alphanumerics,
|
||||||
|
underscores, or hyphens. Convert to lowercase. Also strip leading and
|
||||||
|
trailing whitespace, dashes, and underscores.
|
||||||
|
|
||||||
|
Adapted from Django: https://github.com/django/django/blob/main/django/utils/text.py
|
||||||
|
"""
|
||||||
|
value = str(value)
|
||||||
|
if allow_unicode:
|
||||||
|
value = unicodedata.normalize("NFKC", value)
|
||||||
|
else:
|
||||||
|
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
|
||||||
|
value = re.sub(r"[/]", "_", value.lower())
|
||||||
|
value = re.sub(r"[^\w\s-]", "", value.lower())
|
||||||
|
return re.sub(r"[-\s]+", "-", value).strip("-_")
|
||||||
|
|
||||||
|
|
||||||
def directory_size(directory: Path) -> int:
|
def directory_size(directory: Path) -> int:
|
||||||
"""
|
"""
|
||||||
Return the aggregate size of all files in a directory (bytes).
|
Return the aggregate size of all files in a directory (bytes).
|
||||||
|
@ -49,9 +49,3 @@ def test_download_and_load(mock_context: InvocationContext):
|
|||||||
assert isinstance(model_1, dict)
|
assert isinstance(model_1, dict)
|
||||||
|
|
||||||
|
|
||||||
def test_install_model(mock_context: InvocationContext):
|
|
||||||
key = mock_context.models.install_model("https://www.test.foo/download/test_embedding.safetensors")
|
|
||||||
assert key is not None
|
|
||||||
model = mock_context.models.load(key)
|
|
||||||
assert model is not None
|
|
||||||
assert model.config.key == key
|
|
||||||
|
Loading…
Reference in New Issue
Block a user