resolve merge conflicts

This commit is contained in:
Lincoln Stein 2024-06-23 10:31:35 -04:00
commit 0df018bd4e
110 changed files with 5600 additions and 3245 deletions

View File

@ -128,7 +128,8 @@ The queue operates on a series of download job objects. These objects
specify the source and destination of the download, and keep track of specify the source and destination of the download, and keep track of
the progress of the download. the progress of the download.
The only job type currently implemented is `DownloadJob`, a pydantic object with the Two job types are defined. `DownloadJob` and
`MultiFileDownloadJob`. The former is a pydantic object with the
following fields: following fields:
| **Field** | **Type** | **Default** | **Description** | | **Field** | **Type** | **Default** | **Description** |
@ -190,6 +191,33 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an
`error_type` field of "DownloadJobCancelledException". In addition, `error_type` field of "DownloadJobCancelledException". In addition,
the job's `cancelled` property will be set to True. the job's `cancelled` property will be set to True.
The `MultiFileDownloadJob` is used for diffusers model downloads,
which contain multiple files and directories under a common root:
| **Field** | **Type** | **Default** | **Description** |
|----------------|-----------------|---------------|-----------------|
| _Fields passed in at job creation time_ |
| `download_parts` | Set[DownloadJob]| | Component download jobs |
| `dest` | Path | | Where to download to |
| `on_start` | Callable | | [optional] callback when the download starts |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_complete` | Callable | | [optional] callback called after successful download completion |
| `on_error` | Callable | | [optional] callback called after an error occurs |
| `id` | int | auto assigned | Job ID, an integer >= 0 |
| _Fields updated over the course of the download task_
| `status` | DownloadJobStatus| | Status code |
| `download_path` | Path | | Path to the root of the downloaded files |
| `bytes` | int | 0 | Bytes downloaded so far |
| `total_bytes` | int | 0 | Total size of the file at the remote site |
| `error_type` | str | | String version of the exception that caused an error during download |
| `error` | str | | String version of the traceback associated with an error |
| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|
Note that the MultiFileDownloadJob does not support the `priority`,
`job_started`, `job_ended` or `content_type` attributes. You can get
these from the individual download jobs in `download_parts`.
### Callbacks ### Callbacks
Download jobs can be associated with a series of callbacks, each with Download jobs can be associated with a series of callbacks, each with
@ -251,11 +279,40 @@ jobs using `list_jobs()`, fetch a single job by its with
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
with `join()`. with `join()`.
#### job = queue.download(source, dest, priority, access_token) #### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
Create a new download job and put it on the queue, returning the Create a new download job and put it on the queue, returning the
DownloadJob object. DownloadJob object.
#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
This is similar to download(), but instead of taking a single source,
it accepts a `parts` argument consisting of a list of
`RemoteModelFile` objects. Each part corresponds to a URL/Path pair,
where the URL is the location of the remote file, and the Path is the
destination.
`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and
consists of a url/path pair. Note that the path *must* be relative.
The method returns a `MultiFileDownloadJob`.
```
from invokeai.backend.model_manager.metadata import RemoteModelFile
remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'',
path='my_model/textencoder/pytorch_model.safetensors'
)
remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt',
path='my_model/vae/diffusers_model.safetensors'
)
job = queue.multifile_download(parts=[remote_file_1, remote_file_2],
dest='/tmp/downloads',
on_progress=TqdmProgress().update)
queue.wait_for_job(job)
print(f"The files were downloaded to {job.download_path}")
```
#### jobs = queue.list_jobs() #### jobs = queue.list_jobs()
Return a list of all active and inactive `DownloadJob`s. Return a list of all active and inactive `DownloadJob`s.

View File

@ -397,18 +397,17 @@ In the event you wish to create a new installer, you may use the
following initialization pattern: following initialization pattern:
``` ```
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import get_config
from invokeai.app.services.model_records import ModelRecordServiceSQL from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.app.services.model_install import ModelInstallService from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.download import DownloadQueueService from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
config = InvokeAIAppConfig.get_config() config = get_config()
config.parse_args()
logger = InvokeAILogger.get_logger(config=config) logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config, logger) db = SqliteDatabase(config.db_path, logger)
record_store = ModelRecordServiceSQL(db) record_store = ModelRecordServiceSQL(db)
queue = DownloadQueueService() queue = DownloadQueueService()
queue.start() queue.start()
@ -1367,12 +1366,20 @@ the in-memory loaded model:
| `model` | AnyModel | The instantiated model (details below) | | `model` | AnyModel | The instantiated model (details below) |
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM | | `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
Because the loader can return multiple model types, it is typed to ### get_model_by_key(key, [submodel]) -> LoadedModel
return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and The `get_model_by_key()` method will retrieve the model using its
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers unique database key. For example:
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious. loaded_model = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
`get_model_by_key()` may raise any of the following exceptions:
* `UnknownModelException` -- key not in database
* `ModelNotFoundException` -- key in database but model not found at path
* `NotImplementedException` -- the loader doesn't know how to load this type of model
### Using the Loaded Model in Inference
`LoadedModel` acts as a context manager. The context loads the model `LoadedModel` acts as a context manager. The context loads the model
into the execution device (e.g. VRAM on CUDA systems), locks the model into the execution device (e.g. VRAM on CUDA systems), locks the model
@ -1380,16 +1387,32 @@ in the execution device for the duration of the context, and returns
the model. Use it like this: the model. Use it like this:
``` ```
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with model_info as vae: with loaded_model as vae:
image = vae.decode(latents)[0] image = vae.decode(latents)[0]
``` ```
`get_model_by_key()` may raise any of the following exceptions: The object returned by the LoadedModel context manager is an
`AnyModel`, which is a Union of `ModelMixin`, `torch.nn.Module`,
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious.
In addition, you may call `LoadedModel.model_on_device()`, a context
manager that returns a tuple of the model's state dict in CPU and the
model itself in VRAM. It is used to optimize the LoRA patching and
unpatching process:
```
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with loaded_model.model_on_device() as (state_dict, vae):
image = vae.decode(latents)[0]
```
Since not all models have state dicts, the `state_dict` return value
can be None.
* `UnknownModelException` -- key not in database
* `ModelNotFoundException` -- key in database but model not found at path
* `NotImplementedException` -- the loader doesn't know how to load this type of model
### Emitting model loading events ### Emitting model loading events
@ -1578,3 +1601,59 @@ This method takes a model key, looks it up using the
`ModelRecordServiceBase` object in `mm.store`, and passes the returned `ModelRecordServiceBase` object in `mm.store`, and passes the returned
model configuration to `load_model_by_config()`. It may raise a model configuration to `load_model_by_config()`. It may raise a
`NotImplementedException`. `NotImplementedException`.
## Invocation Context Model Manager API
Within invocations, the following methods are available from the
`InvocationContext` object:
### context.download_and_cache_model(source) -> Path
This method accepts a `source` of a remote model, downloads and caches
it locally, and then returns a Path to the local model. The source can
be a direct download URL or a HuggingFace repo_id.
In the case of HuggingFace repo_id, the following variants are
recognized:
* stabilityai/stable-diffusion-v4 -- default model
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
You can also point at an arbitrary individual file within a repo_id
directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
### context.load_local_model(model_path, [loader]) -> LoadedModel
This method loads a local model from the indicated path, returning a
`LoadedModel`. The optional loader is a Callable that accepts a Path
to the object, and returns a `AnyModel` object. If no loader is
provided, then the method will use `torch.load()` for a .ckpt or .bin
checkpoint file, `safetensors.torch.load_file()` for a safetensors
checkpoint file, or `cls.from_pretrained()` for a directory that looks
like a diffusers directory.
### context.load_remote_model(source, [loader]) -> LoadedModel
This method accepts a `source` of a remote model, downloads and caches
it locally, loads it, and returns a `LoadedModel`. The source can be a
direct download URL or a HuggingFace repo_id.
In the case of HuggingFace repo_id, the following variants are
recognized:
* stabilityai/stable-diffusion-v4 -- default model
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
You can also point at an arbitrary individual file within a repo_id
directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors

View File

@ -93,7 +93,7 @@ class ApiDependencies:
conditioning = ObjectSerializerForwardCache( conditioning = ObjectSerializerForwardCache(
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
) )
download_queue_service = DownloadQueueService(event_bus=events) download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images") model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager( model_manager = ModelManagerService.build_model_manager(
app_config=configuration, app_config=configuration,

View File

@ -9,7 +9,7 @@ from copy import deepcopy
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Optional, Type
from fastapi import Body, Path, Query, Response, UploadFile from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse from fastapi.responses import FileResponse, HTMLResponse
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from PIL import Image from PIL import Image
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
@ -502,6 +502,133 @@ async def install_model(
return result return result
@model_manager_router.get(
"/install/huggingface",
operation_id="install_hugging_face_model",
responses={
201: {"description": "The model is being installed"},
400: {"description": "Bad request"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_class=HTMLResponse,
)
async def install_hugging_face_model(
source: str = Query(description="HuggingFace repo_id to install"),
) -> HTMLResponse:
"""Install a Hugging Face model using a string identifier."""
def generate_html(title: str, heading: str, repo_id: str, is_error: bool, message: str | None = "") -> str:
if message:
message = f"<p>{message}</p>"
title_class = "error" if is_error else "success"
return f"""
<html>
<head>
<title>{title}</title>
<style>
body {{
text-align: center;
background-color: hsl(220 12% 10% / 1);
font-family: Helvetica, sans-serif;
color: hsl(220 12% 86% / 1);
}}
.repo-id {{
color: hsl(220 12% 68% / 1);
}}
.error {{
color: hsl(0 42% 68% / 1)
}}
.message-box {{
display: inline-block;
border-radius: 5px;
background-color: hsl(220 12% 20% / 1);
padding-inline-end: 30px;
padding: 20px;
padding-inline-start: 30px;
padding-inline-end: 30px;
}}
.container {{
display: flex;
width: 100%;
height: 100%;
align-items: center;
justify-content: center;
}}
a {{
color: inherit
}}
a:visited {{
color: inherit
}}
a:active {{
color: inherit
}}
</style>
</head>
<body style="background-color: hsl(220 12% 10% / 1);">
<div class="container">
<div class="message-box">
<h2 class="{title_class}">{heading}</h2>
{message}
<p class="repo-id">Repo ID: {repo_id}</p>
</div>
</div>
</body>
</html>
"""
try:
metadata = HuggingFaceMetadataFetch().from_id(source)
assert isinstance(metadata, ModelMetadataWithFiles)
except UnknownMetadataException:
title = "Unable to Install Model"
heading = "No HuggingFace repository found with that repo ID."
message = "Ensure the repo ID is correct and try again."
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=400)
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_manager.install
if metadata.is_diffusers:
installer.heuristic_import(
source=source,
inplace=False,
)
elif metadata.ckpt_urls is not None and len(metadata.ckpt_urls) == 1:
installer.heuristic_import(
source=str(metadata.ckpt_urls[0]),
inplace=False,
)
else:
title = "Unable to Install Model"
heading = "This HuggingFace repo has multiple models."
message = "Please use the Model Manager to install this model."
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=200)
title = "Model Install Started"
heading = "Your HuggingFace model is installing now."
message = "You can close this tab and check the Model Manager for installation progress."
return HTMLResponse(content=generate_html(title, heading, source, False, message), status_code=201)
except Exception as e:
logger.error(str(e))
title = "Unable to Install Model"
heading = "There was an problem installing this model."
message = 'Please use the Model Manager directly to install this model. If the issue persists, ask for help on <a href="https://discord.gg/ZmtBAhwWhy">discord</a>.'
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=500)
@model_manager_router.get( @model_manager_router.get(
"/install", "/install",
operation_id="list_model_installs", operation_id="list_model_installs",

View File

@ -0,0 +1,98 @@
from typing import Any, Union
import numpy as np
import numpy.typing as npt
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import TorchDevice
@invocation(
"lblend",
title="Blend Latents",
tags=["latents", "blend"],
category="latents",
version="1.0.3",
)
class BlendLatentsInvocation(BaseInvocation):
"""Blend two latents using a given alpha. Latents must have same size."""
latents_a: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
latents_b: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents_a = context.tensors.load(self.latents_a.latents_name)
latents_b = context.tensors.load(self.latents_b.latents_name)
if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.")
device = TorchDevice.choose_torch_device()
def slerp(
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
v0: Union[torch.Tensor, npt.NDArray[Any]],
v1: Union[torch.Tensor, npt.NDArray[Any]],
DOT_THRESHOLD: float = 0.9995,
) -> Union[torch.Tensor, npt.NDArray[Any]]:
"""
Spherical linear interpolation
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colineal. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
"""
inputs_are_torch = False
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
v0 = v0.detach().cpu().numpy()
if not isinstance(v1, np.ndarray):
inputs_are_torch = True
v1 = v1.detach().cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
return v2_torch
else:
assert isinstance(v2, np.ndarray)
return v2
# blend
bl = slerp(self.alpha, latents_a, latents_b)
assert isinstance(bl, torch.Tensor)
blended_latents: torch.Tensor = bl # for type checking convenience
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)

View File

@ -81,9 +81,13 @@ class CompelInvocation(BaseInvocation):
with ( with (
# apply all patches while the model is on the target device # apply all patches while the model is on the target device
text_encoder_info as text_encoder, text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
tokenizer_info as tokenizer, tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), ModelPatcher.apply_lora_text_encoder(
text_encoder,
loras=_lora_loader(),
model_state_dict=model_state_dict,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as ( ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
@ -174,9 +178,14 @@ class SDXLPromptInvocationBase:
with ( with (
# apply all patches while the model is on the target device # apply all patches while the model is on the target device
text_encoder_info as text_encoder, text_encoder_info.model_on_device() as (state_dict, text_encoder),
tokenizer_info as tokenizer, tokenizer_info as tokenizer,
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), ModelPatcher.apply_lora(
text_encoder,
loras=_lora_loader(),
prefix=lora_prefix,
model_state_dict=state_dict,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as ( ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (

View File

@ -1,6 +1,7 @@
from typing import Literal from typing import Literal
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice
LATENT_SCALE_FACTOR = 8 LATENT_SCALE_FACTOR = 8
""" """
@ -15,3 +16,5 @@ SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
"""A literal type for PIL image modes supported by Invoke""" """A literal type for PIL image modes supported by Invoke"""
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()

View File

@ -2,6 +2,7 @@
# initial implementation by Gregg Helt, 2023 # initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import bool, float from builtins import bool, float
from pathlib import Path
from typing import Dict, List, Literal, Union from typing import Dict, List, Literal, Union
import cv2 import cv2
@ -36,12 +37,13 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
from invokeai.backend.image_util.canny import get_canny_edges from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
@ -139,6 +141,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
return context.images.get_pil(self.image.image_name, "RGB") return context.images.get_pil(self.image.image_name, "RGB")
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
self._context = context
raw_image = self.load_image(context) raw_image = self.load_image(context)
# image type should be PIL.PngImagePlugin.PngImageFile ? # image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image) processed_image = self.run_processor(raw_image)
@ -284,7 +287,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
# depth_and_normal not supported in controlnet_aux v0.0.3 # depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode") # depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
def run_processor(self, image): def run_processor(self, image: Image.Image) -> Image.Image:
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
processed_image = midas_processor( processed_image = midas_processor(
image, image,
@ -311,7 +315,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image): def run_processor(self, image: Image.Image) -> Image.Image:
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = normalbae_processor( processed_image = normalbae_processor(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
@ -330,7 +334,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`") thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`") thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
def run_processor(self, image): def run_processor(self, image: Image.Image) -> Image.Image:
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = mlsd_processor( processed_image = mlsd_processor(
image, image,
@ -353,7 +357,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image): def run_processor(self, image: Image.Image) -> Image.Image:
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators") pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
processed_image = pidi_processor( processed_image = pidi_processor(
image, image,
@ -381,7 +385,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter") w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter") f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
def run_processor(self, image): def run_processor(self, image: Image.Image) -> Image.Image:
content_shuffle_processor = ContentShuffleDetector() content_shuffle_processor = ContentShuffleDetector()
processed_image = content_shuffle_processor( processed_image = content_shuffle_processor(
image, image,
@ -405,7 +409,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image""" """Applies Zoe depth processing to image"""
def run_processor(self, image): def run_processor(self, image: Image.Image) -> Image.Image:
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image) processed_image = zoe_depth_processor(image)
return processed_image return processed_image
@ -426,7 +430,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image): def run_processor(self, image: Image.Image) -> Image.Image:
mediapipe_face_processor = MediapipeFaceDetector() mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor( processed_image = mediapipe_face_processor(
image, image,
@ -454,7 +458,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image): def run_processor(self, image: Image.Image) -> Image.Image:
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor( processed_image = leres_processor(
image, image,
@ -496,8 +500,8 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA) np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
return np_img return np_img
def run_processor(self, img): def run_processor(self, image: Image.Image) -> Image.Image:
np_img = np.array(img, dtype=np.uint8) np_img = np.array(image, dtype=np.uint8)
processed_np_image = self.tile_resample( processed_np_image = self.tile_resample(
np_img, np_img,
# res=self.tile_size, # res=self.tile_size,
@ -520,7 +524,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image): def run_processor(self, image: Image.Image) -> Image.Image:
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained( segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
"ybelkada/segment-anything", subfolder="checkpoints" "ybelkada/segment-anything", subfolder="checkpoints"
@ -566,7 +570,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size) color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
def run_processor(self, image: Image.Image): def run_processor(self, image: Image.Image) -> Image.Image:
np_image = np.array(image, dtype=np.uint8) np_image = np.array(image, dtype=np.uint8)
height, width = np_image.shape[:2] height, width = np_image.shape[:2]
@ -601,10 +605,16 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
) )
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image): def run_processor(self, image: Image.Image) -> Image.Image:
depth_anything_detector = DepthAnythingDetector() def loader(model_path: Path):
depth_anything_detector.load_model(model_size=self.model_size) return DepthAnythingDetector.load_model(
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
)
with self._context.models.load_remote_model(
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
) as model:
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
processed_image = depth_anything_detector(image=image, resolution=self.resolution) processed_image = depth_anything_detector(image=image, resolution=self.resolution)
return processed_image return processed_image
@ -624,8 +634,11 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
draw_hands: bool = InputField(default=False) draw_hands: bool = InputField(default=False)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image): def run_processor(self, image: Image.Image) -> Image.Image:
dw_openpose = DWOpenposeDetector() onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
processed_image = dw_openpose( processed_image = dw_openpose(
image, image,
draw_face=self.draw_face, draw_face=self.draw_face,

View File

@ -0,0 +1,80 @@
from typing import Optional
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import DenoiseMaskOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@invocation(
"create_denoise_mask",
title="Create Denoise Mask",
tags=["mask", "denoise"],
category="latents",
version="1.0.2",
)
class CreateDenoiseMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32,
description=FieldDescriptions.fp32,
ui_order=4,
)
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
if mask_image.mode != "L":
mask_image = mask_image.convert("L")
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
if mask_tensor.dim() == 3:
mask_tensor = mask_tensor.unsqueeze(0)
# if shape is not None:
# mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
return mask_tensor
@torch.no_grad()
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
if self.image is not None:
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
else:
image_tensor = None
mask = self.prep_mask_tensor(
context.images.get_pil(self.mask.image_name),
)
if image_tensor is not None:
vae_info = context.models.load(self.vae.vae)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
# TODO:
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
masked_latents_name = context.tensors.save(tensor=masked_latents)
else:
masked_latents_name = None
mask_name = context.tensors.save(tensor=mask)
return DenoiseMaskOutput.build(
mask_name=mask_name,
masked_latents_name=masked_latents_name,
gradient=False,
)

View File

@ -0,0 +1,138 @@
from typing import Literal, Optional
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image, ImageFilter
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
ImageField,
Input,
InputField,
OutputField,
)
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@invocation_output("gradient_mask_output")
class GradientMaskOutput(BaseInvocationOutput):
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
expanded_mask_area: ImageField = OutputField(
description="Image representing the total gradient area of the mask. For paste-back purposes."
)
@invocation(
"create_gradient_mask",
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
version="1.1.0",
)
class CreateGradientMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
edge_radius: int = InputField(
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
)
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
minimum_denoise: float = InputField(
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
)
image: Optional[ImageField] = InputField(
default=None,
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
title="[OPTIONAL] Image",
ui_order=6,
)
unet: Optional[UNetField] = InputField(
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
default=None,
input=Input.Connection,
title="[OPTIONAL] UNet",
ui_order=5,
)
vae: Optional[VAEField] = InputField(
default=None,
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
title="[OPTIONAL] VAE",
input=Input.Connection,
ui_order=7,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32,
description=FieldDescriptions.fp32,
ui_order=9,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
if self.edge_radius > 0:
if self.coherence_mode == "Box Blur":
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
else: # Gaussian Blur OR Staged
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
# redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2
threshold = 1 - self.minimum_denoise
if self.coherence_mode == "Staged":
# wherever the blur_tensor is less than fully masked, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
else:
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
# compute a [0, 1] mask from the blur_tensor
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
expanded_image_dto = context.images.save(expanded_mask_image)
masked_latents_name = None
if self.unet is not None and self.vae is not None and self.image is not None:
# all three fields must be present at the same time
main_model_config = context.models.get_config(self.unet.unet.key)
assert isinstance(main_model_config, MainConfigBase)
if main_model_config.variant is ModelVariantType.Inpaint:
mask = blur_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae)
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
masked_latents = ImageToLatentsInvocation.vae_encode(
vae_info, self.fp32, self.tiled, masked_image.clone()
)
masked_latents_name = context.tensors.save(tensor=masked_latents)
return GradientMaskOutput(
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
)

View File

@ -0,0 +1,61 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
# The Crop Latents node was copied from @skunkworxdark's implementation here:
# https://github.com/skunkworxdark/XYGrid_nodes/blob/74647fa9c1fa57d317a94bd43ca689af7f0aae5e/images_to_grids.py#L1117C1-L1167C80
@invocation(
"crop_latents",
title="Crop Latents",
tags=["latents", "crop"],
category="latents",
version="1.0.2",
)
# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
# Currently, if the class names conflict then 'GET /openapi.json' fails.
class CropLatentsCoreInvocation(BaseInvocation):
"""Crops a latent-space tensor to a box specified in image-space. The box dimensions and coordinates must be
divisible by the latent scale factor of 8.
"""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
x: int = InputField(
ge=0,
multiple_of=LATENT_SCALE_FACTOR,
description="The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
y: int = InputField(
ge=0,
multiple_of=LATENT_SCALE_FACTOR,
description="The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
width: int = InputField(
ge=1,
multiple_of=LATENT_SCALE_FACTOR,
description="The width (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
height: int = InputField(
ge=1,
multiple_of=LATENT_SCALE_FACTOR,
description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
x1 = self.x // LATENT_SCALE_FACTOR
y1 = self.y // LATENT_SCALE_FACTOR
x2 = x1 + (self.width // LATENT_SCALE_FACTOR)
y2 = y1 + (self.height // LATENT_SCALE_FACTOR)
cropped_latents = latents[..., y1:y2, x1:x2]
name = context.tensors.save(tensor=cropped_latents)
return LatentsOutput.build(latents_name=name, latents=cropped_latents)

View File

@ -0,0 +1,811 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torchvision
import torchvision.transforms as T
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.adapter import T2IAdapter
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
from diffusers.schedulers.scheduling_tcd import TCDScheduler
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.fields import (
ConditioningField,
DenoiseMaskField,
FieldDescriptions,
Input,
InputField,
LatentsField,
UIType,
)
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.model import ModelIdentifierField, UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
StableDiffusionGeneratorPipeline,
T2IAdapterData,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
IPAdapterConditioningInfo,
IPAdapterData,
Range,
SDXLConditioningInfo,
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.mask import to_standard_float_mask
from invokeai.backend.util.silence_warnings import SilenceWarnings
def get_scheduler(
context: InvocationContext,
scheduler_info: ModelIdentifierField,
scheduler_name: str,
seed: int,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(scheduler_info)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {
**scheduler_config,
**scheduler_extra_config, # FIXME
"_backup": scheduler_config,
}
# make dpmpp_sde reproducable(seed can be passed only in initializer)
if scheduler_class is DPMSolverSDEScheduler:
scheduler_config["noise_sampler_seed"] = seed
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
if not hasattr(scheduler, "uses_inpainting_model"):
scheduler.uses_inpainting_model = lambda: False
assert isinstance(scheduler, Scheduler)
return scheduler
@invocation(
"denoise_latents",
title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
version="1.5.3",
)
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
)
negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
)
noise: Optional[LatentsField] = InputField(
default=None,
description=FieldDescriptions.noise,
input=Input.Connection,
ui_order=3,
)
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
cfg_scale: Union[float, List[float]] = InputField(
default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale"
)
denoising_start: float = InputField(
default=0.0,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
unet: UNetField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
ui_order=2,
)
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
default=None,
input=Input.Connection,
ui_order=5,
)
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
description=FieldDescriptions.ip_adapter,
title="IP-Adapter",
default=None,
input=Input.Connection,
ui_order=6,
)
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField(
description=FieldDescriptions.t2i_adapter,
title="T2I-Adapter",
default=None,
input=Input.Connection,
ui_order=7,
)
cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
latents: Optional[LatentsField] = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
ui_order=4,
)
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None,
description=FieldDescriptions.mask,
input=Input.Connection,
ui_order=8,
)
@field_validator("cfg_scale")
def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
def _get_text_embeddings_and_masks(
self,
cond_list: list[ConditioningField],
context: InvocationContext,
device: torch.device,
dtype: torch.dtype,
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
"""Get the text embeddings and masks from the input conditioning fields."""
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = []
for cond in cond_list:
cond_data = context.conditioning.load(cond.conditioning_name)
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
mask = cond.mask
if mask is not None:
mask = context.tensors.load(mask.tensor_name)
text_embeddings_masks.append(mask)
return text_embeddings, text_embeddings_masks
def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
Returns:
torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width).
"""
if mask is None:
return torch.ones((1, 1, target_height, target_width), dtype=dtype)
mask = to_standard_float_mask(mask, out_dtype=dtype)
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
resized_mask = tf(mask)
return resized_mask
def _concat_regional_text_embeddings(
self,
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]],
latent_height: int,
latent_width: int,
dtype: torch.dtype,
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
if masks is None:
masks = [None] * len(text_conditionings)
assert len(text_conditionings) == len(masks)
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks)
text_embedding = []
pooled_embedding = None
add_time_ids = None
cur_text_embedding_len = 0
processed_masks = []
embedding_ranges = []
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
mask = masks[prompt_idx]
if is_sdxl:
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
# prompts without a mask. We prefer prompts without a mask, because they are more likely to contain
# global prompt information. In an ideal case, there should be exactly one global prompt without a
# mask, but we don't enforce this.
# HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a
# fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use
# them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
# pretty major breaking change to a popular node, so for now we use this hack.
if pooled_embedding is None or mask is None:
pooled_embedding = text_embedding_info.pooled_embeds
if add_time_ids is None or mask is None:
add_time_ids = text_embedding_info.add_time_ids
text_embedding.append(text_embedding_info.embeds)
if not all_masks_are_none:
embedding_ranges.append(
Range(
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
)
)
processed_masks.append(
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
)
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
text_embedding = torch.cat(text_embedding, dim=1)
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
regions = None
if not all_masks_are_none:
regions = TextConditioningRegions(
masks=torch.cat(processed_masks, dim=1),
ranges=embedding_ranges,
)
if is_sdxl:
return (
SDXLConditioningInfo(embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids),
regions,
)
return BasicConditioningInfo(embeds=text_embedding), regions
def get_conditioning_data(
self,
context: InvocationContext,
unet: UNet2DConditionModel,
latent_height: int,
latent_width: int,
) -> TextConditioningData:
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
cond_list = self.positive_conditioning
if not isinstance(cond_list, list):
cond_list = [cond_list]
uncond_list = self.negative_conditioning
if not isinstance(uncond_list, list):
uncond_list = [uncond_list]
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype
)
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
if isinstance(self.cfg_scale, list):
assert (
len(self.cfg_scale) == self.steps
), "cfg_scale (list) must have the same length as the number of steps"
conditioning_data = TextConditioningData(
uncond_text=uncond_text_embedding,
cond_text=cond_text_embedding,
uncond_regions=uncond_regions,
cond_regions=cond_regions,
guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
)
return conditioning_data
def create_pipeline(
self,
unet: UNet2DConditionModel,
scheduler: Scheduler,
) -> StableDiffusionGeneratorPipeline:
class FakeVae:
class FakeVaeConfig:
def __init__(self) -> None:
self.block_out_channels = [0]
def __init__(self) -> None:
self.config = FakeVae.FakeVaeConfig()
return StableDiffusionGeneratorPipeline(
vae=FakeVae(), # TODO: oh...
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
def prep_control_data(
self,
context: InvocationContext,
control_input: Optional[Union[ControlField, List[ControlField]]],
latents_shape: List[int],
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> Optional[List[ControlNetData]]:
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
if control_input is None:
control_list = None
elif isinstance(control_input, list) and len(control_input) == 0:
control_list = None
elif isinstance(control_input, ControlField):
control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
control_list = control_input
else:
control_list = None
if control_list is None:
return None
# After above handling, any control that is not None should now be of type list[ControlField].
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
controlnet_data = []
for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
# control_models.append(control_model)
control_image_field = control_info.image
input_image = context.images.get_pil(control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
)
control_item = ControlNetData(
model=control_model, # model object
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
# any resizing needed should currently be happening in prepare_control_image(),
# but adding resize_mode to ControlNetData in case needed in the future
resize_mode=control_info.resize_mode,
)
controlnet_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return controlnet_data
def prep_ip_adapter_image_prompts(
self,
context: InvocationContext,
ip_adapters: List[IPAdapterField],
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
image_prompts = []
for single_ip_adapter in ip_adapters:
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
assert isinstance(ip_adapter_model, IPAdapter)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
single_ipa_image_fields = [single_ipa_image_fields]
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
with image_encoder_model_info as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
single_ipa_images, image_encoder_model
)
image_prompts.append((image_prompt_embeds, uncond_image_prompt_embeds))
return image_prompts
def prep_ip_adapter_data(
self,
context: InvocationContext,
ip_adapters: List[IPAdapterField],
image_prompts: List[Tuple[torch.Tensor, torch.Tensor]],
exit_stack: ExitStack,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
) -> Optional[List[IPAdapterData]]:
"""If IP-Adapter is enabled, then this function loads the requisite models and adds the image prompt conditioning data."""
ip_adapter_data_list = []
for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
ip_adapters, image_prompts, strict=True
):
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
mask_field = single_ip_adapter.mask
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
ip_adapter_data_list.append(
IPAdapterData(
ip_adapter_model=ip_adapter_model,
weight=single_ip_adapter.weight,
target_blocks=single_ip_adapter.target_blocks,
begin_step_percent=single_ip_adapter.begin_step_percent,
end_step_percent=single_ip_adapter.end_step_percent,
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
mask=mask,
)
)
return ip_adapter_data_list if len(ip_adapter_data_list) > 0 else None
def run_t2i_adapters(
self,
context: InvocationContext,
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
latents_shape: list[int],
do_classifier_free_guidance: bool,
) -> Optional[list[T2IAdapterData]]:
if t2i_adapter is None:
return None
# Handle the possibility that t2i_adapter could be a list or a single T2IAdapterField.
if isinstance(t2i_adapter, T2IAdapterField):
t2i_adapter = [t2i_adapter]
if len(t2i_adapter) == 0:
return None
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
image = context.images.get_pil(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
if t2i_adapter_model_config.base == BaseModelType.StableDiffusion1:
max_unet_downscale = 8
elif t2i_adapter_model_config.base == BaseModelType.StableDiffusionXL:
max_unet_downscale = 4
else:
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")
t2i_adapter_model: T2IAdapter
with t2i_adapter_loaded_model as t2i_adapter_model:
total_downscale_factor = t2i_adapter_model.total_downscale_factor
# Resize the T2I-Adapter input image.
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
# result will match the latent image's dimensions after max_unet_downscale is applied.
t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor
t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
# T2I-Adapter model.
#
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
# of the same requirements (e.g. preserving binary masks during resize).
t2i_image = prepare_control_image(
image=image,
do_classifier_free_guidance=False,
width=t2i_input_width,
height=t2i_input_height,
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
device=t2i_adapter_model.device,
dtype=t2i_adapter_model.dtype,
resize_mode=t2i_adapter_field.resize_mode,
)
adapter_state = t2i_adapter_model(t2i_image)
if do_classifier_free_guidance:
for idx, value in enumerate(adapter_state):
adapter_state[idx] = torch.cat([value] * 2, dim=0)
t2i_adapter_data.append(
T2IAdapterData(
adapter_state=adapter_state,
weight=t2i_adapter_field.weight,
begin_step_percent=t2i_adapter_field.begin_step_percent,
end_step_percent=t2i_adapter_field.end_step_percent,
)
)
return t2i_adapter_data
# original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps
def init_scheduler(
self,
scheduler: Union[Scheduler, ConfigMixin],
device: torch.device,
steps: int,
denoising_start: float,
denoising_end: float,
seed: int,
) -> Tuple[int, List[int], int, Dict[str, Any]]:
assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False):
scheduler.set_timesteps(steps, device="cpu")
timesteps = scheduler.timesteps.to(device=device)
else:
scheduler.set_timesteps(steps, device=device)
timesteps = scheduler.timesteps
# skip greater order timesteps
_timesteps = timesteps[:: scheduler.order]
# get start timestep index
t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start)))
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps)))
# get end timestep index
t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end)))
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:])))
# apply order to indexes
t_start_idx *= scheduler.order
t_end_idx *= scheduler.order
init_timestep = timesteps[t_start_idx : t_start_idx + 1]
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
num_inference_steps = len(timesteps) // scheduler.order
scheduler_step_kwargs: Dict[str, Any] = {}
scheduler_step_signature = inspect.signature(scheduler.step)
if "generator" in scheduler_step_signature.parameters:
# At some point, someone decided that schedulers that accept a generator should use the original seed with
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
# reproducibility.
#
# These Invoke-supported schedulers accept a generator as of 2024-06-04:
# - DDIMScheduler
# - DDPMScheduler
# - DPMSolverMultistepScheduler
# - EulerAncestralDiscreteScheduler
# - EulerDiscreteScheduler
# - KDPM2AncestralDiscreteScheduler
# - LCMScheduler
# - TCDScheduler
scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)})
if isinstance(scheduler, TCDScheduler):
scheduler_step_kwargs.update({"eta": 1.0})
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
def prep_inpaint_mask(
self, context: InvocationContext, latents: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
if self.denoise_mask is None:
return None, None, False
mask = context.tensors.load(self.denoise_mask.mask_name)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
if self.denoise_mask.masked_latents_name is not None:
masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
else:
masked_latents = torch.where(mask < 0.5, 0.0, latents)
return 1 - mask, masked_latents, self.denoise_mask.gradient
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
seed = None
noise = None
if self.noise is not None:
noise = context.tensors.load(self.noise.latents_name)
seed = self.noise.seed
if self.latents is not None:
latents = context.tensors.load(self.latents.latents_name)
if seed is None:
seed = self.latents.seed
if noise is not None and noise.shape[1:] != latents.shape[1:]:
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
elif noise is not None:
latents = torch.zeros_like(noise)
else:
raise Exception("'latents' or 'noise' must be provided!")
if seed is None:
seed = 0
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
# below. Investigate whether this is appropriate.
t2i_adapter_data = self.run_t2i_adapters(
context,
self.t2i_adapter,
latents.shape,
do_classifier_free_guidance=True,
)
ip_adapters: List[IPAdapterField] = []
if self.ip_adapter is not None:
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
if isinstance(self.ip_adapter, list):
ip_adapters = self.ip_adapter
else:
ip_adapters = [self.ip_adapter]
# If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
# a series of image conditioning embeddings. This is being done here rather than in the
# big model context below in order to use less VRAM on low-VRAM systems.
# The image prompts are then passed to prep_ip_adapter_data().
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
# get the unet's config so that we can pass the base to dispatch_progress()
unet_config = context.models.get_config(self.unet.unet.key)
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
unet_info.model_on_device() as (model_state_dict, unet),
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
set_seamless(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(
unet,
loras=_lora_loader(),
model_state_dict=model_state_dict,
),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
if mask is not None:
mask = mask.to(device=unet.device, dtype=unet.dtype)
if masked_latents is not None:
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
pipeline = self.create_pipeline(unet, scheduler)
_, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
)
controlnet_data = self.prep_control_data(
context=context,
control_input=self.control,
latents_shape=latents.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
ip_adapter_data = self.prep_ip_adapter_data(
context=context,
ip_adapters=ip_adapters,
image_prompts=image_prompts,
exit_stack=exit_stack,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
result_latents = pipeline.latents_from_embeddings(
latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
mask=mask,
masked_latents=masked_latents,
gradient_mask=gradient_mask,
num_inference_steps=num_inference_steps,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
control_data=controlnet_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
callback=step_callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)

View File

@ -0,0 +1,65 @@
import math
from typing import Tuple
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
from invokeai.app.invocations.model import UNetField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType
@invocation_output("ideal_size_output")
class IdealSizeOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
width: int = OutputField(description="The ideal width of the image (in pixels)")
height: int = OutputField(description="The ideal height of the image (in pixels)")
@invocation(
"ideal_size",
title="Ideal Size",
tags=["latents", "math", "ideal_size"],
version="1.0.3",
)
class IdealSizeInvocation(BaseInvocation):
"""Calculates the ideal size for generation to avoid duplication"""
width: int = InputField(default=1024, description="Final image width")
height: int = InputField(default=576, description="Final image height")
unet: UNetField = InputField(default=None, description=FieldDescriptions.unet)
multiplier: float = InputField(
default=1.0,
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in "
"initial generation artifacts if too large)",
)
def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]:
return tuple((x - x % multiple_of) for x in args)
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
unet_config = context.models.get_config(self.unet.unet.key)
aspect = self.width / self.height
dimension: float = 512
if unet_config.base == BaseModelType.StableDiffusion2:
dimension = 768
elif unet_config.base == BaseModelType.StableDiffusionXL:
dimension = 1024
dimension = dimension * self.multiplier
min_dimension = math.floor(dimension * 0.5)
model_area = dimension * dimension # hardcoded for now since all models are trained on square images
if aspect > 1.0:
init_height = max(min_dimension, math.sqrt(model_area / aspect))
init_width = init_height * aspect
else:
init_width = max(min_dimension, math.sqrt(model_area * aspect))
init_height = init_width / aspect
scaled_width, scaled_height = self.trim_to_multiple_of(
math.floor(init_width),
math.floor(init_height),
)
return IdealSizeOutput(width=scaled_width, height=scaled_height)

View File

@ -0,0 +1,125 @@
from functools import singledispatchmethod
import einops
import torch
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@invocation(
"i2l",
title="Image to Latents",
tags=["latents", "image", "vae", "i2l"],
category="latents",
version="1.0.2",
)
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
image: ImageField = InputField(
description="The image to encode",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@staticmethod
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info as vae:
assert isinstance(vae, torch.nn.Module)
orig_dtype = vae.dtype
if upcast:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(orig_dtype)
vae.decoder.conv_in.to(orig_dtype)
vae.decoder.mid_block.to(orig_dtype)
# else:
# latents = latents.float()
else:
vae.to(dtype=torch.float16)
# latents = latents.half()
if tiled:
vae.enable_tiling()
else:
vae.disable_tiling()
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype)
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
@singledispatchmethod
@staticmethod
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
assert isinstance(vae, torch.nn.Module)
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents: torch.Tensor = image_tensor_dist.sample().to(
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!
return latents
@_encode_to_tensor.register
@staticmethod
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
assert isinstance(vae, torch.nn.Module)
latents: torch.FloatTensor = vae.encode(image_tensor).latents
return latents

View File

@ -42,15 +42,16 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Infill the image with the specified method""" """Infill the image with the specified method"""
pass pass
def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]: def load_image(self) -> tuple[Image.Image, bool]:
"""Process the image to have an alpha channel before being infilled""" """Process the image to have an alpha channel before being infilled"""
image = context.images.get_pil(self.image.image_name) image = self._context.images.get_pil(self.image.image_name)
has_alpha = True if image.mode == "RGBA" else False has_alpha = True if image.mode == "RGBA" else False
return image, has_alpha return image, has_alpha
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
self._context = context
# Retrieve and process image to be infilled # Retrieve and process image to be infilled
input_image, has_alpha = self.load_image(context) input_image, has_alpha = self.load_image()
# If the input image has no alpha channel, return it # If the input image has no alpha channel, return it
if has_alpha is False: if has_alpha is False:
@ -133,7 +134,11 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model""" """Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image): def infill(self, image: Image.Image):
lama = LaMA() with self._context.models.load_remote_model(
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
loader=LaMA.load_jit_model,
) as model:
lama = LaMA(model)
return lama(image) return lama(image)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,107 @@
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion import set_seamless
from invokeai.backend.util.devices import TorchDevice
@invocation(
"l2i",
title="Latents to Image",
tags=["latents", "image", "vae", "l2i"],
category="latents",
version="1.2.2",
)
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module)
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(latents.dtype)
vae.decoder.conv_in.to(latents.dtype)
vae.decoder.mid_block.to(latents.dtype)
else:
latents = latents.float()
else:
vae.to(dtype=torch.float16)
latents = latents.half()
if self.tiled or context.config.get().force_tiled_decode:
vae.enable_tiling()
else:
vae.disable_tiling()
# clear memory as vae decode can request a lot
TorchDevice.empty_cache()
with torch.inference_mode():
# copied from diffusers pipeline
latents = latents / vae.config.scaling_factor
image = vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
TorchDevice.empty_cache()
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

View File

@ -0,0 +1,103 @@
from typing import Literal
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
)
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import TorchDevice
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
@invocation(
"lresize",
title="Resize Latents",
tags=["latents", "resize"],
category="latents",
version="1.0.2",
)
class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
width: int = InputField(
ge=64,
multiple_of=LATENT_SCALE_FACTOR,
description=FieldDescriptions.width,
)
height: int = InputField(
ge=64,
multiple_of=LATENT_SCALE_FACTOR,
description=FieldDescriptions.width,
)
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
device = TorchDevice.choose_torch_device()
resized_latents = torch.nn.functional.interpolate(
latents.to(device),
size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR),
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@invocation(
"lscale",
title="Scale Latents",
tags=["latents", "resize"],
category="latents",
version="1.0.2",
)
class ScaleLatentsInvocation(BaseInvocation):
"""Scales latents by a given factor."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor)
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
device = TorchDevice.choose_torch_device()
# resizing
resized_latents = torch.nn.functional.interpolate(
latents.to(device),
scale_factor=self.scale_factor,
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)

View File

@ -0,0 +1,34 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import (
FieldDescriptions,
InputField,
OutputField,
UIType,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("scheduler_output")
class SchedulerOutput(BaseInvocationOutput):
scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
@invocation(
"scheduler",
title="Scheduler",
tags=["scheduler"],
category="latents",
version="1.0.0",
)
class SchedulerInvocation(BaseInvocation):
"""Selects a scheduler."""
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
def invoke(self, context: InvocationContext) -> SchedulerOutput:
return SchedulerOutput(scheduler=self.scheduler)

View File

@ -1,5 +1,4 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path
from typing import Literal from typing import Literal
import cv2 import cv2
@ -10,10 +9,8 @@ from pydantic import ConfigDict
from invokeai.app.invocations.fields import ImageField from invokeai.app.invocations.fields import ImageField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, invocation from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata from .fields import InputField, WithBoard, WithMetadata
@ -52,7 +49,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
rrdbnet_model = None rrdbnet_model = None
netscale = None netscale = None
esrgan_model_path = None
if self.model_name in [ if self.model_name in [
"RealESRGAN_x4plus.pth", "RealESRGAN_x4plus.pth",
@ -95,16 +91,14 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
context.logger.error(msg) context.logger.error(msg)
raise ValueError(msg) raise ValueError(msg)
esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}") loadnet = context.models.load_remote_model(
source=ESRGAN_MODEL_URLS[self.model_name],
# Downloads the ESRGAN model if it doesn't already exist
download_with_progress_bar(
name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
) )
with loadnet as loadnet_model:
upscaler = RealESRGAN( upscaler = RealESRGAN(
scale=netscale, scale=netscale,
model_path=esrgan_model_path, loadnet=loadnet_model,
model=rrdbnet_model, model=rrdbnet_model,
half=False, half=False,
tile=self.tile_size, tile=self.tile_size,
@ -114,9 +108,8 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
# TODO: This strips the alpha... is that okay? # TODO: This strips the alpha... is that okay?
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
upscaled_image = upscaler.upscale(cv2_image) upscaled_image = upscaler.upscale(cv2_image)
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
TorchDevice.empty_cache() pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
image_dto = context.images.save(image=pil_image) image_dto = context.images.save(image=pil_image)

View File

@ -86,6 +86,7 @@ class InvokeAIAppConfig(BaseSettings):
patchmatch: Enable patchmatch inpaint code. patchmatch: Enable patchmatch inpaint code.
models_dir: Path to the models directory. models_dir: Path to the models directory.
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location. convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
download_cache_dir: Path to the directory that contains dynamically downloaded models.
legacy_conf_dir: Path to directory of legacy checkpoint config files. legacy_conf_dir: Path to directory of legacy checkpoint config files.
db_dir: Path to InvokeAI databases directory. db_dir: Path to InvokeAI databases directory.
outputs_dir: Path to directory for outputs. outputs_dir: Path to directory for outputs.
@ -114,6 +115,7 @@ class InvokeAIAppConfig(BaseSettings):
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue. max_queue_size: Maximum number of items in the session queue.
max_threads: Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set. max_threads: Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.
clear_queue_on_startup: Empties session queue on startup.
allow_nodes: List of nodes to allow. Omit to allow all. allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none. deny_nodes: List of nodes to deny. Omit to deny none.
node_cache_size: How many cached nodes to keep in memory. node_cache_size: How many cached nodes to keep in memory.
@ -148,7 +150,8 @@ class InvokeAIAppConfig(BaseSettings):
# PATHS # PATHS
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.") models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
convert_cache_dir: Path = Field(default=Path("models/.cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.") convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.")
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.") legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.") db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.") outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
@ -188,6 +191,7 @@ class InvokeAIAppConfig(BaseSettings):
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.") pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.") max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
max_threads: Optional[int] = Field(default=None, description="Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.") max_threads: Optional[int] = Field(default=None, description="Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.")
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
# NODES # NODES
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.") allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
@ -307,6 +311,11 @@ class InvokeAIAppConfig(BaseSettings):
"""Path to the converted cache models directory, resolved to an absolute path..""" """Path to the converted cache models directory, resolved to an absolute path.."""
return self._resolve(self.convert_cache_dir) return self._resolve(self.convert_cache_dir)
@property
def download_cache_path(self) -> Path:
"""Path to the downloaded models directory, resolved to an absolute path.."""
return self._resolve(self.download_cache_dir)
@property @property
def custom_nodes_path(self) -> Path: def custom_nodes_path(self) -> Path:
"""Path to the custom nodes directory, resolved to an absolute path..""" """Path to the custom nodes directory, resolved to an absolute path.."""

View File

@ -1,10 +1,17 @@
"""Init file for download queue.""" """Init file for download queue."""
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException from .download_base import (
DownloadJob,
DownloadJobStatus,
DownloadQueueServiceBase,
MultiFileDownloadJob,
UnknownJobIDException,
)
from .download_default import DownloadQueueService, TqdmProgress from .download_default import DownloadQueueService, TqdmProgress
__all__ = [ __all__ = [
"DownloadJob", "DownloadJob",
"MultiFileDownloadJob",
"DownloadQueueServiceBase", "DownloadQueueServiceBase",
"DownloadQueueService", "DownloadQueueService",
"TqdmProgress", "TqdmProgress",

View File

@ -5,11 +5,13 @@ from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from functools import total_ordering from functools import total_ordering
from pathlib import Path from pathlib import Path
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from invokeai.backend.model_manager.metadata import RemoteModelFile
class DownloadJobStatus(str, Enum): class DownloadJobStatus(str, Enum):
"""State of a download job.""" """State of a download job."""
@ -33,30 +35,23 @@ class ServiceInactiveException(Exception):
"""This exception is raised when user attempts to initiate a download before the service is started.""" """This exception is raised when user attempts to initiate a download before the service is started."""
DownloadEventHandler = Callable[["DownloadJob"], None] SingleFileDownloadEventHandler = Callable[["DownloadJob"], None]
DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None] SingleFileDownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None]
MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None]
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
@total_ordering class DownloadJobBase(BaseModel):
class DownloadJob(BaseModel): """Base of classes to monitor and control downloads."""
"""Class to monitor and control a model download request."""
# required variables to be passed in on creation
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path")
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
# automatically assigned on creation # automatically assigned on creation
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
# set internally during download process dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path")
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download") status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file")
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
job_ended: Optional[str] = Field(
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
bytes: int = Field(default=0, description="Bytes downloaded so far") bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total file size (bytes)") total_bytes: int = Field(default=0, description="Total file size (bytes)")
@ -74,14 +69,6 @@ class DownloadJob(BaseModel):
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None) _on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None) _on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
def __hash__(self) -> int:
"""Return hash of the string representation of this object, for indexing."""
return hash(str(self))
def __le__(self, other: "DownloadJob") -> bool:
"""Return True if this job's priority is less than another's."""
return self.priority <= other.priority
def cancel(self) -> None: def cancel(self) -> None:
"""Call to cancel the job.""" """Call to cancel the job."""
self._cancelled = True self._cancelled = True
@ -98,6 +85,11 @@ class DownloadJob(BaseModel):
"""Return true if job completed without errors.""" """Return true if job completed without errors."""
return self.status == DownloadJobStatus.COMPLETED return self.status == DownloadJobStatus.COMPLETED
@property
def waiting(self) -> bool:
"""Return true if the job is waiting to run."""
return self.status == DownloadJobStatus.WAITING
@property @property
def running(self) -> bool: def running(self) -> bool:
"""Return true if the job is running.""" """Return true if the job is running."""
@ -154,6 +146,37 @@ class DownloadJob(BaseModel):
self._on_cancelled = on_cancelled self._on_cancelled = on_cancelled
@total_ordering
class DownloadJob(DownloadJobBase):
"""Class to monitor and control a model download request."""
# required variables to be passed in on creation
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
# set internally during download process
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
job_ended: Optional[str] = Field(
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
def __hash__(self) -> int:
"""Return hash of the string representation of this object, for indexing."""
return hash(str(self))
def __le__(self, other: "DownloadJob") -> bool:
"""Return True if this job's priority is less than another's."""
return self.priority <= other.priority
class MultiFileDownloadJob(DownloadJobBase):
"""Class to monitor and control multifile downloads."""
download_parts: Set[DownloadJob] = Field(default_factory=set, description="List of download parts.")
class DownloadQueueServiceBase(ABC): class DownloadQueueServiceBase(ABC):
"""Multithreaded queue for downloading models via URL.""" """Multithreaded queue for downloading models via URL."""
@ -201,6 +224,48 @@ class DownloadQueueServiceBase(ABC):
""" """
pass pass
@abstractmethod
def multifile_download(
self,
parts: List[RemoteModelFile],
dest: Path,
access_token: Optional[str] = None,
submit_job: bool = True,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> MultiFileDownloadJob:
"""
Create and enqueue a multifile download job.
:param parts: Set of URL / filename pairs
:param dest: Path to download to. See below.
:param access_token: Access token to download the indicated files. If not provided,
each file's URL may be matched to an access token using the config file matching
system.
:param submit_job: If true [default] then submit the job for execution. Otherwise,
you will need to pass the job to submit_multifile_download().
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
events.
:returns: A MultiFileDownloadJob object for monitoring the state of the download.
The `dest` argument is a Path object pointing to a directory. All downloads
with be placed inside this directory. The callbacks will receive the
MultiFileDownloadJob.
"""
pass
@abstractmethod
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
"""
Enqueue a previously-created multi-file download job.
:param job: A MultiFileDownloadJob created with multifile_download()
"""
pass
@abstractmethod @abstractmethod
def submit_download_job( def submit_download_job(
self, self,
@ -252,7 +317,7 @@ class DownloadQueueServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def cancel_job(self, job: DownloadJob) -> None: def cancel_job(self, job: DownloadJobBase) -> None:
"""Cancel the job, clearing partial downloads and putting it into ERROR state.""" """Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass pass
@ -262,7 +327,7 @@ class DownloadQueueServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
"""Wait until the indicated download job has reached a terminal state. """Wait until the indicated download job has reached a terminal state.
This will block until the indicated install job has completed, This will block until the indicated install job has completed,

View File

@ -8,30 +8,32 @@ import time
import traceback import traceback
from pathlib import Path from pathlib import Path
from queue import Empty, PriorityQueue from queue import Empty, PriorityQueue
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set from typing import Any, Dict, List, Literal, Optional, Set
import requests import requests
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from requests import HTTPError from requests import HTTPError
from tqdm import tqdm from tqdm import tqdm
from invokeai.app.services.config import InvokeAIAppConfig, get_config
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.model_manager.metadata import RemoteModelFile
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from .download_base import ( from .download_base import (
DownloadEventHandler, DownloadEventHandler,
DownloadExceptionHandler, DownloadExceptionHandler,
DownloadJob, DownloadJob,
DownloadJobBase,
DownloadJobCancelledException, DownloadJobCancelledException,
DownloadJobStatus, DownloadJobStatus,
DownloadQueueServiceBase, DownloadQueueServiceBase,
MultiFileDownloadJob,
ServiceInactiveException, ServiceInactiveException,
UnknownJobIDException, UnknownJobIDException,
) )
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
# Maximum number of bytes to download during each call to requests.iter_content() # Maximum number of bytes to download during each call to requests.iter_content()
DOWNLOAD_CHUNK_SIZE = 100000 DOWNLOAD_CHUNK_SIZE = 100000
@ -42,20 +44,24 @@ class DownloadQueueService(DownloadQueueServiceBase):
def __init__( def __init__(
self, self,
max_parallel_dl: int = 5, max_parallel_dl: int = 5,
app_config: Optional[InvokeAIAppConfig] = None,
event_bus: Optional["EventServiceBase"] = None, event_bus: Optional["EventServiceBase"] = None,
requests_session: Optional[requests.sessions.Session] = None, requests_session: Optional[requests.sessions.Session] = None,
): ):
""" """
Initialize DownloadQueue. Initialize DownloadQueue.
:param app_config: InvokeAIAppConfig object
:param max_parallel_dl: Number of simultaneous downloads allowed [5]. :param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param requests_session: Optional requests.sessions.Session object, for unit tests. :param requests_session: Optional requests.sessions.Session object, for unit tests.
""" """
self._app_config = app_config or get_config()
self._jobs: Dict[int, DownloadJob] = {} self._jobs: Dict[int, DownloadJob] = {}
self._download_part2parent: Dict[AnyHttpUrl, MultiFileDownloadJob] = {}
self._next_job_id = 0 self._next_job_id = 0
self._queue: PriorityQueue[DownloadJob] = PriorityQueue() self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
self._stop_event = threading.Event() self._stop_event = threading.Event()
self._job_completed_event = threading.Event() self._job_terminated_event = threading.Event()
self._worker_pool: Set[threading.Thread] = set() self._worker_pool: Set[threading.Thread] = set()
self._lock = threading.Lock() self._lock = threading.Lock()
self._logger = InvokeAILogger.get_logger("DownloadQueueService") self._logger = InvokeAILogger.get_logger("DownloadQueueService")
@ -107,9 +113,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
raise ServiceInactiveException( raise ServiceInactiveException(
"The download service is not currently accepting requests. Please call start() to initialize the service." "The download service is not currently accepting requests. Please call start() to initialize the service."
) )
with self._lock: job.id = self._next_id()
job.id = self._next_job_id
self._next_job_id += 1
job.set_callbacks( job.set_callbacks(
on_start=on_start, on_start=on_start,
on_progress=on_progress, on_progress=on_progress,
@ -141,7 +145,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
source=source, source=source,
dest=dest, dest=dest,
priority=priority, priority=priority,
access_token=access_token, access_token=access_token or self._lookup_access_token(source),
) )
self.submit_download_job( self.submit_download_job(
job, job,
@ -153,10 +157,63 @@ class DownloadQueueService(DownloadQueueServiceBase):
) )
return job return job
def multifile_download(
self,
parts: List[RemoteModelFile],
dest: Path,
access_token: Optional[str] = None,
submit_job: bool = True,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> MultiFileDownloadJob:
mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id())
mfdj.set_callbacks(
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
for part in parts:
url = part.url
path = dest / part.path
assert path.is_relative_to(dest), "only relative download paths accepted"
job = DownloadJob(
source=url,
dest=path,
access_token=access_token,
)
mfdj.download_parts.add(job)
self._download_part2parent[job.source] = mfdj
if submit_job:
self.submit_multifile_download(mfdj)
return mfdj
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
for download_job in job.download_parts:
self.submit_download_job(
download_job,
on_start=self._mfd_started,
on_progress=self._mfd_progress,
on_complete=self._mfd_complete,
on_cancelled=self._mfd_cancelled,
on_error=self._mfd_error,
)
def join(self) -> None: def join(self) -> None:
"""Wait for all jobs to complete.""" """Wait for all jobs to complete."""
self._queue.join() self._queue.join()
def _next_id(self) -> int:
with self._lock:
id = self._next_job_id
self._next_job_id += 1
return id
def list_jobs(self) -> List[DownloadJob]: def list_jobs(self) -> List[DownloadJob]:
"""List all the jobs.""" """List all the jobs."""
return list(self._jobs.values()) return list(self._jobs.values())
@ -178,14 +235,14 @@ class DownloadQueueService(DownloadQueueServiceBase):
except KeyError as excp: except KeyError as excp:
raise UnknownJobIDException("Unrecognized job") from excp raise UnknownJobIDException("Unrecognized job") from excp
def cancel_job(self, job: DownloadJob) -> None: def cancel_job(self, job: DownloadJobBase) -> None:
""" """
Cancel the indicated job. Cancel the indicated job.
If it is running it will be stopped. If it is running it will be stopped.
job.status will be set to DownloadJobStatus.CANCELLED job.status will be set to DownloadJobStatus.CANCELLED
""" """
with self._lock: if job.status in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]:
job.cancel() job.cancel()
def cancel_all_jobs(self) -> None: def cancel_all_jobs(self) -> None:
@ -194,12 +251,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
if not job.in_terminal_state: if not job.in_terminal_state:
self.cancel_job(job) self.cancel_job(job)
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
"""Block until the indicated job has reached terminal state, or when timeout limit reached.""" """Block until the indicated job has reached terminal state, or when timeout limit reached."""
start = time.time() start = time.time()
while not job.in_terminal_state: while not job.in_terminal_state:
if self._job_completed_event.wait(timeout=0.25): # in case we miss an event if self._job_terminated_event.wait(timeout=0.25): # in case we miss an event
self._job_completed_event.clear() self._job_terminated_event.clear()
if timeout > 0 and time.time() - start > timeout: if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded") raise TimeoutError("Timeout exceeded")
return job return job
@ -228,22 +285,25 @@ class DownloadQueueService(DownloadQueueServiceBase):
job.job_started = get_iso_timestamp() job.job_started = get_iso_timestamp()
self._do_download(job) self._do_download(job)
self._signal_job_complete(job) self._signal_job_complete(job)
except (OSError, HTTPError) as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc()
self._signal_job_error(job, excp)
except DownloadJobCancelledException: except DownloadJobCancelledException:
self._signal_job_cancelled(job) self._signal_job_cancelled(job)
self._cleanup_cancelled_job(job) self._cleanup_cancelled_job(job)
except Exception as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc()
self._signal_job_error(job, excp)
finally: finally:
job.job_ended = get_iso_timestamp() job.job_ended = get_iso_timestamp()
self._job_completed_event.set() # signal a change to terminal state self._job_terminated_event.set() # signal a change to terminal state
self._download_part2parent.pop(job.source, None) # if this is a subpart of a multipart job, remove it
self._job_terminated_event.set()
self._queue.task_done() self._queue.task_done()
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.") self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
def _do_download(self, job: DownloadJob) -> None: def _do_download(self, job: DownloadJob) -> None:
"""Do the actual download.""" """Do the actual download."""
url = job.source url = job.source
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {} header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb" open_mode = "wb"
@ -335,38 +395,29 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _in_progress_path(self, path: Path) -> Path: def _in_progress_path(self, path: Path) -> Path:
return path.with_name(path.name + ".downloading") return path.with_name(path.name + ".downloading")
def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
# Pull the token from config if it exists and matches the URL
token = None
for pair in self._app_config.remote_api_tokens or []:
if re.search(pair.url_regex, str(source)):
token = pair.token
break
return token
def _signal_job_started(self, job: DownloadJob) -> None: def _signal_job_started(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.RUNNING job.status = DownloadJobStatus.RUNNING
if job.on_start: self._execute_cb(job, "on_start")
try:
job.on_start(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
)
if self._event_bus: if self._event_bus:
self._event_bus.emit_download_started(job) self._event_bus.emit_download_started(job)
def _signal_job_progress(self, job: DownloadJob) -> None: def _signal_job_progress(self, job: DownloadJob) -> None:
if job.on_progress: self._execute_cb(job, "on_progress")
try:
job.on_progress(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
)
if self._event_bus: if self._event_bus:
self._event_bus.emit_download_progress(job) self._event_bus.emit_download_progress(job)
def _signal_job_complete(self, job: DownloadJob) -> None: def _signal_job_complete(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.COMPLETED job.status = DownloadJobStatus.COMPLETED
if job.on_complete: self._execute_cb(job, "on_complete")
try:
job.on_complete(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
)
if self._event_bus: if self._event_bus:
self._event_bus.emit_download_complete(job) self._event_bus.emit_download_complete(job)
@ -374,26 +425,21 @@ class DownloadQueueService(DownloadQueueServiceBase):
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]: if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
return return
job.status = DownloadJobStatus.CANCELLED job.status = DownloadJobStatus.CANCELLED
if job.on_cancelled: self._execute_cb(job, "on_cancelled")
try:
job.on_cancelled(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
)
if self._event_bus: if self._event_bus:
self._event_bus.emit_download_cancelled(job) self._event_bus.emit_download_cancelled(job)
# if multifile download, then signal the parent
if parent_job := self._download_part2parent.get(job.source, None):
if not parent_job.in_terminal_state:
parent_job.status = DownloadJobStatus.CANCELLED
self._execute_cb(parent_job, "on_cancelled")
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None: def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
job.status = DownloadJobStatus.ERROR job.status = DownloadJobStatus.ERROR
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}") self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
if job.on_error: self._execute_cb(job, "on_error", excp)
try:
job.on_error(job, excp)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
)
if self._event_bus: if self._event_bus:
self._event_bus.emit_download_error(job) self._event_bus.emit_download_error(job)
@ -406,6 +452,97 @@ class DownloadQueueService(DownloadQueueServiceBase):
except OSError as excp: except OSError as excp:
self._logger.warning(excp) self._logger.warning(excp)
########################################
# callbacks used for multifile downloads
########################################
def _mfd_started(self, download_job: DownloadJob) -> None:
self._logger.info(f"File download started: {download_job.source}")
with self._lock:
mf_job = self._download_part2parent[download_job.source]
if mf_job.waiting:
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
mf_job.status = DownloadJobStatus.RUNNING
assert download_job.download_path is not None
path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest)
mf_job.download_path = (
mf_job.dest / path_relative_to_destdir.parts[0]
) # keep just the first component of the path
self._execute_cb(mf_job, "on_start")
def _mfd_progress(self, download_job: DownloadJob) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
if mf_job.cancelled:
for part in mf_job.download_parts:
self.cancel_job(part)
elif mf_job.running:
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
mf_job.bytes = sum(x.total_bytes for x in mf_job.download_parts)
self._execute_cb(mf_job, "on_progress")
def _mfd_complete(self, download_job: DownloadJob) -> None:
self._logger.info(f"Download complete: {download_job.source}")
with self._lock:
mf_job = self._download_part2parent[download_job.source]
# are there any more active jobs left in this task?
if mf_job.running and all(x.complete for x in mf_job.download_parts):
mf_job.status = DownloadJobStatus.COMPLETED
self._execute_cb(mf_job, "on_complete")
# we're done with this sub-job
self._job_terminated_event.set()
def _mfd_cancelled(self, download_job: DownloadJob) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
assert mf_job is not None
if not mf_job.in_terminal_state:
self._logger.warning(f"Download cancelled: {download_job.source}")
mf_job.cancel()
for s in mf_job.download_parts:
self.cancel_job(s)
def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
assert mf_job is not None
if not mf_job.in_terminal_state:
mf_job.status = download_job.status
mf_job.error = download_job.error
mf_job.error_type = download_job.error_type
self._execute_cb(mf_job, "on_error", excp)
self._logger.error(
f"Cancelling {mf_job.dest} due to an error while downloading {download_job.source}: {str(excp)}"
)
for s in [x for x in mf_job.download_parts if x.running]:
self.cancel_job(s)
self._download_part2parent.pop(download_job.source)
self._job_terminated_event.set()
def _execute_cb(
self,
job: DownloadJob | MultiFileDownloadJob,
callback_name: Literal[
"on_start",
"on_progress",
"on_complete",
"on_cancelled",
"on_error",
],
excp: Optional[Exception] = None,
) -> None:
if callback := getattr(job, callback_name, None):
args = [job, excp] if excp else [job]
try:
callback(*args)
except Exception as e:
self._logger.error(
f"An error occurred while processing the {callback_name} callback: {traceback.format_exception(e)}"
)
def get_pc_name_max(directory: str) -> int: def get_pc_name_max(directory: str) -> int:
if hasattr(os, "pathconf"): if hasattr(os, "pathconf"):

View File

@ -22,6 +22,7 @@ from invokeai.app.services.events.events_common import (
ModelInstallCompleteEvent, ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent, ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent, ModelInstallDownloadsCompleteEvent,
ModelInstallDownloadStartedEvent,
ModelInstallErrorEvent, ModelInstallErrorEvent,
ModelInstallStartedEvent, ModelInstallStartedEvent,
ModelLoadCompleteEvent, ModelLoadCompleteEvent,
@ -34,7 +35,6 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineInterme
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.download.download_base import DownloadJob from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.model_install.model_install_common import ModelInstallJob from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import ( from invokeai.app.services.session_queue.session_queue_common import (
@ -145,6 +145,10 @@ class EventServiceBase:
# region Model install # region Model install
def emit_model_install_download_started(self, job: "ModelInstallJob") -> None:
"""Emitted at intervals while the install job is started (remote models only)."""
self.dispatch(ModelInstallDownloadStartedEvent.build(job))
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None: def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
"""Emitted at intervals while the install job is in progress (remote models only).""" """Emitted at intervals while the install job is in progress (remote models only)."""
self.dispatch(ModelInstallDownloadProgressEvent.build(job)) self.dispatch(ModelInstallDownloadProgressEvent.build(job))

View File

@ -417,6 +417,42 @@ class ModelLoadCompleteEvent(ModelEventBase):
return cls(config=config, submodel_type=submodel_type) return cls(config=config, submodel_type=submodel_type)
@payload_schema.register
class ModelInstallDownloadStartedEvent(ModelEventBase):
"""Event model for model_install_download_started"""
__event_name__ = "model_install_download_started"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
local_path: str = Field(description="Where model is downloading to")
bytes: int = Field(description="Number of bytes downloaded so far")
total_bytes: int = Field(description="Total size of download, including all files")
parts: list[dict[str, int | str]] = Field(
description="Progress of downloading URLs that comprise the model, if any"
)
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadStartedEvent":
parts: list[dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
return cls(
id=job.id,
source=str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
)
@payload_schema.register @payload_schema.register
class ModelInstallDownloadProgressEvent(ModelEventBase): class ModelInstallDownloadProgressEvent(ModelEventBase):
"""Event model for model_install_download_progress""" """Event model for model_install_download_progress"""

View File

@ -13,7 +13,7 @@ from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager.config import AnyModelConfig from invokeai.backend.model_manager import AnyModelConfig
class ModelInstallServiceBase(ABC): class ModelInstallServiceBase(ABC):
@ -243,12 +243,11 @@ class ModelInstallServiceBase(ABC):
""" """
@abstractmethod @abstractmethod
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path: def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
""" """
Download the model file located at source to the models cache and return its Path. Download the model file located at source to the models cache and return its Path.
:param source: A Url or a string that can be converted into one. :param source: A string representing a URL or repo_id.
:param access_token: Optional access token to access restricted resources.
The model file will be downloaded into the system-wide model cache The model file will be downloaded into the system-wide model cache
(`models/.cache`) if it isn't already there. Note that the model cache (`models/.cache`) if it isn't already there. Note that the model cache

View File

@ -8,7 +8,7 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated from typing_extensions import Annotated
from invokeai.app.services.download import DownloadJob from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
@ -26,13 +26,6 @@ class InstallStatus(str, Enum):
CANCELLED = "cancelled" # terminated with an error message CANCELLED = "cancelled" # terminated with an error message
class ModelInstallPart(BaseModel):
url: AnyHttpUrl
path: Path
bytes: int = 0
total_bytes: int = 0
class UnknownInstallJobException(Exception): class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested.""" """Raised when the status of an unknown job is requested."""
@ -169,6 +162,7 @@ class ModelInstallJob(BaseModel):
) )
# internal flags and transitory settings # internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None) _install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_multifile_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None) _exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None: def set_error(self, e: Exception) -> None:

View File

@ -5,21 +5,22 @@ import os
import re import re
import threading import threading
import time import time
from hashlib import sha256
from pathlib import Path from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch import torch
import yaml import yaml
from huggingface_hub import HfFolder from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from pydantic_core import Url
from requests import Session from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
@ -44,6 +45,7 @@ from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.catch_sigint import catch_sigint from invokeai.backend.util.catch_sigint import catch_sigint
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.util import slugify
from .model_install_common import ( from .model_install_common import (
MODEL_SOURCE_TO_TYPE_MAP, MODEL_SOURCE_TO_TYPE_MAP,
@ -58,9 +60,6 @@ from .model_install_common import (
TMPDIR_PREFIX = "tmpinstall_" TMPDIR_PREFIX = "tmpinstall_"
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
class ModelInstallService(ModelInstallServiceBase): class ModelInstallService(ModelInstallServiceBase):
"""class for InvokeAI model installation.""" """class for InvokeAI model installation."""
@ -91,7 +90,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._downloads_changed_event = threading.Event() self._downloads_changed_event = threading.Event()
self._install_completed_event = threading.Event() self._install_completed_event = threading.Event()
self._download_queue = download_queue self._download_queue = download_queue
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {} self._download_cache: Dict[int, ModelInstallJob] = {}
self._running = False self._running = False
self._session = session self._session = session
self._install_thread: Optional[threading.Thread] = None self._install_thread: Optional[threading.Thread] = None
@ -210,33 +209,12 @@ class ModelInstallService(ModelInstallServiceBase):
access_token: Optional[str] = None, access_token: Optional[str] = None,
inplace: Optional[bool] = False, inplace: Optional[bool] = False,
) -> ModelInstallJob: ) -> ModelInstallJob:
variants = "|".join(ModelRepoVariant.__members__.values()) """Install a model using pattern matching to infer the type of source."""
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" source_obj = self._guess_source(source)
source_obj: Optional[StringLikeSource] = None if isinstance(source_obj, LocalModelSource):
source_obj.inplace = inplace
if Path(source).exists(): # A local file or directory elif isinstance(source_obj, HFModelSource) or isinstance(source_obj, URLModelSource):
source_obj = LocalModelSource(path=Path(source), inplace=inplace) source_obj.access_token = access_token
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
subfolder=Path(match.group(3)) if match.group(3) else None,
access_token=access_token,
)
elif re.match(r"^https?://[^/]+", source):
# Pull the token from config if it exists and matches the URL
_token = access_token
if _token is None:
for pair in self.app_config.remote_api_tokens or []:
if re.search(pair.url_regex, source):
_token = pair.token
break
source_obj = URLModelSource(
url=AnyHttpUrl(source),
access_token=_token,
)
else:
raise ValueError(f"Unsupported model source: '{source}'")
return self.import_model(source_obj, config) return self.import_model(source_obj, config)
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102 def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
@ -297,8 +275,9 @@ class ModelInstallService(ModelInstallServiceBase):
def cancel_job(self, job: ModelInstallJob) -> None: def cancel_job(self, job: ModelInstallJob) -> None:
"""Cancel the indicated job.""" """Cancel the indicated job."""
job.cancel() job.cancel()
with self._lock: self._logger.warning(f"Cancelling {job.source}")
self._cancel_download_parts(job) if dj := job._multifile_job:
self._download_queue.cancel_job(dj)
def prune_jobs(self) -> None: def prune_jobs(self) -> None:
"""Prune all completed and errored jobs.""" """Prune all completed and errored jobs."""
@ -351,7 +330,7 @@ class ModelInstallService(ModelInstallServiceBase):
legacy_config_path = stanza.get("config") legacy_config_path = stanza.get("config")
if legacy_config_path: if legacy_config_path:
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir. # In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
legacy_config_path: Path = self._app_config.root_path / legacy_config_path legacy_config_path = self._app_config.root_path / legacy_config_path
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path): if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path) legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
config["config_path"] = str(legacy_config_path) config["config_path"] = str(legacy_config_path)
@ -392,38 +371,95 @@ class ModelInstallService(ModelInstallServiceBase):
rmtree(model_path) rmtree(model_path)
self.unregister(key) self.unregister(key)
def download_and_cache( @classmethod
def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path:
escaped_source = slugify(str(source))
return app_config.download_cache_path / escaped_source
def download_and_cache_model(
self, self,
source: Union[str, AnyHttpUrl], source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: int = 0,
) -> Path: ) -> Path:
"""Download the model file located at source to the models cache and return its Path.""" """Download the model file located at source to the models cache and return its Path."""
model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32] model_path = self._download_cache_path(str(source), self._app_config)
model_path = self._app_config.convert_cache_path / model_hash
# We expect the cache directory to contain one and only one downloaded file. # We expect the cache directory to contain one and only one downloaded file or directory.
# We don't know the file's name in advance, as it is set by the download # We don't know the file's name in advance, as it is set by the download
# content-disposition header. # content-disposition header.
if model_path.exists(): if model_path.exists():
contents = [x for x in model_path.iterdir() if x.is_file()] contents: List[Path] = list(model_path.iterdir())
if len(contents) > 0: if len(contents) > 0:
return contents[0] return contents[0]
model_path.mkdir(parents=True, exist_ok=True) model_path.mkdir(parents=True, exist_ok=True)
job = self._download_queue.download( model_source = self._guess_source(str(source))
source=AnyHttpUrl(str(source)), remote_files, _ = self._remote_files_from_source(model_source)
job = self._multifile_download(
dest=model_path, dest=model_path,
access_token=access_token, remote_files=remote_files,
on_progress=TqdmProgress().update, subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
) )
self._download_queue.wait_for_job(job, timeout) files_string = "file" if len(remote_files) == 1 else "files"
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
self._download_queue.wait_for_job(job)
if job.complete: if job.complete:
assert job.download_path is not None assert job.download_path is not None
return job.download_path return job.download_path
else: else:
raise Exception(job.error) raise Exception(job.error)
def _remote_files_from_source(
self, source: ModelSource
) -> Tuple[List[RemoteModelFile], Optional[AnyModelRepoMetadata]]:
metadata = None
if isinstance(source, HFModelSource):
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
assert isinstance(metadata, ModelMetadataWithFiles)
return (
metadata.download_urls(
variant=source.variant or self._guess_variant(),
subfolder=source.subfolder,
session=self._session,
),
metadata,
)
if isinstance(source, URLModelSource):
try:
fetcher = self.get_fetcher_from_url(str(source.url))
kwargs: dict[str, Any] = {"session": self._session}
metadata = fetcher(**kwargs).from_url(source.url)
assert isinstance(metadata, ModelMetadataWithFiles)
return metadata.download_urls(session=self._session), metadata
except ValueError:
pass
return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None
raise Exception(f"No files associated with {source}")
def _guess_source(self, source: str) -> ModelSource:
"""Turn a source string into a ModelSource object."""
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None
if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source))
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),
variant=ModelRepoVariant(match.group(2)) if match.group(2) else None, # pass None rather than ''
subfolder=Path(match.group(3)) if match.group(3) else None,
)
elif re.match(r"^https?://[^/]+", source):
source_obj = URLModelSource(
url=Url(source),
)
else:
raise ValueError(f"Unsupported model source: '{source}'")
return source_obj
# -------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------
# Internal functions that manage the installer threads # Internal functions that manage the installer threads
# -------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------
@ -484,16 +520,19 @@ class ModelInstallService(ModelInstallServiceBase):
job.config_out = self.record_store.get_model(key) job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job) self._signal_job_completed(job)
def _set_error(self, job: ModelInstallJob, excp: Exception) -> None: def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None:
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts): multifile_download_job = install_job._multifile_job
job.set_error( if multifile_download_job and any(
x.content_type is not None and "text/html" in x.content_type for x in multifile_download_job.download_parts
):
install_job.set_error(
InvalidModelConfigException( InvalidModelConfigException(
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download." f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
) )
) )
else: else:
job.set_error(excp) install_job.set_error(excp)
self._signal_job_errored(job) self._signal_job_errored(install_job)
# -------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------
# Internal functions that manage the models directory # Internal functions that manage the models directory
@ -519,7 +558,6 @@ class ModelInstallService(ModelInstallServiceBase):
This is typically only used during testing with a new DB or when using the memory DB, because those are the This is typically only used during testing with a new DB or when using the memory DB, because those are the
only situations in which we may have orphaned models in the models directory. only situations in which we may have orphaned models in the models directory.
""" """
installed_model_paths = { installed_model_paths = {
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models() (self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
} }
@ -531,7 +569,12 @@ class ModelInstallService(ModelInstallServiceBase):
if resolved_path in installed_model_paths: if resolved_path in installed_model_paths:
return True return True
# Skip core models entirely - these aren't registered with the model manager. # Skip core models entirely - these aren't registered with the model manager.
if str(resolved_path).startswith(str(self.app_config.models_path / "core")): for special_directory in [
self.app_config.models_path / "core",
self.app_config.convert_cache_dir,
self.app_config.download_cache_dir,
]:
if resolved_path.is_relative_to(special_directory):
return False return False
try: try:
model_id = self.register_path(model_path) model_id = self.register_path(model_path)
@ -647,20 +690,15 @@ class ModelInstallService(ModelInstallServiceBase):
inplace=source.inplace or False, inplace=source.inplace or False,
) )
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: def _import_from_hf(
self,
source: HFModelSource,
config: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests # Add user's cached access token to HuggingFace requests
source.access_token = source.access_token or HfFolder.get_token() if source.access_token is None:
if not source.access_token: source.access_token = HfFolder.get_token()
self._logger.info("No HuggingFace access token present; some models may not be downloadable.") remote_files, metadata = self._remote_files_from_source(source)
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
assert isinstance(metadata, ModelMetadataWithFiles)
remote_files = metadata.download_urls(
variant=source.variant or self._guess_variant(),
subfolder=source.subfolder,
session=self._session,
)
return self._import_remote_model( return self._import_remote_model(
source=source, source=source,
config=config, config=config,
@ -668,22 +706,12 @@ class ModelInstallService(ModelInstallServiceBase):
metadata=metadata, metadata=metadata,
) )
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: def _import_from_url(
# URLs from HuggingFace will be handled specially self,
metadata = None source: URLModelSource,
fetcher = None config: Optional[Dict[str, Any]],
try: ) -> ModelInstallJob:
fetcher = self.get_fetcher_from_url(str(source.url)) remote_files, metadata = self._remote_files_from_source(source)
except ValueError:
pass
kwargs: dict[str, Any] = {"session": self._session}
if fetcher is not None:
metadata = fetcher(**kwargs).from_url(source.url)
self._logger.debug(f"metadata={metadata}")
if metadata and isinstance(metadata, ModelMetadataWithFiles):
remote_files = metadata.download_urls(session=self._session)
else:
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
return self._import_remote_model( return self._import_remote_model(
source=source, source=source,
config=config, config=config,
@ -698,12 +726,9 @@ class ModelInstallService(ModelInstallServiceBase):
metadata: Optional[AnyModelRepoMetadata], metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]], config: Optional[Dict[str, Any]],
) -> ModelInstallJob: ) -> ModelInstallJob:
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
# Currently the tmpdir isn't automatically removed at exit because it is
# being held in a daemon thread.
if len(remote_files) == 0: if len(remote_files) == 0:
raise ValueError(f"{source}: No downloadable files found") raise ValueError(f"{source}: No downloadable files found")
tmpdir = Path( destdir = Path(
mkdtemp( mkdtemp(
dir=self._app_config.models_path, dir=self._app_config.models_path,
prefix=TMPDIR_PREFIX, prefix=TMPDIR_PREFIX,
@ -714,55 +739,28 @@ class ModelInstallService(ModelInstallServiceBase):
source=source, source=source,
config_in=config or {}, config_in=config or {},
source_metadata=metadata, source_metadata=metadata,
local_path=tmpdir, # local path may change once the download has started due to content-disposition handling local_path=destdir, # local path may change once the download has started due to content-disposition handling
bytes=0, bytes=0,
total_bytes=0, total_bytes=0,
) )
# In the event that there is a subfolder specified in the source, # remember the temporary directory for later removal
# we need to remove it from the destination path in order to avoid install_job._install_tmpdir = destdir
# creating unwanted subfolders install_job.total_bytes = sum((x.size or 0) for x in remote_files)
if isinstance(source, HFModelSource) and source.subfolder:
root = Path(remote_files[0].path.parts[0])
subfolder = root / source.subfolder
else:
root = Path(".")
subfolder = Path(".")
# we remember the path up to the top of the tmpdir so that it may be multifile_job = self._multifile_download(
# removed safely at the end of the install process. remote_files=remote_files,
install_job._install_tmpdir = tmpdir dest=destdir,
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below subfolder=source.subfolder if isinstance(source, HFModelSource) else None,
files_string = "file" if len(remote_files) == 1 else "file"
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
self._logger.debug(f"remote_files={remote_files}")
for model_file in remote_files:
url = model_file.url
path = root / model_file.path.relative_to(subfolder)
self._logger.debug(f"Downloading {url} => {path}")
install_job.total_bytes += model_file.size
assert hasattr(source, "access_token")
dest = tmpdir / path.parent
dest.mkdir(parents=True, exist_ok=True)
download_job = DownloadJob(
source=url,
dest=dest,
access_token=source.access_token, access_token=source.access_token,
submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
) )
self._download_cache[download_job.source] = install_job # matches a download job to an install job self._download_cache[multifile_job.id] = install_job
install_job.download_parts.add(download_job) install_job._multifile_job = multifile_job
# only start the jobs once install_job.download_parts is fully populated
for download_job in install_job.download_parts:
self._download_queue.submit_download_job(
download_job,
on_start=self._download_started_callback,
on_progress=self._download_progress_callback,
on_complete=self._download_complete_callback,
on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback,
)
files_string = "file" if len(remote_files) == 1 else "files"
self._logger.info(f"Queueing model install: {source} ({len(remote_files)} {files_string})")
self._logger.debug(f"remote_files={remote_files}")
self._download_queue.submit_multifile_download(multifile_job)
return install_job return install_job
def _stat_size(self, path: Path) -> int: def _stat_size(self, path: Path) -> int:
@ -774,88 +772,105 @@ class ModelInstallService(ModelInstallServiceBase):
size += sum(self._stat_size(Path(root, x)) for x in files) size += sum(self._stat_size(Path(root, x)) for x in files)
return size return size
def _multifile_download(
self,
remote_files: List[RemoteModelFile],
dest: Path,
subfolder: Optional[Path] = None,
access_token: Optional[str] = None,
submit_job: bool = True,
) -> MultiFileDownloadJob:
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
if subfolder:
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
path_to_add = Path(f"{top}_{subfolder}")
else:
path_to_remove = Path(".")
path_to_add = Path(".")
parts: List[RemoteModelFile] = []
for model_file in remote_files:
assert model_file.size is not None
parts.append(
RemoteModelFile(
url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json
path=path_to_add / model_file.path.relative_to(path_to_remove),
)
)
return self._download_queue.multifile_download(
parts=parts,
dest=dest,
access_token=access_token,
submit_job=submit_job,
on_start=self._download_started_callback,
on_progress=self._download_progress_callback,
on_complete=self._download_complete_callback,
on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback,
)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Callbacks are executed by the download queue in a separate thread # Callbacks are executed by the download queue in a separate thread
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _download_started_callback(self, download_job: DownloadJob) -> None: def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
self._logger.info(f"Model download started: {download_job.source}")
with self._lock: with self._lock:
install_job = self._download_cache[download_job.source] if install_job := self._download_cache.get(download_job.id, None):
install_job.status = InstallStatus.DOWNLOADING install_job.status = InstallStatus.DOWNLOADING
if install_job.local_path == install_job._install_tmpdir: # first time
assert download_job.download_path assert download_job.download_path
if install_job.local_path == install_job._install_tmpdir: install_job.local_path = download_job.download_path
partial_path = download_job.download_path.relative_to(install_job._install_tmpdir) install_job.download_parts = download_job.download_parts
dest_name = partial_path.parts[0] install_job.bytes = sum(x.bytes for x in download_job.download_parts)
install_job.local_path = install_job._install_tmpdir / dest_name install_job.total_bytes = download_job.total_bytes
self._signal_job_download_started(install_job)
# Update the total bytes count for remote sources. def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
if not install_job.total_bytes:
install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts)
def _download_progress_callback(self, download_job: DownloadJob) -> None:
with self._lock: with self._lock:
install_job = self._download_cache[download_job.source] if install_job := self._download_cache.get(download_job.id, None):
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel() if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
self._cancel_download_parts(install_job) self._download_queue.cancel_job(download_job)
else: else:
# update sizes # update sizes
install_job.bytes = sum(x.bytes for x in install_job.download_parts) install_job.bytes = sum(x.bytes for x in download_job.download_parts)
install_job.total_bytes = sum(x.total_bytes for x in download_job.download_parts)
self._signal_job_downloading(install_job) self._signal_job_downloading(install_job)
def _download_complete_callback(self, download_job: DownloadJob) -> None: def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
self._logger.info(f"Model download complete: {download_job.source}")
with self._lock: with self._lock:
install_job = self._download_cache[download_job.source] if install_job := self._download_cache.pop(download_job.id, None):
# are there any more active jobs left in this task?
if install_job.downloading and all(x.complete for x in install_job.download_parts):
self._signal_job_downloads_done(install_job) self._signal_job_downloads_done(install_job)
self._put_in_queue(install_job) self._put_in_queue(install_job) # this starts the installation and registration
# Let other threads know that the number of downloads has changed # Let other threads know that the number of downloads has changed
self._download_cache.pop(download_job.source, None)
self._downloads_changed_event.set() self._downloads_changed_event.set()
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None: def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock: with self._lock:
install_job = self._download_cache.pop(download_job.source, None) if install_job := self._download_cache.pop(download_job.id, None):
assert install_job is not None
assert excp is not None assert excp is not None
install_job.set_error(excp) install_job.set_error(excp)
self._logger.error( self._download_queue.cancel_job(download_job)
f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}"
)
self._cancel_download_parts(install_job)
# Let other threads know that the number of downloads has changed # Let other threads know that the number of downloads has changed
self._downloads_changed_event.set() self._downloads_changed_event.set()
def _download_cancelled_callback(self, download_job: DownloadJob) -> None: def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock: with self._lock:
install_job = self._download_cache.pop(download_job.source, None) if install_job := self._download_cache.pop(download_job.id, None):
if not install_job:
return
self._downloads_changed_event.set() self._downloads_changed_event.set()
self._logger.warning(f"Model download canceled: {download_job.source}")
# if install job has already registered an error, then do not replace its status with cancelled # if install job has already registered an error, then do not replace its status with cancelled
if not install_job.errored: if not install_job.errored:
install_job.cancel() install_job.cancel()
self._cancel_download_parts(install_job)
# Let other threads know that the number of downloads has changed # Let other threads know that the number of downloads has changed
self._downloads_changed_event.set() self._downloads_changed_event.set()
def _cancel_download_parts(self, install_job: ModelInstallJob) -> None:
# on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks
# do not lock here because it gets called within a locked context
for s in install_job.download_parts:
self._download_queue.cancel_job(s)
if all(x.in_terminal_state for x in install_job.download_parts):
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
self._put_in_queue(install_job)
# ------------------------------------------------------------------------------------------------ # ------------------------------------------------------------------------------------------------
# Internal methods that put events on the event bus # Internal methods that put events on the event bus
# ------------------------------------------------------------------------------------------------ # ------------------------------------------------------------------------------------------------
@ -865,8 +880,18 @@ class ModelInstallService(ModelInstallServiceBase):
if self._event_bus: if self._event_bus:
self._event_bus.emit_model_install_started(job) self._event_bus.emit_model_install_started(job)
def _signal_job_download_started(self, job: ModelInstallJob) -> None:
if self._event_bus:
assert job._multifile_job is not None
assert job.bytes is not None
assert job.total_bytes is not None
self._event_bus.emit_model_install_download_started(job)
def _signal_job_downloading(self, job: ModelInstallJob) -> None: def _signal_job_downloading(self, job: ModelInstallJob) -> None:
if self._event_bus: if self._event_bus:
assert job._multifile_job is not None
assert job.bytes is not None
assert job.total_bytes is not None
self._event_bus.emit_model_install_download_progress(job) self._event_bus.emit_model_install_download_progress(job)
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None: def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
@ -881,6 +906,8 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"Model install complete: {job.source}") self._logger.info(f"Model install complete: {job.source}")
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}") self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
if self._event_bus: if self._event_bus:
assert job.local_path is not None
assert job.config_out is not None
self._event_bus.emit_model_install_complete(job) self._event_bus.emit_model_install_complete(job)
def _signal_job_errored(self, job: ModelInstallJob) -> None: def _signal_job_errored(self, job: ModelInstallJob) -> None:
@ -896,7 +923,13 @@ class ModelInstallService(ModelInstallServiceBase):
self._event_bus.emit_model_install_cancelled(job) self._event_bus.emit_model_install_cancelled(job)
@staticmethod @staticmethod
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase: def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]:
"""
Return a metadata fetcher appropriate for provided url.
This used to be more useful, but the number of supported model
sources has been reduced to HuggingFace alone.
"""
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()): if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'") raise ValueError(f"Unsupported model source: '{url}'")

View File

@ -2,10 +2,11 @@
"""Base class for model loader.""" """Base class for model loader."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from pathlib import Path
from typing import Callable, Optional
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
@ -36,3 +37,27 @@ class ModelLoadServiceBase(ABC):
@abstractmethod @abstractmethod
def gpu_count(self) -> int: def gpu_count(self) -> int:
"""Return the number of GPUs we are configured to use.""" """Return the number of GPUs we are configured to use."""
@abstractmethod
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
"""
Load the model file or directory located at the indicated Path.
This will load an arbitrary model file into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that this returns a LoadedModelWithoutConfig object, which is the same as
LoadedModel, but without the config attribute.
Args:
model_path: A pathlib.Path to a checkpoint-style models file
loader: A Callable that expects a Path and returns a Dict[str, Tensor]
Returns:
A LoadedModel object.
"""

View File

@ -1,18 +1,26 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team # Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
"""Implementation of model loader service.""" """Implementation of model loader service."""
from typing import Optional, Type from pathlib import Path
from typing import Callable, Optional, Type
from picklescan.scanner import scan_file_path
from safetensors.torch import load_file as safetensors_load_file
from torch import load as torch_load
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import ( from invokeai.backend.model_manager.load import (
LoadedModel, LoadedModel,
LoadedModelWithoutConfig,
ModelLoaderRegistry, ModelLoaderRegistry,
ModelLoaderRegistryBase, ModelLoaderRegistryBase,
) )
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from .model_load_base import ModelLoadServiceBase from .model_load_base import ModelLoadServiceBase
@ -81,3 +89,41 @@ class ModelLoadService(ModelLoadServiceBase):
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type) self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
return loaded_model return loaded_model
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
cache_key = str(model_path)
ram_cache = self.ram_cache
try:
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
except IndexError:
pass
def torch_load_file(checkpoint: Path) -> AnyModel:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
result = torch_load(checkpoint, map_location="cpu")
return result
def diffusers_load_directory(directory: Path) -> AnyModel:
load_class = GenericDiffusersLoader(
app_config=self._app_config,
logger=self._logger,
ram_cache=self._ram_cache,
convert_cache=self.convert_cache,
).get_hf_load_class(directory)
return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
loader = loader or (
diffusers_load_directory
if model_path.is_dir()
else torch_load_file
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
else lambda path: safetensors_load_file(path, device="cpu")
)
assert loader is not None
raw_model = loader(model_path)
ram_cache.put(key=cache_key, model=raw_model)
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))

View File

@ -12,15 +12,13 @@ from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import (
ControlAdapterDefaultSettings, ControlAdapterDefaultSettings,
MainModelDefaultSettings, MainModelDefaultSettings,
ModelFormat,
ModelType,
ModelVariantType, ModelVariantType,
SchedulerPredictionType, SchedulerPredictionType,
) )

View File

@ -37,8 +37,12 @@ class SqliteSessionQueue(SessionQueueBase):
def start(self, invoker: Invoker) -> None: def start(self, invoker: Invoker) -> None:
self.__invoker = invoker self.__invoker = invoker
self._set_in_progress_to_canceled() self._set_in_progress_to_canceled()
if self.__invoker.services.configuration.clear_queue_on_startup:
clear_result = self.clear(DEFAULT_QUEUE_ID)
if clear_result.deleted > 0:
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
else:
prune_result = self.prune(DEFAULT_QUEUE_ID) prune_result = self.prune(DEFAULT_QUEUE_ID)
if prune_result.deleted > 0: if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items") self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")

View File

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
import torch import torch
from PIL.Image import Image from PIL.Image import Image
from pydantic.networks import AnyHttpUrl
from torch import Tensor from torch import Tensor
from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.invocations.constants import IMAGE_MODES
@ -23,7 +24,7 @@ from invokeai.backend.model_manager.config import (
ModelType, ModelType,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
@ -329,8 +330,10 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface): class ModelsInterface(InvocationContextInterface):
"""Common API for loading, downloading and managing models."""
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool: def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
"""Checks if a model exists. """Check if a model exists.
Args: Args:
identifier: The key or ModelField representing the model. identifier: The key or ModelField representing the model.
@ -340,13 +343,13 @@ class ModelsInterface(InvocationContextInterface):
""" """
if isinstance(identifier, str): if isinstance(identifier, str):
return self._services.model_manager.store.exists(identifier) return self._services.model_manager.store.exists(identifier)
else:
return self._services.model_manager.store.exists(identifier.key) return self._services.model_manager.store.exists(identifier.key)
def load( def load(
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
) -> LoadedModel: ) -> LoadedModel:
"""Loads a model. """Load a model.
Args: Args:
identifier: The key or ModelField representing the model. identifier: The key or ModelField representing the model.
@ -370,7 +373,7 @@ class ModelsInterface(InvocationContextInterface):
def load_by_attrs( def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
) -> LoadedModel: ) -> LoadedModel:
"""Loads a model by its attributes. """Load a model by its attributes.
Args: Args:
name: Name of the model. name: Name of the model.
@ -393,7 +396,7 @@ class ModelsInterface(InvocationContextInterface):
return self._services.model_manager.load.load_model(configs[0], submodel_type) return self._services.model_manager.load.load_model(configs[0], submodel_type)
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig: def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Gets a model's config. """Get a model's config.
Args: Args:
identifier: The key or ModelField representing the model. identifier: The key or ModelField representing the model.
@ -403,11 +406,11 @@ class ModelsInterface(InvocationContextInterface):
""" """
if isinstance(identifier, str): if isinstance(identifier, str):
return self._services.model_manager.store.get_model(identifier) return self._services.model_manager.store.get_model(identifier)
else:
return self._services.model_manager.store.get_model(identifier.key) return self._services.model_manager.store.get_model(identifier.key)
def search_by_path(self, path: Path) -> list[AnyModelConfig]: def search_by_path(self, path: Path) -> list[AnyModelConfig]:
"""Searches for models by path. """Search for models by path.
Args: Args:
path: The path to search for. path: The path to search for.
@ -424,7 +427,7 @@ class ModelsInterface(InvocationContextInterface):
type: Optional[ModelType] = None, type: Optional[ModelType] = None,
format: Optional[ModelFormat] = None, format: Optional[ModelFormat] = None,
) -> list[AnyModelConfig]: ) -> list[AnyModelConfig]:
"""Searches for models by attributes. """Search for models by attributes.
Args: Args:
name: The name to search for (exact match). name: The name to search for (exact match).
@ -443,6 +446,72 @@ class ModelsInterface(InvocationContextInterface):
model_format=format, model_format=format,
) )
def download_and_cache_model(
self,
source: str | AnyHttpUrl,
) -> Path:
"""
Download the model file located at source to the models cache and return its Path.
This can be used to single-file install models and other resources of arbitrary types
which should not get registered with the database. If the model is already
installed, the cached path will be returned. Otherwise it will be downloaded.
Args:
source: A URL that points to the model, or a huggingface repo_id.
Returns:
Path to the downloaded model
"""
return self._services.model_manager.install.download_and_cache_model(source=source)
def load_local_model(
self,
model_path: Path,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
Load the model file located at the indicated path
If a loader callable is provided, it will be invoked to load the model. Otherwise,
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
Args:
path: A model Path
loader: A Callable that expects a Path and returns a dict[str|int, Any]
Returns:
A LoadedModelWithoutConfig object.
"""
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
def load_remote_model(
self,
source: str | AnyHttpUrl,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
Download, cache, and load the model file located at the indicated URL or repo_id.
If the model is already downloaded, it will be loaded from the cache.
If the a loader callable is provided, it will be invoked to load the model. Otherwise,
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
Args:
source: A URL or huggingface repoid.
loader: A Callable that expects a Path and returns a dict[str|int, Any]
Returns:
A LoadedModelWithoutConfig object.
"""
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
class ConfigInterface(InvocationContextInterface): class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig: def get(self) -> InvokeAIAppConfig:

View File

@ -13,6 +13,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -43,6 +44,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_8(app_config=config)) migrator.register_migration(build_migration_8(app_config=config))
migrator.register_migration(build_migration_9()) migrator.register_migration(build_migration_9())
migrator.register_migration(build_migration_10()) migrator.register_migration(build_migration_10())
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
migrator.run_migrations() migrator.run_migrations()
return db return db

View File

@ -0,0 +1,75 @@
import shutil
import sqlite3
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
LEGACY_CORE_MODELS = [
# OpenPose
"any/annotators/dwpose/yolox_l.onnx",
"any/annotators/dwpose/dw-ll_ucoco_384.onnx",
# DepthAnything
"any/annotators/depth_anything/depth_anything_vitl14.pth",
"any/annotators/depth_anything/depth_anything_vitb14.pth",
"any/annotators/depth_anything/depth_anything_vits14.pth",
# Lama inpaint
"core/misc/lama/lama.pt",
# RealESRGAN upscale
"core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
"core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
"core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
]
class Migration11Callback:
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
self._app_config = app_config
self._logger = logger
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._remove_convert_cache()
self._remove_downloaded_models()
self._remove_unused_core_models()
def _remove_convert_cache(self) -> None:
"""Rename models/.cache to models/.convert_cache."""
self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.")
legacy_convert_path = self._app_config.root_path / "models" / ".cache"
shutil.rmtree(legacy_convert_path, ignore_errors=True)
def _remove_downloaded_models(self) -> None:
"""Remove models from their old locations; they will re-download when needed."""
self._logger.info(
"Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache."
)
for model_path in LEGACY_CORE_MODELS:
legacy_dest_path = self._app_config.models_path / model_path
legacy_dest_path.unlink(missing_ok=True)
def _remove_unused_core_models(self) -> None:
"""Remove unused core models and their directories."""
self._logger.info("Removing defunct core models.")
for dir in ["face_restoration", "misc", "upscaling"]:
path_to_remove = self._app_config.models_path / "core" / dir
shutil.rmtree(path_to_remove, ignore_errors=True)
shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True)
def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
"""
Build the migration from database version 10 to 11.
This migration does the following:
- Moves "core" models previously downloaded with download_with_progress_bar() into new
"models/.download_cache" directory.
- Renames "models/.cache" to "models/.convert_cache".
"""
migration_11 = Migration(
from_version=10,
to_version=11,
callback=Migration11Callback(app_config=app_config, logger=logger),
)
return migration_11

View File

@ -1,51 +0,0 @@
from pathlib import Path
from urllib import request
from tqdm import tqdm
from invokeai.backend.util.logging import InvokeAILogger
class ProgressBar:
"""Simple progress bar for urllib.request.urlretrieve using tqdm."""
def __init__(self, model_name: str = "file"):
self.pbar = None
self.name = model_name
def __call__(self, block_num: int, block_size: int, total_size: int):
if not self.pbar:
self.pbar = tqdm(
desc=self.name,
initial=0,
unit="iB",
unit_scale=True,
unit_divisor=1000,
total=total_size,
)
self.pbar.update(block_size)
def download_with_progress_bar(name: str, url: str, dest_path: Path) -> bool:
"""Download a file from a URL to a destination path, with a progress bar.
If the file already exists, it will not be downloaded again.
Exceptions are not caught.
Args:
name (str): Name of the file being downloaded.
url (str): URL to download the file from.
dest_path (Path): Destination path to save the file to.
Returns:
bool: True if the file was downloaded, False if it already existed.
"""
if dest_path.exists():
return False # already downloaded
InvokeAILogger.get_logger().info(f"Downloading {name}...")
dest_path.parent.mkdir(parents=True, exist_ok=True)
request.urlretrieve(url, dest_path, ProgressBar(name))
return True

View File

@ -1,5 +1,5 @@
import pathlib from pathlib import Path
from typing import Literal, Union from typing import Literal
import cv2 import cv2
import numpy as np import numpy as np
@ -10,28 +10,17 @@ from PIL import Image
from torchvision.transforms import Compose from torchvision.transforms import Compose
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2 from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
config = get_config() config = get_config()
logger = InvokeAILogger.get_logger(config=config) logger = InvokeAILogger.get_logger(config=config)
DEPTH_ANYTHING_MODELS = { DEPTH_ANYTHING_MODELS = {
"large": { "large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true", "base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vitl14.pth", "small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
},
"base": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vitb14.pth",
},
"small": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vits14.pth",
},
} }
@ -53,36 +42,27 @@ transform = Compose(
class DepthAnythingDetector: class DepthAnythingDetector:
def __init__(self) -> None: def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
self.model = None self.model = model
self.model_size: Union[Literal["large", "base", "small"], None] = None self.device = device
self.device = TorchDevice.choose_torch_device()
def load_model(self, model_size: Literal["large", "base", "small"] = "small"): @staticmethod
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"] def load_model(
download_with_progress_bar( model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name, ) -> DPT_DINOv2:
DEPTH_ANYTHING_MODELS[model_size]["url"], match model_size:
DEPTH_ANYTHING_MODEL_PATH,
)
if not self.model or model_size != self.model_size:
del self.model
self.model_size = model_size
match self.model_size:
case "small": case "small":
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384]) model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
case "base": case "base":
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768]) model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
case "large": case "large":
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024]) model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu")) model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
self.model.eval() model.eval()
self.model.to(self.device) model.to(device)
return self.model return model
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image: def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
if not self.model: if not self.model:

View File

@ -1,30 +1,53 @@
from pathlib import Path
from typing import Dict
import numpy as np import numpy as np
import torch import torch
from controlnet_aux.util import resize_image from controlnet_aux.util import resize_image
from PIL import Image from PIL import Image
from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
DWPOSE_MODELS = {
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
}
def draw_pose(pose, H, W, draw_face=True, draw_body=True, draw_hands=True, resolution=512):
def draw_pose(
pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]],
H: int,
W: int,
draw_face: bool = True,
draw_body: bool = True,
draw_hands: bool = True,
resolution: int = 512,
) -> Image.Image:
bodies = pose["bodies"] bodies = pose["bodies"]
faces = pose["faces"] faces = pose["faces"]
hands = pose["hands"] hands = pose["hands"]
assert isinstance(bodies, dict)
candidate = bodies["candidate"] candidate = bodies["candidate"]
assert isinstance(bodies, dict)
subset = bodies["subset"] subset = bodies["subset"]
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
if draw_body: if draw_body:
canvas = draw_bodypose(canvas, candidate, subset) canvas = draw_bodypose(canvas, candidate, subset)
if draw_hands: if draw_hands:
assert isinstance(hands, np.ndarray)
canvas = draw_handpose(canvas, hands) canvas = draw_handpose(canvas, hands)
if draw_face: if draw_face:
canvas = draw_facepose(canvas, faces) assert isinstance(hands, np.ndarray)
canvas = draw_facepose(canvas, faces) # type: ignore
dwpose_image = resize_image( dwpose_image: Image.Image = resize_image(
canvas, canvas,
resolution, resolution,
) )
@ -39,11 +62,16 @@ class DWOpenposeDetector:
Credits: https://github.com/IDEA-Research/DWPose Credits: https://github.com/IDEA-Research/DWPose
""" """
def __init__(self) -> None: def __init__(self, onnx_det: Path, onnx_pose: Path) -> None:
self.pose_estimation = Wholebody() self.pose_estimation = Wholebody(onnx_det=onnx_det, onnx_pose=onnx_pose)
def __call__( def __call__(
self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512 self,
image: Image.Image,
draw_face: bool = False,
draw_body: bool = True,
draw_hands: bool = False,
resolution: int = 512,
) -> Image.Image: ) -> Image.Image:
np_image = np.array(image) np_image = np.array(image)
H, W, C = np_image.shape H, W, C = np_image.shape
@ -79,3 +107,6 @@ class DWOpenposeDetector:
return draw_pose( return draw_pose(
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution
) )
__all__ = ["DWPOSE_MODELS", "DWOpenposeDetector"]

View File

@ -5,11 +5,13 @@ import math
import cv2 import cv2
import matplotlib import matplotlib
import numpy as np import numpy as np
import numpy.typing as npt
eps = 0.01 eps = 0.01
NDArrayInt = npt.NDArray[np.uint8]
def draw_bodypose(canvas, candidate, subset): def draw_bodypose(canvas: NDArrayInt, candidate: NDArrayInt, subset: NDArrayInt) -> NDArrayInt:
H, W, C = canvas.shape H, W, C = canvas.shape
candidate = np.array(candidate) candidate = np.array(candidate)
subset = np.array(subset) subset = np.array(subset)
@ -88,7 +90,7 @@ def draw_bodypose(canvas, candidate, subset):
return canvas return canvas
def draw_handpose(canvas, all_hand_peaks): def draw_handpose(canvas: NDArrayInt, all_hand_peaks: NDArrayInt) -> NDArrayInt:
H, W, C = canvas.shape H, W, C = canvas.shape
edges = [ edges = [
@ -142,7 +144,7 @@ def draw_handpose(canvas, all_hand_peaks):
return canvas return canvas
def draw_facepose(canvas, all_lmks): def draw_facepose(canvas: NDArrayInt, all_lmks: NDArrayInt) -> NDArrayInt:
H, W, C = canvas.shape H, W, C = canvas.shape
for lmks in all_lmks: for lmks in all_lmks:
lmks = np.array(lmks) lmks = np.array(lmks)

View File

@ -2,47 +2,26 @@
# Modified pathing to suit Invoke # Modified pathing to suit Invoke
from pathlib import Path
import numpy as np import numpy as np
import onnxruntime as ort import onnxruntime as ort
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from .onnxdet import inference_detector from .onnxdet import inference_detector
from .onnxpose import inference_pose from .onnxpose import inference_pose
DWPOSE_MODELS = {
"yolox_l.onnx": {
"local": "any/annotators/dwpose/yolox_l.onnx",
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
},
"dw-ll_ucoco_384.onnx": {
"local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx",
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
},
}
config = get_config() config = get_config()
class Wholebody: class Wholebody:
def __init__(self): def __init__(self, onnx_det: Path, onnx_pose: Path):
device = TorchDevice.choose_torch_device() device = TorchDevice.choose_torch_device()
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"] providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"]
download_with_progress_bar(
"dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH
)
onnx_det = DET_MODEL_PATH
onnx_pose = POSE_MODEL_PATH
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)

View File

@ -1,4 +1,4 @@
import gc from pathlib import Path
from typing import Any from typing import Any
import numpy as np import numpy as np
@ -6,9 +6,7 @@ import torch
from PIL import Image from PIL import Image
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config from invokeai.backend.model_manager.config import AnyModel
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import TorchDevice
def norm_img(np_img): def norm_img(np_img):
@ -19,28 +17,11 @@ def norm_img(np_img):
return np_img return np_img
def load_jit_model(url_or_path, device):
model_path = url_or_path
logger.info(f"Loading model from: {model_path}")
model = torch.jit.load(model_path, map_location="cpu").to(device)
model.eval()
return model
class LaMA: class LaMA:
def __init__(self, model: AnyModel):
self._model = model
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
device = TorchDevice.choose_torch_device()
model_location = get_config().models_path / "core/misc/lama/lama.pt"
if not model_location.exists():
download_with_progress_bar(
name="LaMa Inpainting Model",
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
dest_path=model_location,
)
model = load_jit_model(model_location, device)
image = np.asarray(input_image.convert("RGB")) image = np.asarray(input_image.convert("RGB"))
image = norm_img(image) image = norm_img(image)
@ -48,20 +29,25 @@ class LaMA:
mask = np.asarray(mask) mask = np.asarray(mask)
mask = np.invert(mask) mask = np.invert(mask)
mask = norm_img(mask) mask = norm_img(mask)
mask = (mask > 0) * 1 mask = (mask > 0) * 1
device = next(self._model.buffers()).device
image = torch.from_numpy(image).unsqueeze(0).to(device) image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device) mask = torch.from_numpy(mask).unsqueeze(0).to(device)
with torch.inference_mode(): with torch.inference_mode():
infilled_image = model(image, mask) infilled_image = self._model(image, mask)
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy() infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8") infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
infilled_image = Image.fromarray(infilled_image) infilled_image = Image.fromarray(infilled_image)
del model
gc.collect()
torch.cuda.empty_cache()
return infilled_image return infilled_image
@staticmethod
def load_jit_model(url_or_path: str | Path, device: torch.device | str = "cpu") -> torch.nn.Module:
model_path = url_or_path
logger.info(f"Loading model from: {model_path}")
model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore
model.eval()
return model

View File

@ -1,6 +1,5 @@
import math import math
from enum import Enum from enum import Enum
from pathlib import Path
from typing import Any, Optional from typing import Any, Optional
import cv2 import cv2
@ -11,6 +10,7 @@ from cv2.typing import MatLike
from tqdm import tqdm from tqdm import tqdm
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
""" """
@ -52,7 +52,7 @@ class RealESRGAN:
def __init__( def __init__(
self, self,
scale: int, scale: int,
model_path: Path, loadnet: AnyModel,
model: RRDBNet, model: RRDBNet,
tile: int = 0, tile: int = 0,
tile_pad: int = 10, tile_pad: int = 10,
@ -67,8 +67,6 @@ class RealESRGAN:
self.half = half self.half = half
self.device = TorchDevice.choose_torch_device() self.device = TorchDevice.choose_torch_device()
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
# prefer to use params_ema # prefer to use params_ema
if "params_ema" in loadnet: if "params_ema" in loadnet:
keyname = "params_ema" keyname = "params_ema"

View File

@ -125,13 +125,16 @@ class IPAdapter(RawModel):
self.device, dtype=self.dtype self.device, dtype=self.dtype
) )
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None): def to(
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False
):
if device is not None:
self.device = device self.device = device
if dtype is not None: if dtype is not None:
self.dtype = dtype self.dtype = dtype
self._image_proj_model.to(device=self.device, dtype=self.dtype) self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
self.attn_weights.to(device=self.device, dtype=self.dtype) self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
def calc_size(self): def calc_size(self):
# workaround for circular import # workaround for circular import

View File

@ -61,9 +61,10 @@ class LoRALayerBase:
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None: ) -> None:
if self.bias is not None: if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype) self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)
# TODO: find and debug lora/locon with bias # TODO: find and debug lora/locon with bias
@ -109,14 +110,15 @@ class LoRALayer(LoRALayerBase):
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None: ) -> None:
super().to(device=device, dtype=dtype) super().to(device=device, dtype=dtype, non_blocking=non_blocking)
self.up = self.up.to(device=device, dtype=dtype) self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.down = self.down.to(device=device, dtype=dtype) self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.mid is not None: if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype) self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)
class LoHALayer(LoRALayerBase): class LoHALayer(LoRALayerBase):
@ -169,18 +171,19 @@ class LoHALayer(LoRALayerBase):
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None: ) -> None:
super().to(device=device, dtype=dtype) super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype) self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w1_b = self.w1_b.to(device=device, dtype=dtype) self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t1 is not None: if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype) self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_a = self.w2_a.to(device=device, dtype=dtype) self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_b = self.w2_b.to(device=device, dtype=dtype) self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t2 is not None: if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype) self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
class LoKRLayer(LoRALayerBase): class LoKRLayer(LoRALayerBase):
@ -265,6 +268,7 @@ class LoKRLayer(LoRALayerBase):
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None: ) -> None:
super().to(device=device, dtype=dtype) super().to(device=device, dtype=dtype)
@ -273,19 +277,19 @@ class LoKRLayer(LoRALayerBase):
else: else:
assert self.w1_a is not None assert self.w1_a is not None
assert self.w1_b is not None assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype) self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w1_b = self.w1_b.to(device=device, dtype=dtype) self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.w2 is not None: if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype) self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
else: else:
assert self.w2_a is not None assert self.w2_a is not None
assert self.w2_b is not None assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype) self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_b = self.w2_b.to(device=device, dtype=dtype) self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t2 is not None: if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype) self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
class FullLayer(LoRALayerBase): class FullLayer(LoRALayerBase):
@ -319,10 +323,11 @@ class FullLayer(LoRALayerBase):
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None: ) -> None:
super().to(device=device, dtype=dtype) super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype) self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
class IA3Layer(LoRALayerBase): class IA3Layer(LoRALayerBase):
@ -358,11 +363,12 @@ class IA3Layer(LoRALayerBase):
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
): ):
super().to(device=device, dtype=dtype) super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype) self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.on_input = self.on_input.to(device=device, dtype=dtype) self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
@ -388,10 +394,11 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None: ) -> None:
# TODO: try revert if exception? # TODO: try revert if exception?
for _key, layer in self.layers.items(): for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype) layer.to(device=device, dtype=dtype, non_blocking=non_blocking)
def calc_size(self) -> int: def calc_size(self) -> int:
model_size = 0 model_size = 0
@ -514,7 +521,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
# lower memory consumption by removing already parsed layer values # lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear() state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype) layer.to(device=device, dtype=dtype, non_blocking=True)
model.layers[layer_key] = layer model.layers[layer_key] = layer
return model return model

View File

@ -0,0 +1,24 @@
import json
from base64 import b64decode
def validate_hash(hash: str):
if ":" not in hash:
return
for enc_hash in hashes:
alg, hash_ = hash.split(":")
if alg == "blake3":
alg = "blake3_single"
map = json.loads(b64decode(enc_hash))
if alg in map:
if hash_ == map[alg]:
raise Exception("Unrecoverable Model Error")
hashes: list[str] = [
"eyJibGFrZTNfbXVsdGkiOiI3Yjc5ODZmM2QyNTk3MDZiMjVhZDRhM2NmNGM2MTcyNGNhZmQ0Yjc4NjI4MjIwNjMyZGU4NjVlM2UxNDEyMTVlIiwiYmxha2UzX3NpbmdsZSI6IjdiNzk4NmYzZDI1OTcwNmIyNWFkNGEzY2Y0YzYxNzI0Y2FmZDRiNzg2MjgyMjA2MzJkZTg2NWUzZTE0MTIxNWUiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiNzdlZmU5MzRhZGQ3YmU5Njc3NmJkODM3NWJhZDQxN2QiLCJzaGExIjoiYmM2YzYxYzgwNDgyMTE2ZTY2ZGQyNTYwNjRkYTgxYjFlY2U4NzMzOCIsInNoYTIyNCI6IjgzNzNlZGM4ZTg4Y2UxMTljODdlOTM2OTY4ZWViMWNmMzdjZGY4NTBmZjhjOTZkYjNmMDc4YmE0Iiwic2hhMjU2IjoiNzNjYWMxZWRlZmUyZjdlODFkNjRiMTI2YjIxMmY2Yzk2ZTAwNjgyNGJjZmJkZDI3Y2E5NmUyNTk5ZTQwNzUwZiIsInNoYTM4NCI6IjlmNmUwNzlmOTNiNDlkMTg1YzEyNzY0OGQwNzE3YTA0N2E3MzYyNDI4YzY4MzBhNDViNzExODAwZDE4NjIwZDZjMjcwZGE3ZmY0Y2FjOTRmNGVmZDdiZWQ5OTlkOWU0ZCIsInNoYTUxMiI6IjAwNzE5MGUyYjk5ZjVlN2Q1OGZiYWI2YTk1YmY0NjJiODhkOTg1N2NlNjY4MTMyMGJmM2M0Y2ZiZmY0MjkxZmEzNTMyMTk3YzdkODc2YWQ3NjZhOTQyOTQ2Zjc1OWY2YTViNDBlM2I2MzM3YzIwNWI0M2JkOWMyN2JiMTljNzk0IiwiYmxha2UyYiI6IjlhN2VhNTQzY2ZhMmMzMWYyZDIyNjg2MjUwNzUyNDE0Mjc1OWJiZTA0MWZlMWJkMzQzNDM1MWQwNWZlYjI2OGY2MjU0OTFlMzlmMzdkYWQ4MGM2Y2UzYTE4ZjAxNGEzZjJiMmQ2OGU2OTc0MjRmNTU2M2Y5ZjlhYzc1MzJiMjEwIiwiYmxha2UycyI6ImYxZmMwMjA0YjdjNzIwNGJlNWI1YzY3NDEyYjQ2MjY5NWE3YjFlYWQ2M2E5ZGVkMjEzYjZmYTU0NGZjNjJlYzUiLCJzaGEzXzIyNCI6IjljZDQ3YTBhMzA3NmNmYzI0NjJhNTAzMjVmMjg4ZjFiYzJjMmY2NmU2ODIxODc5NjJhNzU0NjFmIiwic2hhM18yNTYiOiI4NTFlNGI1ZDI1MWZlZTFiYzk0ODU1OWNjMDNiNjhlNTllYWU5YWI1ZTUyYjA0OTgxYTRhOTU4YWQyMDdkYjYwIiwic2hhM18zODQiOiJiZDA2ZTRhZGFlMWQ0MTJmZjFjOTcxMDJkZDFlN2JmY2UzMDViYTgxMTgyNzM3NWY5NTI4OWJkOGIyYTUxNjdiMmUyNzZjODNjNTU3ODFhMTEyMDRhNzc5MTUwMzM5ZTEiLCJzaGEzXzUxMiI6ImQ1ZGQ2OGZmZmY5NGRhZjJhMDkzZTliNmM1MTBlZmZkNThmZTA0ODMyZGQzMzEyOTZmN2NkZmYzNmRhZmQ3NGMxY2VmNjUxNTBkZjk5OGM1ODgyY2MzMzk2MTk1ZTViYjc5OTY1OGFkMTQ3MzFiMjJmZWZiMWQzNmY2MWJjYzJjIiwic2hha2VfMTI4IjoiOWJlNTgwNWMwNjg1MmZmNDUzNGQ4ZDZmODYyMmFkOTJkMGUwMWE2Y2JmYjIwN2QxOTRmM2JkYThiOGNmNWU4ZiIsInNoYWtlXzI1NiI6IjRhYjgwYjY2MzcxYzdhNjBhYWM4NDVkMTZlNWMzZDNhMmM4M2FjM2FjZDNiNTBiNzdjYWYyYTNmMWMyY2ZjZjc5OGNjYjkxN2FjZjQzNzBmZDdjN2ZmODQ5M2Q3NGY1MWM4NGU3M2ViZGQ4MTRmM2MwMzk3YzI4ODlmNTI0Mzg3In0K",
"eyJibGFrZTNfbXVsdGkiOiI4ODlmYzIwMDA4NWY1NWY4YTA4MjhiODg3MDM0OTRhMGFmNWZkZGI5N2E2YmYwMDRjM2VkYTdiYzBkNDU0MjQzIiwiYmxha2UzX3NpbmdsZSI6Ijg4OWZjMjAwMDg1ZjU1ZjhhMDgyOGI4ODcwMzQ5NGEwYWY1ZmRkYjk3YTZiZjAwNGMzZWRhN2JjMGQ0NTQyNDMiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiNTIzNTRhMzkzYTVmOGNjNmMyMzQ0OThiYjcxMDljYzEiLCJzaGExIjoiMTJmYmRhOGE3ZGUwOGMwNDc2NTA5OWY2NGNmMGIzYjcxMjc1MGM1NyIsInNoYTIyNCI6IjEyZWU3N2U0Y2NhODViMDk4YjdjNWJlMWFjNGMwNzljNGM3MmJmODA2YjdlZjU1NGI0NzgxZDkxIiwic2hhMjU2IjoiMjU1NTMwZDAyYTY4MjY4OWE5ZTZjMjRhOWZhMDM2OGNhODMxZTI1OTAyYjM2NzQyNzkwZTk3NzU1ZjEzMmNmNSIsInNoYTM4NCI6IjhkMGEyMTRlNDk0NGE2NGY3ZmZjNTg3MGY0ZWUyZTA0OGIzYjRjMmQ0MGRmMWFmYTVlOGE1ZWNkN2IwOTY3M2ZjNWI5YzM5Yzg4Yjc2YmIwY2I4ZjQ1ZjAxY2MwNjZkNCIsInNoYTUxMiI6Ijg3NTM3OWNiYzdlOGYyNzU4YjVjMDY5ZTU2ZWRjODY1ODE4MGFkNDEzNGMwMzY1NzM4ZjM1YjQwYzI2M2JkMTMwMzcwZTE0MzZkNDNmOGFhMTgyMTg5MzgzMTg1ODNhOWJhYTUyYTBjMTk1Mjg5OTQzYzZiYTY2NTg1Yjg5M2ZiIiwiYmxha2UyYiI6IjBhY2MwNWEwOGE5YjhhODNmZTVjYTk4ZmExMTg3NTYwNjk0MjY0YWUxNTI4NDliYzFkNzQzNTYzMzMyMTlhYTg3N2ZiNjc4MmRjZDZiOGIyYjM1MTkyNDQzNDE2ODJiMTQ3YmY2YTY3MDU2ZWIwOTQ4MzE1M2E4Y2ZiNTNmMTI0IiwiYmxha2UycyI6ImY5ZTRhZGRlNGEzZDRhOTZhOWUyNjVjMGVmMjdmZDNiNjA0NzI1NDllMTEyMWQzOGQwMTkxNTY5ZDY5YzdhYzAiLCJzaGEzXzIyNCI6ImM0NjQ3MGRjMjkyNGI0YjZkMTA2NDY5MDRiNWM2OGVjNTU2YmQ4MTA5NmVkMTA4YjZiMzQyZmU1Iiwic2hhM18yNTYiOiIwMDBlMThiZTI1MzYxYTk0NGExZTIwNjQ5ZmY0ZGM2OGRiZTk0OGNkNTYwY2I5MTFhODU1OTE3ODdkNWQ5YWYwIiwic2hhM18zODQiOiIzNDljZmVhMGUxZGE0NWZlMmYzNjJhMWFjZjI1ZTczOWNiNGQ0NDdiM2NiODUzZDVkYWNjMzU5ZmRhMWE1M2FhYWU5OTM2ZmFhZWM1NmFhZDkwMThhYjgxMTI4ZjI3N2YiLCJzaGEzXzUxMiI6ImMxNDgwNGY1YTNjNWE4ZGEyMTAyODk1YTFjZGU4MmIwNGYwZmY4OTczMTc0MmY2NDQyY2NmNzQ1OTQzYWQ5NGViOWZmMTNhZDg3YjRmODkxN2M5NmY5ZjMwZjkwYTFhYTI4OTI3OTkwMjg0ZDJhMzcyMjA0NjE4MTNiNDI0MzEyIiwic2hha2VfMTI4IjoiN2IxY2RkMWUyMzUzMzk0OTg5M2UyMmZkMTAwZmU0YjJhMTU1MDJmMTNjMTI0YzhiZDgxY2QwZDdlOWEzMGNmOCIsInNoYWtlXzI1NiI6ImI0NjMzZThhMjNkZDM0ODk0ZTIyNzc0ODYyNTE1MzVjYWFlNjkyMTdmOTQ0NTc3MzE1NTljODBjNWQ3M2ZkOTMxZTFjMDJlZDI0Yjc3MzE3OTJjMjVlNTZhYjg3NjI4YmJiMDgxNTU0MjU2MWY5ZGI2NWE0NDk4NDFmNGQzYTU4In0K",
"eyJibGFrZTNfbXVsdGkiOiI2Y2M0MmU4NGRiOGQyZTliYjA4YjUxNWUwYzlmYzg2NTViNDUwNGRlZDM1MzBlZjFjNTFjZWEwOWUxYThiNGYxIiwiYmxha2UzX3NpbmdsZSI6IjZjYzQyZTg0ZGI4ZDJlOWJiMDhiNTE1ZTBjOWZjODY1NWI0NTA0ZGVkMzUzMGVmMWM1MWNlYTA5ZTFhOGI0ZjEiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiZDQwNjk3NTJhYjQ0NzFhZDliMDY3YmUxMmRjNTM2ZjYiLCJzaGExIjoiOGRjZmVlMjZjZjUyOTllMDBjN2QwZjJiZTc0NmVmMTlkZjliZGExNCIsInNoYTIyNCI6IjhjMzAzOTU3ZjI3NDNiMjUwNmQyYzIzY2VmNmU4MTQ5MTllZmE2MWM0MTFiMDk5ZmMzODc2MmRjIiwic2hhMjU2IjoiZDk3ZjQ2OWJjMWZkMjhjMjZkMjJhN2Y3ODczNzlhZmM4NjY3ZmZmM2FhYTQ5NTE4NmQyZTM4OTU2MTBjZDJmMyIsInNoYTM4NCI6IjY0NmY0YWM0ZDA2YWJkZmE2MDAwN2VjZWNiOWNjOTk4ZmJkOTBiYzYwMmY3NTk2M2RhZDUzMGMzNGE5ZGE1YzY4NjhlMGIwMDJkZDNlMTM4ZjhmMjA2ODcyNzFkMDVjMSIsInNoYTUxMiI6ImYzZTU4NTA0YzYyOGUwYjViNzBhOTYxYThmODA1MDA1NjQ1M2E5NDlmNTgzNDhiYTNhZTVlMjdkNDRhNGJkMjc5ZjA3MmU1OGQ5YjEyOGE1NDc1MTU2ZmM3YzcxMGJkYjI3OWQ5OGFmN2EwYTI4Y2Y1ZDY2MmQxODY4Zjg3ZjI3IiwiYmxha2UyYiI6ImFhNjgyYmJjM2U1ZGRjNDZkNWUxN2VjMzRlNmEzZGY5ZjhiNWQyNzk0YTZkNmY0M2VjODMxZjhjOTU2OGYyY2RiOGE4YjAyNTE4MDA4YmY0Y2FhYTlhY2FhYjNkNzRmZmRiNGZlNDgwOTcwODU3OGJiZjNlNzJjYTc5ZDQwYzZmIiwiYmxha2UycyI6ImQ0ZGJlZTJkMmZlNDMwOGViYTkwMTY1MDdmMzI1ZmJiODZlMWQzNDQ0MjgzNzRlMjAwNjNiNWQ1MzkzZTExNjMiLCJzaGEzXzIyNCI6ImE1ZTM5NWZlNGRlYjIyY2JhNjgwMWFiZTliZjljMjM2YmMzYjkwZDdiN2ZjMTRhZDhjZjQ0NzBlIiwic2hhM18yNTYiOiIwOWYwZGVjODk0OWEzYmQzYzU3N2RjYzUyMTMwMGRiY2UwMjVjM2VjOTJkNzQ0MDJkNTE1ZDA4NTQwODg2NGY1Iiwic2hhM18zODQiOiJmMjEyNmM5NTcxODQ3NDZmNjYyMjE4MTRkMDZkZWQ3NDBhYWU3MDA4MTc0YjI0OTEzY2YwOTQzY2IwMTA5Y2QxNWI4YmMwOGY1YjUwMWYwYzhhOTY4MzUwYzgzY2I1ZWUiLCJzaGEzXzUxMiI6ImU1ZmEwMzIwMzk2YTJjMThjN2UxZjVlZmJiODYwYTU1M2NlMTlkMDQ0MWMxNWEwZTI1M2RiNjJkM2JmNjg0ZDI1OWIxYmQ4OTJkYTcyMDVjYTYyODQ2YzU0YWI1ODYxOTBmNDUxZDlmZmNkNDA5YmU5MzlhNWM1YWIyZDdkM2ZkIiwic2hha2VfMTI4IjoiNGI2MTllM2I4N2U1YTY4OTgxMjk0YzgzMmU0NzljZGI4MWFmODdlZTE4YzM1Zjc5ZjExODY5ZWEzNWUxN2I3MiIsInNoYWtlXzI1NiI6ImYzOWVkNmMxZmQ2NzVmMDg3ODAyYTc4ZTUwYWFkN2ZiYTZiM2QxNzhlZWYzMjRkMTI3ZTZjYmEwMGRjNzkwNTkxNjQ1Y2U1Y2NmMjhjYzVkNWRkODU1OWIzMDMxYTM3ZjE5NjhmYmFhNDQzMmI2ZWU0Yzg3ZWE2YTdkMmE2NWM2In0K",
"eyJibGFrZTNfbXVsdGkiOiJhNDRiZjJkMzVkZDI3OTZlZTI1NmY0MzVkODFhNTdhOGM0MjZhMzM5ZDc3NTVkMmNiMjdmMzU4ZjM0NTM4OWM2IiwiYmxha2UzX3NpbmdsZSI6ImE0NGJmMmQzNWRkMjc5NmVlMjU2ZjQzNWQ4MWE1N2E4YzQyNmEzMzlkNzc1NWQyY2IyN2YzNThmMzQ1Mzg5YzYiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiOGU5OTMzMzEyZjg4NDY4MDg0ZmRiZWNjNDYyMTMxZTgiLCJzaGExIjoiNmI0MmZjZDFmMmQyNzUwYWNkY2JkMTUzMmQ4NjQ5YTM1YWI2NDYzNCIsInNoYTIyNCI6ImQ2Y2E2OTUxNzIzZjdjZjg0NzBjZWRjMmVhNjA2ODNmMWU4NDMzM2Q2NDM2MGIzOWIyMjZlZmQzIiwic2hhMjU2IjoiMDAxNGY5Yzg0YjcwMTFhMGJkNzliNzU0NGVjNzg4NDQzNWQ4ZGY0NmRjMDBiNDk0ZmFkYzA4NWQzNDM1NjI4MyIsInNoYTM4NCI6IjMxODg2OTYxODc4NWY3MWJlM2RlZjkyZDgyNzY2NjBhZGE0MGViYTdkMDk1M2Y0YTc5ODdlMThhNzFlNjBlY2EwY2YyM2YwMjVhMmQ4ZjUyMmNkZGY3MTcxODFhMTQxNSIsInNoYTUxMiI6IjdmZGQxN2NmOWU3ZTBhZDcwMzJjMDg1MTkyYWMxZmQ0ZmFhZjZkNWNlYzAzOTE5ZDk0MmZiZTIyNWNhNmIwZTg0NmQ4ZGI0ZjllYTQ5MjJlMTdhNTg4MTY4YzExMTM1NWZiZDQ1NTlmMmU5NDcwNjAwZWE1MzBhMDdiMzY0YWQwIiwiYmxha2UyYiI6IjI0ZjExZWI5M2VlN2YxOTI5NWZiZGU5MTczMmE0NGJkZGYxOWE1ZTQ4MWNmOWFhMjQ2M2UzNDllYjg0Mzc4ZDBkODFjNzY0YWQ1NTk1YjkxZjQzYzgxODcxNTRlYWU5NTZkY2ZjZTlkMWU2MTZjNTFkZThhZDZjZTBhODcyY2Q0IiwiYmxha2UycyI6IjVkZTUwZDUwMGYwYTBmOGRlMTEwOGE2ZmFkZGM4ODNlMTA3NmQ3MThiNmQxN2E4ZDVkMjgzZDdiNGYzZDU2OGEiLCJzaGEzXzIyNCI6IjFhNTA0OGNlYWZiYjg2ZDc4ZmNiNTI0ZTViYTc4NWQ2ZmY5NzY1ZTNlMzdhZWRjZmYxZGVjNGJhIiwic2hhM18yNTYiOiI0YjA0YjE1NTRmMzRkYTlmMjBmZDczM2IzNDg4NjE0ZWNhM2IwOWU1OTJjOGJlMmM0NjA1NjYyMWU0MjJmZDllIiwic2hhM18zODQiOiI1NjMwYjM2OGQ4MGM1YmM5MTgzM2VmNWM2YWUzOTJhNDE4NTNjYmM2MWJiNTI4ZDE4YWM1OWFjZGZiZWU1YThkMWMyZDE4MTM1ZGI2ZWQ2OTJlODFkZThmYTM3MzkxN2MiLCJzaGEzXzUxMiI6IjA2ODg4MGE1MmNiNDkzODYwZDhjOTVhOTFhZGFmZTYwZGYxODc2ZDhjYjFhNmI3NTU2ZjJjM2Y1NjFmMGYwZjMyZjZhYTA1YmVmN2FhYjQ5OWEwNTM0Zjk0Njc4MDEzODlmNDc0ODFiNzcxMjdjMDFiOGFhOTY4NGJhZGUzYmY2Iiwic2hha2VfMTI4IjoiODlmYTdjNDcwNGI4NGZkMWQ1M2E0MTBlN2ZjMzU3NWRhNmUxMGU1YzkzMjM1NWYyZWEyMWM4NDVhZDBlM2UxOCIsInNoYWtlXzI1NiI6IjE4NGNlMWY2NjdmYmIyODA5NWJhZmVkZTQzNTUzZjhkYzBhNGY1MDQwYWJlMjcxMzkzMzcwNDEyZWFiZTg0ZGJhNjI0Y2ZiZWE4YzUxZDU2YzkwMTM2Mjg2ODgyZmQ0Y2E3MzA3NzZjNWUzODFlYzI5MWYxYTczOTE1MDkyMTFmIn0K",
"eyJibGFrZTNfbXVsdGkiOiJhYjA2YjNmMDliNTExOTAzMTMzMzY5NDE2MTc4ZDk2ZjlkYTc3ZGEwOTgyNDJmN2VlMTVjNTNhNTRkMDZhNWVmIiwiYmxha2UzX3NpbmdsZSI6ImFiMDZiM2YwOWI1MTE5MDMxMzMzNjk0MTYxNzhkOTZmOWRhNzdkYTA5ODI0MmY3ZWUxNWM1M2E1NGQwNmE1ZWYiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiZWY0MjcxYjU3NTQwMjU4NGQ2OTI5ZWJkMGI3Nzk5NzYiLCJzaGExIjoiMzgzNzliYWQzZjZiZjc4MmM4OTgzOGY3YWVkMzRkNDNkMzNlYWM2MSIsInNoYTIyNCI6ImQ5ZDNiMjJkYmZlY2M1NTdlODAzNjg5M2M3ZWE0N2I0NTQzYzM2NzZhMDk4NzMxMzRhNjQ0OWEwIiwic2hhMjU2IjoiMjYxZGI3NmJlMGYxMzdlZWJkYmI5OGRlYWM0ZjcyMDdiOGUxMjdiY2MyZmMwODI5OGVjZDczYjQ3MjYxNjQ1NiIsInNoYTM4NCI6IjMzMjkwYWQxYjlhMmRkYmU0ODY3MWZiMTIxNDdiZWJhNjI4MjA1MDcwY2VkNjNiZTFmNGU5YWRhMjgwYWU2ZjZjNDkzYTY2MDllMGQ2YTIzMWU2ODU5ZmIyNGZhM2FjMCIsInNoYTUxMiI6IjAzMDZhMWI1NmNiYTdjNjJiNTNmNTk4MTAwMTQ3MDQ5ODBhNGRmZTdjZjQ5NTU4ZmMyMmQxZDczZDc5NzJmZTllODk2ZWRjMmEyYTQxYWVjNjRjZjkwZGUwYjI1NGM0MDBlZTU1YzcwZjk3OGVlMzk5NmM2YzhkNTBjYTI4YTdiIiwiYmxha2UyYiI6IjY1MDZhMDg1YWQ5MGZkZjk2NGJmMGE5NTFkZmVkMTllZTc0NGVjY2EyODQzZjQzYTI5NmFjZDM0M2RiODhhMDNlNTlkNmFmMGM1YWJkNTEzMzc4MTQ5Yjg3OTExMTVmODRmMDIyZWM1M2JmNGFjNDZhZDczNWIwMmJlYTM0MDk5IiwiYmxha2UycyI6IjdlZDQ3ZWQxOTg3MTk0YWFmNGIwMjQ3MWFkNTMyMmY3NTE3ZjI0OTcwMDc2Y2NmNDkzMWI0MzYxMDU1NzBlNDAiLCJzaGEzXzIyNCI6Ijk2MGM4MDExOTlhMGUzYWExNjdiNmU2MWVkMzE2ZDUzMDM2Yjk4M2UyOThkNWI5MjZmMDc3NDlhIiwic2hhM18yNTYiOiIzYzdmYWE1ZDE3Zjk2MGYxOTI2ZjNlNGIyZjc1ZjdiOWIyZDQ4NGFhNmEwM2ViOWNlMTI4NmM2OTE2YWEyM2RlIiwic2hhM18zODQiOiI5Y2Y0NDA1NWFjYzFlYjZmMDY1YjRjODcxYTYzNTM1MGE1ZjY0ODQwM2YwYTU0MWEzYzZhNjI3N2ViZjZmYTNjYmM1YmJiNjQwMDE4OGFlMWIxMTI2OGZmMDJiMzYzZDUiLCJzaGEzXzUxMiI6ImEyZDk3ZDRlYjYxM2UwZDViYTc2OTk2MzE2MzcxOGEwNDIxZDkxNTNiNjllYjM5MDRmZjI4ODRhZDdjNGJiYmIwNGY2Nzc1OTA1YmQxNGI2NTJmZTQ1Njg0YmI5MTQ3ZjBkYWViZjAxZjIzY2MzZDhkMjIzMTE0MGUzNjI4NTE5Iiwic2hha2VfMTI4IjoiNjkwMWMwYjg1MTg5ZTkyNTJiODI3MTc5NjE2MjRlMTM0MDQ1ZjlkMmI5MzM0MzVkM2Y0OThiZWIyN2Q3N2JiNSIsInNoYWtlXzI1NiI6ImIwMjA4ZTFkNDVjZWI0ODdiZDUwNzk3MWJiNWI3MjdjN2UyYmE3ZDliNWM2ZTEyYWE5YTNhOTY5YzcyNDRjODIwZDcyNDY1ODhlZWU3Yjk4ZWM1NzhjZWIxNjc3OTkxODljMWRkMmZkMmZmYWM4MWExZDAzZDFiNjMxOGRkMjBiIn0K",
]

View File

@ -31,12 +31,13 @@ from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from ..raw_model import RawModel from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models # ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime # RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module] AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
class InvalidModelConfigException(Exception): class InvalidModelConfigException(Exception):
@ -115,7 +116,7 @@ class SchedulerPredictionType(str, Enum):
class ModelRepoVariant(str, Enum): class ModelRepoVariant(str, Enum):
"""Various hugging face variants on the diffusers format.""" """Various hugging face variants on the diffusers format."""
Default = "" # model files without "fp16" or other qualifier - empty str Default = "" # model files without "fp16" or other qualifier
FP16 = "fp16" FP16 = "fp16"
FP32 = "fp32" FP32 = "fp32"
ONNX = "onnx" ONNX = "onnx"
@ -448,4 +449,6 @@ class ModelConfigFactory(object):
model.key = key model.key = key
if isinstance(model, CheckpointConfigBase) and timestamp is not None: if isinstance(model, CheckpointConfigBase) and timestamp is not None:
model.converted_at = timestamp model.converted_at = timestamp
if model:
validate_hash(model.hash)
return model # type: ignore return model # type: ignore

View File

@ -7,7 +7,7 @@ from importlib import import_module
from pathlib import Path from pathlib import Path
from .convert_cache.convert_cache_default import ModelConvertCache from .convert_cache.convert_cache_default import ModelConvertCache
from .load_base import LoadedModel, ModelLoaderBase from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
from .load_default import ModelLoader from .load_default import ModelLoader
from .model_cache.model_cache_default import ModelCache from .model_cache.model_cache_default import ModelCache
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
@ -19,6 +19,7 @@ for module in loaders:
__all__ = [ __all__ = [
"LoadedModel", "LoadedModel",
"LoadedModelWithoutConfig",
"ModelCache", "ModelCache",
"ModelConvertCache", "ModelConvertCache",
"ModelLoaderBase", "ModelLoaderBase",

View File

@ -7,6 +7,7 @@ from pathlib import Path
from invokeai.backend.util import GIG, directory_size from invokeai.backend.util import GIG, directory_size
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.util import safe_filename
from .convert_cache_base import ModelConvertCacheBase from .convert_cache_base import ModelConvertCacheBase
@ -35,6 +36,7 @@ class ModelConvertCache(ModelConvertCacheBase):
def cache_path(self, key: str) -> Path: def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key.""" """Return the path for a model with the indicated key."""
key = safe_filename(self._cache_path, key)
return self._cache_path / key return self._cache_path / key
def make_room(self, size: float) -> None: def make_room(self, size: float) -> None:

View File

@ -4,10 +4,13 @@ Base class for model loading in InvokeAI.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Dict, Generator, Optional, Tuple
import torch
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
@ -20,10 +23,44 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
@dataclass @dataclass
class LoadedModel: class LoadedModelWithoutConfig:
"""Context manager object that mediates transfer from RAM<->VRAM.""" """
Context manager object that mediates transfer from RAM<->VRAM.
This is a context manager object that has two distinct APIs:
1. Older API (deprecated):
Use the LoadedModel object directly as a context manager.
It will move the model into VRAM (on CUDA devices), and
return the model in a form suitable for passing to torch.
Example:
```
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
with loaded_model as vae:
image = vae.decode(latents)[0]
```
2. Newer API (recommended):
Call the LoadedModel's `model_on_device()` method in a
context. It returns a tuple consisting of a copy of
the model's state dict in CPU RAM followed by a copy
of the model in VRAM. The state dict is provided to allow
LoRAs and other model patchers to return the model to
its unpatched state without expensive copy and restore
operations.
Example:
```
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
with loaded_model.model_on_device() as (state_dict, vae):
image = vae.decode(latents)[0]
```
The state_dict should be treated as a read-only object and
never modified. Also be aware that some loadable models do
not have a state_dict, in which case this value will be None.
"""
config: AnyModelConfig
_locker: ModelLockerBase _locker: ModelLockerBase
def __enter__(self) -> AnyModel: def __enter__(self) -> AnyModel:
@ -34,12 +71,29 @@ class LoadedModel:
"""Context exit.""" """Context exit."""
self._locker.unlock() self._locker.unlock()
@contextmanager
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
locked_model = self._locker.lock()
try:
state_dict = self._locker.get_state_dict()
yield (state_dict, locked_model)
finally:
self._locker.unlock()
@property @property
def model(self) -> AnyModel: def model(self) -> AnyModel:
"""Return the model without locking it.""" """Return the model without locking it."""
return self._locker.model return self._locker.model
@dataclass
class LoadedModel(LoadedModelWithoutConfig):
"""Context manager object that mediates transfer from RAM<->VRAM."""
config: Optional[AnyModelConfig] = None
# TODO(MM2): # TODO(MM2):
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't # Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
# know about. I think the problem may be related to this class being an ABC. # know about. I think the problem may be related to this class being an ABC.

View File

@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
@ -84,7 +84,7 @@ class ModelLoader(ModelLoaderBase):
except IndexError: except IndexError:
pass pass
cache_path: Path = self._convert_cache.cache_path(config.key) cache_path: Path = self._convert_cache.cache_path(str(model_path))
if self._needs_conversion(config, model_path, cache_path): if self._needs_conversion(config, model_path, cache_path):
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type) loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
else: else:
@ -95,7 +95,6 @@ class ModelLoader(ModelLoaderBase):
config.key, config.key,
submodel_type=submodel_type, submodel_type=submodel_type,
model=loaded_model, model=loaded_model,
size=calc_model_size_by_data(loaded_model),
) )
return self._ram_cache.get( return self._ram_cache.get(
@ -126,9 +125,7 @@ class ModelLoader(ModelLoaderBase):
if subtype == submodel_type: if subtype == submodel_type:
continue continue
if submodel := getattr(pipeline, subtype.value, None): if submodel := getattr(pipeline, subtype.value, None):
self._ram_cache.put( self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel)
)
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:

View File

@ -31,6 +31,11 @@ class ModelLockerBase(ABC):
"""Unlock the contained model, and remove it from VRAM.""" """Unlock the contained model, and remove it from VRAM."""
pass pass
@abstractmethod
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
pass
@property @property
@abstractmethod @abstractmethod
def model(self) -> AnyModel: def model(self) -> AnyModel:
@ -43,11 +48,33 @@ T = TypeVar("T")
@dataclass @dataclass
class CacheRecord(Generic[T]): class CacheRecord(Generic[T]):
"""Elements of the cache.""" """
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
"""
key: str key: str
size: int size: int
model: T model: T
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
loaded: bool = False loaded: bool = False
_locks: int = 0 _locks: int = 0
@ -147,7 +174,6 @@ class ModelCacheBase(ABC, Generic[T]):
self, self,
key: str, key: str,
model: T, model: T,
size: int,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> None: ) -> None:
"""Store model under key and optional submodel_type.""" """Store model under key and optional submodel_type."""

View File

@ -29,7 +29,8 @@ from typing import Dict, Generator, List, Optional, Set
import torch import torch
from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -206,16 +207,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
self, self,
key: str, key: str,
model: AnyModel, model: AnyModel,
size: int,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> None: ) -> None:
"""Store model under key and optional submodel_type.""" """Store model under key and optional submodel_type."""
with self._ram_lock:
key = self._make_cache_key(key, submodel_type) key = self._make_cache_key(key, submodel_type)
if key in self._cached_models: if key in self._cached_models:
return return
size = calc_model_size_by_data(model)
self.make_room(size) self.make_room(size)
cache_record = CacheRecord(key, model=model, size=size)
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
self._cached_models[key] = cache_record self._cached_models[key] = cache_record
self._cache_stack.append(key) self._cache_stack.append(key)
@ -277,6 +279,106 @@ class ModelCache(ModelCacheBase[AnyModel]):
else: else:
return model_key return model_key
def offload_unlocked_models(self, size_required: int) -> None:
"""Move any unused models from VRAM."""
reserved = self._max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if not cache_entry.loaded:
continue
if not cache_entry.locked:
self.move_model_to_device(cache_entry, self.storage_device)
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
)
TorchDevice.empty_cache()
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device.
:param cache_entry: The CacheRecord for the model
:param target_device: The torch.device to move the model into
May raise a torch.cuda.OutOfMemoryError
"""
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
source_device = cache_entry.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
# Some models don't have a `to` method, in which case they run in RAM/CPU.
if not hasattr(cache_entry.model, "to"):
return
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from
# RAM to a new state dict in VRAM, and then inject it into the model.
# This operation is slightly faster than running `to()` on the whole model.
#
# When the model needs to be removed from VRAM we simply delete the copy
# of the state dict in VRAM, and reinject the state dict that is cached
# in RAM into the model. So this operation is very fast.
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
try:
if cache_entry.state_dict is not None:
assert hasattr(cache_entry.model, "load_state_dict")
if target_device == self.storage_device:
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device, non_blocking=True)
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self.logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning.
if not math.isclose(
vram_change,
cache_entry.size,
rel_tol=0.1,
abs_tol=10 * MB,
):
self.logger.debug(
f"Moving model '{cache_entry.key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
def print_cuda_stats(self) -> None: def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics.""" """Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)

View File

@ -2,8 +2,7 @@
Base class and implementation of a class that moves models in and out of VRAM. Base class and implementation of a class that moves models in and out of VRAM.
""" """
import copy from typing import Dict, Optional
from typing import Optional
import torch import torch
@ -26,42 +25,25 @@ class ModelLocker(ModelLockerBase):
""" """
self._cache = cache self._cache = cache
self._cache_entry = cache_entry self._cache_entry = cache_entry
self._execution_device: Optional[torch.device] = None
@property @property
def model(self) -> AnyModel: def model(self) -> AnyModel:
"""Return the model without moving it around.""" """Return the model without moving it around."""
return self._cache_entry.model return self._cache_entry.model
# ---------------------------- NOTE ----------------- def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
# Ryan suggests keeping a copy of the model's state dict in CPU and copying it """Return the state dict (if any) for the cached model."""
# into the GPU with code like this: return self._cache_entry.state_dict
#
# def state_dict_to(state_dict: dict[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]:
# new_state_dict: dict[str, torch.Tensor] = {}
# for k, v in state_dict.items():
# new_state_dict[k] = v.to(device=device, copy=True, non_blocking=True)
# return new_state_dict
#
# I believe we'd then use load_state_dict() to inject the state dict into the model.
# See: https://pytorch.org/tutorials/beginner/saving_loading_models.html
# ---------------------------- NOTE -----------------
def lock(self) -> AnyModel: def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it.""" """Move the model into the execution device (GPU) and lock it."""
if not hasattr(self.model, "to"):
return self.model
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
self._cache_entry.lock() self._cache_entry.lock()
try: try:
# We wait for a gpu to be free - may raise a ValueError if self._cache.lazy_offloading:
self._execution_device = self._cache.get_execution_device() self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}") self._cache.move_model_to_device(self._cache_entry, self._cache.get_execution_device())
model_in_gpu = copy.deepcopy(self._cache_entry.model)
if hasattr(model_in_gpu, "to"):
model_in_gpu.to(self._execution_device)
self._cache_entry.loaded = True self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.print_cuda_stats() self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting") self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
@ -70,11 +52,10 @@ class ModelLocker(ModelLockerBase):
except Exception: except Exception:
self._cache_entry.unlock() self._cache_entry.unlock()
raise raise
return model_in_gpu
return self.model
def unlock(self) -> None: def unlock(self) -> None:
"""Call upon exit from context.""" """Call upon exit from context."""
if not hasattr(self.model, "to"):
return
self._cache_entry.unlock() self._cache_entry.unlock()
self._cache.print_cuda_stats() self._cache.print_cuda_stats()

View File

@ -65,14 +65,11 @@ class GenericDiffusersLoader(ModelLoader):
else: else:
try: try:
config = self._load_diffusers_config(model_path, config_name="config.json") config = self._load_diffusers_config(model_path, config_name="config.json")
class_name = config.get("_class_name", None) if class_name := config.get("_class_name"):
if class_name:
result = self._hf_definition_to_type(module="diffusers", class_name=class_name) result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
if config.get("model_type", None) == "clip_vision_model": elif class_name := config.get("architectures"):
class_name = config.get("architectures")
assert class_name is not None
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0]) result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
if not class_name: else:
raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json") raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
except KeyError as e: except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e

View File

@ -22,8 +22,7 @@ from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class VAELoader(GenericDiffusersLoader): class VAELoader(GenericDiffusersLoader):
"""Class to load VAE models.""" """Class to load VAE models."""
@ -40,10 +39,6 @@ class VAELoader(GenericDiffusersLoader):
return True return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel: def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
# TODO(MM2): check whether sdxl VAE models convert.
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"VAE conversion not supported for model type: {config.base}")
else:
assert isinstance(config, CheckpointConfigBase) assert isinstance(config, CheckpointConfigBase)
config_file = self._app_config.legacy_conf_path / config.config_path config_file = self._app_config.legacy_conf_path / config.config_path

View File

@ -83,7 +83,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
assert s.size is not None assert s.size is not None
files.append( files.append(
RemoteModelFile( RemoteModelFile(
url=hf_hub_url(id, s.rfilename, revision=variant), url=hf_hub_url(id, s.rfilename, revision=variant or "main"),
path=Path(name, s.rfilename), path=Path(name, s.rfilename),
size=s.size, size=s.size,
sha256=s.lfs.get("sha256") if s.lfs else None, sha256=s.lfs.get("sha256") if s.lfs else None,

View File

@ -37,9 +37,12 @@ class RemoteModelFile(BaseModel):
url: AnyHttpUrl = Field(description="The url to download this model file") url: AnyHttpUrl = Field(description="The url to download this model file")
path: Path = Field(description="The path to the file, relative to the model root") path: Path = Field(description="The path to the file, relative to the model root")
size: int = Field(description="The size of this file, in bytes") size: Optional[int] = Field(description="The size of this file, in bytes", default=0)
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None) sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
def __hash__(self) -> int:
return hash(str(self))
class ModelMetadataBase(BaseModel): class ModelMetadataBase(BaseModel):
"""Base class for model metadata information.""" """Base class for model metadata information."""

View File

@ -10,7 +10,7 @@ from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.util.util import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings
from .config import ( from .config import (
AnyModelConfig, AnyModelConfig,
@ -451,8 +451,16 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
class VaeCheckpointProbe(CheckpointProbeBase): class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
# I can't find any standalone 2.X VAEs to test with! # VAEs of all base types have the same structure, so we wimp out and
return BaseModelType.StableDiffusion1 # guess using the name.
for regexp, basetype in [
(r"xl", BaseModelType.StableDiffusionXL),
(r"sd2", BaseModelType.StableDiffusion2),
(r"vae", BaseModelType.StableDiffusion1),
]:
if re.search(regexp, self.model_path.name, re.IGNORECASE):
return basetype
raise InvalidModelConfigException("Cannot determine base type")
class LoRACheckpointProbe(CheckpointProbeBase): class LoRACheckpointProbe(CheckpointProbeBase):

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import pickle import pickle
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -66,8 +66,14 @@ class ModelPatcher:
cls, cls,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
loras: Iterator[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None: model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
with cls.apply_lora(unet, loras, "lora_unet_"): ) -> Generator[None, None, None]:
with cls.apply_lora(
unet,
loras=loras,
prefix="lora_unet_",
model_state_dict=model_state_dict,
):
yield yield
@classmethod @classmethod
@ -76,28 +82,9 @@ class ModelPatcher:
cls, cls,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
loras: Iterator[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None: model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
with cls.apply_lora(text_encoder, loras, "lora_te_"): ) -> Generator[None, None, None]:
yield with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder2(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
yield yield
@classmethod @classmethod
@ -107,7 +94,16 @@ class ModelPatcher:
model: AnyModel, model: AnyModel,
loras: Iterator[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str, prefix: str,
) -> None: model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
"""
Apply one or more LoRAs to a model.
:param model: The model to patch.
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
"""
original_weights = {} original_weights = {}
try: try:
with torch.no_grad(): with torch.no_grad():
@ -133,6 +129,9 @@ class ModelPatcher:
dtype = module.weight.dtype dtype = module.weight.dtype
if module_key not in original_weights: if module_key not in original_weights:
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
original_weights[module_key] = model_state_dict[module_key + ".weight"]
else:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
@ -140,12 +139,12 @@ class ModelPatcher:
# We intentionally move to the target device first, then cast. Experimentally, this was found to # We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'. # same thing in a single call to '.to(...)'.
layer.to(device=device) layer.to(device=device, non_blocking=True)
layer.to(dtype=torch.float32) layer.to(dtype=torch.float32, non_blocking=True)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device=torch.device("cpu")) layer.to(device=torch.device("cpu"), non_blocking=True)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape: if module.weight.shape != layer_weight.shape:
@ -154,7 +153,7 @@ class ModelPatcher:
layer_weight = layer_weight.reshape(module.weight.shape) layer_weight = layer_weight.reshape(module.weight.shape)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
module.weight += layer_weight.to(dtype=dtype) module.weight += layer_weight.to(dtype=dtype, non_blocking=True)
yield # wait for context manager exit yield # wait for context manager exit
@ -162,7 +161,7 @@ class ModelPatcher:
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad(): with torch.no_grad():
for module_key, weight in original_weights.items(): for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight) model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)
@classmethod @classmethod
@contextmanager @contextmanager

View File

@ -6,6 +6,7 @@ from typing import Any, List, Optional, Tuple, Union
import numpy as np import numpy as np
import onnx import onnx
import torch
from onnx import numpy_helper from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers from onnxruntime import InferenceSession, SessionOptions, get_available_providers
@ -188,6 +189,15 @@ class IAIOnnxRuntimeModel(RawModel):
# return self.io_binding.copy_outputs_to_cpu() # return self.io_binding.copy_outputs_to_cpu()
return self.session.run(None, inputs) return self.session.run(None, inputs)
# compatability with RawModel ABC
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
pass
# compatability with diffusers load code # compatability with diffusers load code
@classmethod @classmethod
def from_pretrained( def from_pretrained(

View File

@ -10,6 +10,20 @@ The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
that adds additional methods and attributes. that adds additional methods and attributes.
""" """
from abc import ABC, abstractmethod
from typing import Optional
class RawModel: import torch
"""Base class for 'Raw' model wrappers."""
class RawModel(ABC):
"""Abstract base class for 'Raw' model wrappers."""
@abstractmethod
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
pass

View File

@ -65,6 +65,18 @@ class TextualInversionModelRaw(RawModel):
return result return result
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
if not torch.cuda.is_available():
return
for emb in [self.embedding, self.embedding_2]:
if emb is not None:
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
class TextualInversionManager(BaseTextualInversionManager): class TextualInversionManager(BaseTextualInversionManager):
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library.""" """TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""

View File

@ -1,29 +1,36 @@
"""Context class to silence transformers and diffusers warnings."""
import warnings import warnings
from typing import Any from contextlib import ContextDecorator
from diffusers import logging as diffusers_logging from diffusers.utils import logging as diffusers_logging
from transformers import logging as transformers_logging from transformers import logging as transformers_logging
class SilenceWarnings(object): # Inherit from ContextDecorator to allow using SilenceWarnings as both a context manager and a decorator.
"""Use in context to temporarily turn off warnings from transformers & diffusers modules. class SilenceWarnings(ContextDecorator):
"""A context manager that disables warnings from transformers & diffusers modules while active.
As context manager:
```
with SilenceWarnings(): with SilenceWarnings():
# do something # do something
```
As decorator:
```
@SilenceWarnings()
def some_function():
# do something
```
""" """
def __init__(self) -> None:
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self) -> None: def __enter__(self) -> None:
self._transformers_verbosity = transformers_logging.get_verbosity()
self._diffusers_verbosity = diffusers_logging.get_verbosity()
transformers_logging.set_verbosity_error() transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error() diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
def __exit__(self, *args: Any) -> None: def __exit__(self, *args) -> None:
transformers_logging.set_verbosity(self.transformers_verbosity) transformers_logging.set_verbosity(self._transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity) diffusers_logging.set_verbosity(self._diffusers_verbosity)
warnings.simplefilter("default") warnings.simplefilter("default")

View File

@ -1,17 +1,43 @@
import base64 import base64
import io import io
import os import os
import warnings import re
import unicodedata
from pathlib import Path from pathlib import Path
from diffusers import logging as diffusers_logging
from PIL import Image from PIL import Image
from transformers import logging as transformers_logging
# actual size of a gig # actual size of a gig
GIG = 1073741824 GIG = 1073741824
def slugify(value: str, allow_unicode: bool = False) -> str:
"""
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
dashes to single dashes. Remove characters that aren't alphanumerics,
underscores, or hyphens. Replace slashes with underscores.
Convert to lowercase. Also strip leading and
trailing whitespace, dashes, and underscores.
Adapted from Django: https://github.com/django/django/blob/main/django/utils/text.py
"""
value = str(value)
if allow_unicode:
value = unicodedata.normalize("NFKC", value)
else:
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
value = re.sub(r"[/]", "_", value.lower())
value = re.sub(r"[^.\w\s-]", "", value.lower())
return re.sub(r"[-\s]+", "-", value).strip("-_")
def safe_filename(directory: Path, value: str) -> str:
"""Make a string safe to use as a filename."""
escaped_string = slugify(value)
max_name_length = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 256
return escaped_string[len(escaped_string) - max_name_length :]
def directory_size(directory: Path) -> int: def directory_size(directory: Path) -> int:
""" """
Return the aggregate size of all files in a directory (bytes). Return the aggregate size of all files in a directory (bytes).
@ -51,21 +77,3 @@ class Chdir(object):
def __exit__(self, *args): def __exit__(self, *args):
os.chdir(self.original) os.chdir(self.original)
class SilenceWarnings(object):
"""Context manager to temporarily lower verbosity of diffusers & transformers warning messages."""
def __enter__(self):
"""Set verbosity to error."""
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore")
def __exit__(self, type, value, traceback):
"""Restore logger verbosity to state before context was entered."""
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter("default")

View File

@ -1021,7 +1021,8 @@
"float": "Kommazahlen", "float": "Kommazahlen",
"enum": "Aufzählung", "enum": "Aufzählung",
"fullyContainNodes": "Vollständig ausgewählte Nodes auswählen", "fullyContainNodes": "Vollständig ausgewählte Nodes auswählen",
"editMode": "Im Workflow-Editor bearbeiten" "editMode": "Im Workflow-Editor bearbeiten",
"resetToDefaultValue": "Auf Standardwert zurücksetzen"
}, },
"hrf": { "hrf": {
"enableHrf": "Korrektur für hohe Auflösungen", "enableHrf": "Korrektur für hohe Auflösungen",

View File

@ -6,7 +6,7 @@
"settingsLabel": "Ajustes", "settingsLabel": "Ajustes",
"img2img": "Imagen a Imagen", "img2img": "Imagen a Imagen",
"unifiedCanvas": "Lienzo Unificado", "unifiedCanvas": "Lienzo Unificado",
"nodes": "Editor del flujo de trabajo", "nodes": "Flujos de trabajo",
"upload": "Subir imagen", "upload": "Subir imagen",
"load": "Cargar", "load": "Cargar",
"statusDisconnected": "Desconectado", "statusDisconnected": "Desconectado",
@ -14,7 +14,7 @@
"discordLabel": "Discord", "discordLabel": "Discord",
"back": "Atrás", "back": "Atrás",
"loading": "Cargando", "loading": "Cargando",
"postprocessing": "Tratamiento posterior", "postprocessing": "Postprocesado",
"txt2img": "De texto a imagen", "txt2img": "De texto a imagen",
"accept": "Aceptar", "accept": "Aceptar",
"cancel": "Cancelar", "cancel": "Cancelar",
@ -42,7 +42,42 @@
"copy": "Copiar", "copy": "Copiar",
"beta": "Beta", "beta": "Beta",
"on": "En", "on": "En",
"aboutDesc": "¿Utilizas Invoke para trabajar? Mira aquí:" "aboutDesc": "¿Utilizas Invoke para trabajar? Mira aquí:",
"installed": "Instalado",
"green": "Verde",
"editor": "Editor",
"orderBy": "Ordenar por",
"file": "Archivo",
"goTo": "Ir a",
"imageFailedToLoad": "No se puede cargar la imagen",
"saveAs": "Guardar Como",
"somethingWentWrong": "Algo salió mal",
"nextPage": "Página Siguiente",
"selected": "Seleccionado",
"tab": "Tabulador",
"positivePrompt": "Prompt Positivo",
"negativePrompt": "Prompt Negativo",
"error": "Error",
"format": "formato",
"unknown": "Desconocido",
"input": "Entrada",
"nodeEditor": "Editor de nodos",
"template": "Plantilla",
"prevPage": "Página Anterior",
"red": "Rojo",
"alpha": "Transparencia",
"outputs": "Salidas",
"editing": "Editando",
"learnMore": "Aprende más",
"enabled": "Activado",
"disabled": "Desactivado",
"folder": "Carpeta",
"updated": "Actualizado",
"created": "Creado",
"save": "Guardar",
"unknownError": "Error Desconocido",
"blue": "Azul",
"viewingDesc": "Revisar imágenes en una vista de galería grande"
}, },
"gallery": { "gallery": {
"galleryImageSize": "Tamaño de la imagen", "galleryImageSize": "Tamaño de la imagen",
@ -467,7 +502,8 @@
"about": "Acerca de", "about": "Acerca de",
"createIssue": "Crear un problema", "createIssue": "Crear un problema",
"resetUI": "Interfaz de usuario $t(accessibility.reset)", "resetUI": "Interfaz de usuario $t(accessibility.reset)",
"mode": "Modo" "mode": "Modo",
"submitSupportTicket": "Enviar Ticket de Soporte"
}, },
"nodes": { "nodes": {
"zoomInNodes": "Acercar", "zoomInNodes": "Acercar",
@ -543,5 +579,17 @@
"layers_one": "Capa", "layers_one": "Capa",
"layers_many": "Capas", "layers_many": "Capas",
"layers_other": "Capas" "layers_other": "Capas"
},
"controlnet": {
"crop": "Cortar",
"delete": "Eliminar",
"depthAnythingDescription": "Generación de mapa de profundidad usando la técnica de Depth Anything",
"duplicate": "Duplicar",
"colorMapDescription": "Genera un mapa de color desde la imagen",
"depthMidasDescription": "Crea un mapa de profundidad con Midas",
"balanced": "Equilibrado",
"beginEndStepPercent": "Inicio / Final Porcentaje de pasos",
"detectResolution": "Detectar resolución",
"beginEndStepPercentShort": "Inicio / Final %"
} }
} }

View File

@ -45,7 +45,7 @@
"outputs": "Risultati", "outputs": "Risultati",
"data": "Dati", "data": "Dati",
"somethingWentWrong": "Qualcosa è andato storto", "somethingWentWrong": "Qualcosa è andato storto",
"copyError": "$t(gallery.copy) Errore", "copyError": "Errore $t(gallery.copy)",
"input": "Ingresso", "input": "Ingresso",
"notInstalled": "Non $t(common.installed)", "notInstalled": "Non $t(common.installed)",
"unknownError": "Errore sconosciuto", "unknownError": "Errore sconosciuto",
@ -85,7 +85,11 @@
"viewing": "Visualizza", "viewing": "Visualizza",
"viewingDesc": "Rivedi le immagini in un'ampia vista della galleria", "viewingDesc": "Rivedi le immagini in un'ampia vista della galleria",
"editing": "Modifica", "editing": "Modifica",
"editingDesc": "Modifica nell'area Livelli di controllo" "editingDesc": "Modifica nell'area Livelli di controllo",
"enabled": "Abilitato",
"disabled": "Disabilitato",
"comparingDesc": "Confronta due immagini",
"comparing": "Confronta"
}, },
"gallery": { "gallery": {
"galleryImageSize": "Dimensione dell'immagine", "galleryImageSize": "Dimensione dell'immagine",
@ -122,14 +126,30 @@
"bulkDownloadRequestedDesc": "La tua richiesta di download è in preparazione. L'operazione potrebbe richiedere alcuni istanti.", "bulkDownloadRequestedDesc": "La tua richiesta di download è in preparazione. L'operazione potrebbe richiedere alcuni istanti.",
"bulkDownloadRequestFailed": "Problema durante la preparazione del download", "bulkDownloadRequestFailed": "Problema durante la preparazione del download",
"bulkDownloadFailed": "Scaricamento fallito", "bulkDownloadFailed": "Scaricamento fallito",
"alwaysShowImageSizeBadge": "Mostra sempre le dimensioni dell'immagine" "alwaysShowImageSizeBadge": "Mostra sempre le dimensioni dell'immagine",
"openInViewer": "Apri nel visualizzatore",
"selectForCompare": "Seleziona per il confronto",
"selectAnImageToCompare": "Seleziona un'immagine da confrontare",
"slider": "Cursore",
"sideBySide": "Fianco a Fianco",
"compareImage": "Immagine di confronto",
"viewerImage": "Immagine visualizzata",
"hover": "Al passaggio del mouse",
"swapImages": "Scambia le immagini",
"compareOptions": "Opzioni di confronto",
"stretchToFit": "Scala per adattare",
"exitCompare": "Esci dal confronto",
"compareHelp1": "Tieni premuto <Kbd>Alt</Kbd> mentre fai clic su un'immagine della galleria o usi i tasti freccia per cambiare l'immagine di confronto.",
"compareHelp2": "Premi <Kbd>M</Kbd> per scorrere le modalità di confronto.",
"compareHelp3": "Premi <Kbd>C</Kbd> per scambiare le immagini confrontate.",
"compareHelp4": "Premi <Kbd>Z</Kbd> o <Kbd>Esc</Kbd> per uscire."
}, },
"hotkeys": { "hotkeys": {
"keyboardShortcuts": "Tasti di scelta rapida", "keyboardShortcuts": "Tasti di scelta rapida",
"appHotkeys": "Applicazione", "appHotkeys": "Applicazione",
"generalHotkeys": "Generale", "generalHotkeys": "Generale",
"galleryHotkeys": "Galleria", "galleryHotkeys": "Galleria",
"unifiedCanvasHotkeys": "Tela Unificata", "unifiedCanvasHotkeys": "Tela",
"invoke": { "invoke": {
"title": "Invoke", "title": "Invoke",
"desc": "Genera un'immagine" "desc": "Genera un'immagine"
@ -147,8 +167,8 @@
"desc": "Apre e chiude il pannello delle opzioni" "desc": "Apre e chiude il pannello delle opzioni"
}, },
"pinOptions": { "pinOptions": {
"title": "Appunta le opzioni", "title": "Fissa le opzioni",
"desc": "Blocca il pannello delle opzioni" "desc": "Fissa il pannello delle opzioni"
}, },
"toggleGallery": { "toggleGallery": {
"title": "Attiva/disattiva galleria", "title": "Attiva/disattiva galleria",
@ -332,14 +352,14 @@
"title": "Annulla e cancella" "title": "Annulla e cancella"
}, },
"resetOptionsAndGallery": { "resetOptionsAndGallery": {
"title": "Ripristina Opzioni e Galleria", "title": "Ripristina le opzioni e la galleria",
"desc": "Reimposta le opzioni e i pannelli della galleria" "desc": "Reimposta i pannelli delle opzioni e della galleria"
}, },
"searchHotkeys": "Cerca tasti di scelta rapida", "searchHotkeys": "Cerca tasti di scelta rapida",
"noHotkeysFound": "Nessun tasto di scelta rapida trovato", "noHotkeysFound": "Nessun tasto di scelta rapida trovato",
"toggleOptionsAndGallery": { "toggleOptionsAndGallery": {
"desc": "Apre e chiude le opzioni e i pannelli della galleria", "desc": "Apre e chiude le opzioni e i pannelli della galleria",
"title": "Attiva/disattiva le Opzioni e la Galleria" "title": "Attiva/disattiva le opzioni e la galleria"
}, },
"clearSearch": "Cancella ricerca", "clearSearch": "Cancella ricerca",
"remixImage": { "remixImage": {
@ -348,7 +368,7 @@
}, },
"toggleViewer": { "toggleViewer": {
"title": "Attiva/disattiva il visualizzatore di immagini", "title": "Attiva/disattiva il visualizzatore di immagini",
"desc": "Passa dal Visualizzatore immagini all'area di lavoro per la scheda corrente." "desc": "Passa dal visualizzatore immagini all'area di lavoro per la scheda corrente."
} }
}, },
"modelManager": { "modelManager": {
@ -378,7 +398,7 @@
"convertToDiffusers": "Converti in Diffusori", "convertToDiffusers": "Converti in Diffusori",
"convertToDiffusersHelpText2": "Questo processo sostituirà la voce in Gestione Modelli con la versione Diffusori dello stesso modello.", "convertToDiffusersHelpText2": "Questo processo sostituirà la voce in Gestione Modelli con la versione Diffusori dello stesso modello.",
"convertToDiffusersHelpText4": "Questo è un processo una tantum. Potrebbero essere necessari circa 30-60 secondi a seconda delle specifiche del tuo computer.", "convertToDiffusersHelpText4": "Questo è un processo una tantum. Potrebbero essere necessari circa 30-60 secondi a seconda delle specifiche del tuo computer.",
"convertToDiffusersHelpText5": "Assicurati di avere spazio su disco sufficiente. I modelli generalmente variano tra 2 GB e 7 GB di dimensioni.", "convertToDiffusersHelpText5": "Assicurati di avere spazio su disco sufficiente. I modelli generalmente variano tra 2 GB e 7 GB in dimensione.",
"convertToDiffusersHelpText6": "Vuoi convertire questo modello?", "convertToDiffusersHelpText6": "Vuoi convertire questo modello?",
"modelConverted": "Modello convertito", "modelConverted": "Modello convertito",
"alpha": "Alpha", "alpha": "Alpha",
@ -528,7 +548,7 @@
"layer": { "layer": {
"initialImageNoImageSelected": "Nessuna immagine iniziale selezionata", "initialImageNoImageSelected": "Nessuna immagine iniziale selezionata",
"t2iAdapterIncompatibleDimensions": "L'adattatore T2I richiede che la dimensione dell'immagine sia un multiplo di {{multiple}}", "t2iAdapterIncompatibleDimensions": "L'adattatore T2I richiede che la dimensione dell'immagine sia un multiplo di {{multiple}}",
"controlAdapterNoModelSelected": "Nessun modello di Adattatore di Controllo selezionato", "controlAdapterNoModelSelected": "Nessun modello di adattatore di controllo selezionato",
"controlAdapterIncompatibleBaseModel": "Il modello base dell'adattatore di controllo non è compatibile", "controlAdapterIncompatibleBaseModel": "Il modello base dell'adattatore di controllo non è compatibile",
"controlAdapterNoImageSelected": "Nessuna immagine dell'adattatore di controllo selezionata", "controlAdapterNoImageSelected": "Nessuna immagine dell'adattatore di controllo selezionata",
"controlAdapterImageNotProcessed": "Immagine dell'adattatore di controllo non elaborata", "controlAdapterImageNotProcessed": "Immagine dell'adattatore di controllo non elaborata",
@ -606,25 +626,25 @@
"canvasMerged": "Tela unita", "canvasMerged": "Tela unita",
"sentToImageToImage": "Inviato a Generazione da immagine", "sentToImageToImage": "Inviato a Generazione da immagine",
"sentToUnifiedCanvas": "Inviato alla Tela", "sentToUnifiedCanvas": "Inviato alla Tela",
"parametersNotSet": "Parametri non impostati", "parametersNotSet": "Parametri non richiamati",
"metadataLoadFailed": "Impossibile caricare i metadati", "metadataLoadFailed": "Impossibile caricare i metadati",
"serverError": "Errore del Server", "serverError": "Errore del Server",
"connected": "Connesso al Server", "connected": "Connesso al server",
"canceled": "Elaborazione annullata", "canceled": "Elaborazione annullata",
"uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG", "uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG",
"parameterSet": "{{parameter}} impostato", "parameterSet": "Parametro richiamato",
"parameterNotSet": "{{parameter}} non impostato", "parameterNotSet": "Parametro non richiamato",
"problemCopyingImage": "Impossibile copiare l'immagine", "problemCopyingImage": "Impossibile copiare l'immagine",
"baseModelChangedCleared_one": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modello incompatibile", "baseModelChangedCleared_one": "Cancellato o disabilitato {{count}} sottomodello incompatibile",
"baseModelChangedCleared_many": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modelli incompatibili", "baseModelChangedCleared_many": "Cancellati o disabilitati {{count}} sottomodelli incompatibili",
"baseModelChangedCleared_other": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modelli incompatibili", "baseModelChangedCleared_other": "Cancellati o disabilitati {{count}} sottomodelli incompatibili",
"imageSavingFailed": "Salvataggio dell'immagine non riuscito", "imageSavingFailed": "Salvataggio dell'immagine non riuscito",
"canvasSentControlnetAssets": "Tela inviata a ControlNet & Risorse", "canvasSentControlnetAssets": "Tela inviata a ControlNet & Risorse",
"problemCopyingCanvasDesc": "Impossibile copiare la tela", "problemCopyingCanvasDesc": "Impossibile copiare la tela",
"loadedWithWarnings": "Flusso di lavoro caricato con avvisi", "loadedWithWarnings": "Flusso di lavoro caricato con avvisi",
"canvasCopiedClipboard": "Tela copiata negli appunti", "canvasCopiedClipboard": "Tela copiata negli appunti",
"maskSavedAssets": "Maschera salvata nelle risorse", "maskSavedAssets": "Maschera salvata nelle risorse",
"problemDownloadingCanvas": "Problema durante il download della tela", "problemDownloadingCanvas": "Problema durante lo scarico della tela",
"problemMergingCanvas": "Problema nell'unione delle tele", "problemMergingCanvas": "Problema nell'unione delle tele",
"imageUploaded": "Immagine caricata", "imageUploaded": "Immagine caricata",
"addedToBoard": "Aggiunto alla bacheca", "addedToBoard": "Aggiunto alla bacheca",
@ -658,7 +678,17 @@
"problemDownloadingImage": "Impossibile scaricare l'immagine", "problemDownloadingImage": "Impossibile scaricare l'immagine",
"prunedQueue": "Coda ripulita", "prunedQueue": "Coda ripulita",
"modelImportCanceled": "Importazione del modello annullata", "modelImportCanceled": "Importazione del modello annullata",
"parameters": "Parametri" "parameters": "Parametri",
"parameterSetDesc": "{{parameter}} richiamato",
"parameterNotSetDesc": "Impossibile richiamare {{parameter}}",
"parameterNotSetDescWithMessage": "Impossibile richiamare {{parameter}}: {{message}}",
"parametersSet": "Parametri richiamati",
"errorCopied": "Errore copiato",
"outOfMemoryError": "Errore di memoria esaurita",
"baseModelChanged": "Modello base modificato",
"sessionRef": "Sessione: {{sessionId}}",
"somethingWentWrong": "Qualcosa è andato storto",
"outOfMemoryErrorDesc": "Le impostazioni della generazione attuale superano la capacità del sistema. Modifica le impostazioni e riprova."
}, },
"tooltip": { "tooltip": {
"feature": { "feature": {
@ -674,7 +704,7 @@
"layer": "Livello", "layer": "Livello",
"base": "Base", "base": "Base",
"mask": "Maschera", "mask": "Maschera",
"maskingOptions": "Opzioni di mascheramento", "maskingOptions": "Opzioni maschera",
"enableMask": "Abilita maschera", "enableMask": "Abilita maschera",
"preserveMaskedArea": "Mantieni area mascherata", "preserveMaskedArea": "Mantieni area mascherata",
"clearMask": "Cancella maschera (Shift+C)", "clearMask": "Cancella maschera (Shift+C)",
@ -745,7 +775,8 @@
"mode": "Modalità", "mode": "Modalità",
"resetUI": "$t(accessibility.reset) l'Interfaccia Utente", "resetUI": "$t(accessibility.reset) l'Interfaccia Utente",
"createIssue": "Segnala un problema", "createIssue": "Segnala un problema",
"about": "Informazioni" "about": "Informazioni",
"submitSupportTicket": "Invia ticket di supporto"
}, },
"nodes": { "nodes": {
"zoomOutNodes": "Rimpicciolire", "zoomOutNodes": "Rimpicciolire",
@ -790,7 +821,7 @@
"workflowNotes": "Note", "workflowNotes": "Note",
"versionUnknown": " Versione sconosciuta", "versionUnknown": " Versione sconosciuta",
"unableToValidateWorkflow": "Impossibile convalidare il flusso di lavoro", "unableToValidateWorkflow": "Impossibile convalidare il flusso di lavoro",
"updateApp": "Aggiorna App", "updateApp": "Aggiorna Applicazione",
"unableToLoadWorkflow": "Impossibile caricare il flusso di lavoro", "unableToLoadWorkflow": "Impossibile caricare il flusso di lavoro",
"updateNode": "Aggiorna nodo", "updateNode": "Aggiorna nodo",
"version": "Versione", "version": "Versione",
@ -882,11 +913,14 @@
"missingNode": "Nodo di invocazione mancante", "missingNode": "Nodo di invocazione mancante",
"missingInvocationTemplate": "Modello di invocazione mancante", "missingInvocationTemplate": "Modello di invocazione mancante",
"missingFieldTemplate": "Modello di campo mancante", "missingFieldTemplate": "Modello di campo mancante",
"singleFieldType": "{{name}} (Singola)" "singleFieldType": "{{name}} (Singola)",
"imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino delle impostazioni predefinite",
"boardAccessError": "Impossibile trovare la bacheca {{board_id}}, ripristino ai valori predefiniti",
"modelAccessError": "Impossibile trovare il modello {{key}}, ripristino ai valori predefiniti"
}, },
"boards": { "boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca", "autoAddBoard": "Aggiungi automaticamente bacheca",
"menuItemAutoAdd": "Aggiungi automaticamente a questa Bacheca", "menuItemAutoAdd": "Aggiungi automaticamente a questa bacheca",
"cancel": "Annulla", "cancel": "Annulla",
"addBoard": "Aggiungi Bacheca", "addBoard": "Aggiungi Bacheca",
"bottomMessage": "L'eliminazione di questa bacheca e delle sue immagini ripristinerà tutte le funzionalità che le stanno attualmente utilizzando.", "bottomMessage": "L'eliminazione di questa bacheca e delle sue immagini ripristinerà tutte le funzionalità che le stanno attualmente utilizzando.",
@ -898,7 +932,7 @@
"myBoard": "Bacheca", "myBoard": "Bacheca",
"searchBoard": "Cerca bacheche ...", "searchBoard": "Cerca bacheche ...",
"noMatching": "Nessuna bacheca corrispondente", "noMatching": "Nessuna bacheca corrispondente",
"selectBoard": "Seleziona una Bacheca", "selectBoard": "Seleziona una bacheca",
"uncategorized": "Non categorizzato", "uncategorized": "Non categorizzato",
"downloadBoard": "Scarica la bacheca", "downloadBoard": "Scarica la bacheca",
"deleteBoardOnly": "solo la Bacheca", "deleteBoardOnly": "solo la Bacheca",
@ -919,7 +953,7 @@
"control": "Controllo", "control": "Controllo",
"crop": "Ritaglia", "crop": "Ritaglia",
"depthMidas": "Profondità (Midas)", "depthMidas": "Profondità (Midas)",
"detectResolution": "Rileva risoluzione", "detectResolution": "Rileva la risoluzione",
"controlMode": "Modalità di controllo", "controlMode": "Modalità di controllo",
"cannyDescription": "Canny rilevamento bordi", "cannyDescription": "Canny rilevamento bordi",
"depthZoe": "Profondità (Zoe)", "depthZoe": "Profondità (Zoe)",
@ -930,7 +964,7 @@
"showAdvanced": "Mostra opzioni Avanzate", "showAdvanced": "Mostra opzioni Avanzate",
"bgth": "Soglia rimozione sfondo", "bgth": "Soglia rimozione sfondo",
"importImageFromCanvas": "Importa immagine dalla Tela", "importImageFromCanvas": "Importa immagine dalla Tela",
"lineartDescription": "Converte l'immagine in lineart", "lineartDescription": "Converte l'immagine in linea",
"importMaskFromCanvas": "Importa maschera dalla Tela", "importMaskFromCanvas": "Importa maschera dalla Tela",
"hideAdvanced": "Nascondi opzioni avanzate", "hideAdvanced": "Nascondi opzioni avanzate",
"resetControlImage": "Reimposta immagine di controllo", "resetControlImage": "Reimposta immagine di controllo",
@ -946,7 +980,7 @@
"pidiDescription": "Elaborazione immagini PIDI", "pidiDescription": "Elaborazione immagini PIDI",
"fill": "Riempie", "fill": "Riempie",
"colorMapDescription": "Genera una mappa dei colori dall'immagine", "colorMapDescription": "Genera una mappa dei colori dall'immagine",
"lineartAnimeDescription": "Elaborazione lineart in stile anime", "lineartAnimeDescription": "Elaborazione linea in stile anime",
"imageResolution": "Risoluzione dell'immagine", "imageResolution": "Risoluzione dell'immagine",
"colorMap": "Colore", "colorMap": "Colore",
"lowThreshold": "Soglia inferiore", "lowThreshold": "Soglia inferiore",

View File

@ -87,7 +87,11 @@
"viewing": "Просмотр", "viewing": "Просмотр",
"editing": "Редактирование", "editing": "Редактирование",
"viewingDesc": "Просмотр изображений в режиме большой галереи", "viewingDesc": "Просмотр изображений в режиме большой галереи",
"editingDesc": "Редактировать на холсте слоёв управления" "editingDesc": "Редактировать на холсте слоёв управления",
"enabled": "Включено",
"disabled": "Отключено",
"comparingDesc": "Сравнение двух изображений",
"comparing": "Сравнение"
}, },
"gallery": { "gallery": {
"galleryImageSize": "Размер изображений", "galleryImageSize": "Размер изображений",
@ -124,7 +128,23 @@
"bulkDownloadRequested": "Подготовка к скачиванию", "bulkDownloadRequested": "Подготовка к скачиванию",
"bulkDownloadRequestedDesc": "Ваш запрос на скачивание готовится. Это может занять несколько минут.", "bulkDownloadRequestedDesc": "Ваш запрос на скачивание готовится. Это может занять несколько минут.",
"bulkDownloadRequestFailed": "Возникла проблема при подготовке скачивания", "bulkDownloadRequestFailed": "Возникла проблема при подготовке скачивания",
"alwaysShowImageSizeBadge": "Всегда показывать значок размера изображения" "alwaysShowImageSizeBadge": "Всегда показывать значок размера изображения",
"openInViewer": "Открыть в просмотрщике",
"selectForCompare": "Выбрать для сравнения",
"hover": "Наведение",
"swapImages": "Поменять местами",
"stretchToFit": "Растягивание до нужного размера",
"exitCompare": "Выйти из сравнения",
"compareHelp4": "Нажмите <Kbd>Z</Kbd> или <Kbd>Esc</Kbd> для выхода.",
"compareImage": "Сравнить изображение",
"viewerImage": "Изображение просмотрщика",
"selectAnImageToCompare": "Выберите изображение для сравнения",
"slider": "Слайдер",
"sideBySide": "Бок о бок",
"compareOptions": "Варианты сравнения",
"compareHelp1": "Удерживайте <Kbd>Alt</Kbd> при нажатии на изображение в галерее или при помощи клавиш со стрелками, чтобы изменить сравниваемое изображение.",
"compareHelp2": "Нажмите <Kbd>M</Kbd>, чтобы переключиться между режимами сравнения.",
"compareHelp3": "Нажмите <Kbd>C</Kbd>, чтобы поменять местами сравниваемые изображения."
}, },
"hotkeys": { "hotkeys": {
"keyboardShortcuts": "Горячие клавиши", "keyboardShortcuts": "Горячие клавиши",
@ -528,7 +548,20 @@
"missingFieldTemplate": "Отсутствует шаблон поля", "missingFieldTemplate": "Отсутствует шаблон поля",
"addingImagesTo": "Добавление изображений в", "addingImagesTo": "Добавление изображений в",
"invoke": "Создать", "invoke": "Создать",
"imageNotProcessedForControlAdapter": "Изображение адаптера контроля №{{number}} не обрабатывается" "imageNotProcessedForControlAdapter": "Изображение адаптера контроля №{{number}} не обрабатывается",
"layer": {
"controlAdapterImageNotProcessed": "Изображение адаптера контроля не обработано",
"ipAdapterNoModelSelected": "IP адаптер не выбран",
"controlAdapterNoModelSelected": "не выбрана модель адаптера контроля",
"controlAdapterIncompatibleBaseModel": "несовместимая базовая модель адаптера контроля",
"controlAdapterNoImageSelected": "не выбрано изображение контрольного адаптера",
"initialImageNoImageSelected": "начальное изображение не выбрано",
"rgNoRegion": "регион не выбран",
"rgNoPromptsOrIPAdapters": "нет текстовых запросов или IP-адаптеров",
"ipAdapterIncompatibleBaseModel": "несовместимая базовая модель IP-адаптера",
"t2iAdapterIncompatibleDimensions": "Адаптер T2I требует, чтобы размеры изображения были кратны {{multiple}}",
"ipAdapterNoImageSelected": "изображение IP-адаптера не выбрано"
}
}, },
"isAllowedToUpscale": { "isAllowedToUpscale": {
"useX2Model": "Изображение слишком велико для увеличения с помощью модели x4. Используйте модель x2", "useX2Model": "Изображение слишком велико для увеличения с помощью модели x4. Используйте модель x2",
@ -606,12 +639,12 @@
"connected": "Подключено к серверу", "connected": "Подключено к серверу",
"canceled": "Обработка отменена", "canceled": "Обработка отменена",
"uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG", "uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG",
"parameterNotSet": "Параметр {{parameter}} не задан", "parameterNotSet": "Параметр не задан",
"parameterSet": "Параметр {{parameter}} задан", "parameterSet": "Параметр задан",
"problemCopyingImage": "Не удается скопировать изображение", "problemCopyingImage": "Не удается скопировать изображение",
"baseModelChangedCleared_one": "Базовая модель изменила, очистила или отключила {{count}} несовместимую подмодель", "baseModelChangedCleared_one": "Очищена или отключена {{count}} несовместимая подмодель",
"baseModelChangedCleared_few": "Базовая модель изменила, очистила или отключила {{count}} несовместимые подмодели", "baseModelChangedCleared_few": "Очищены или отключены {{count}} несовместимые подмодели",
"baseModelChangedCleared_many": "Базовая модель изменила, очистила или отключила {{count}} несовместимых подмоделей", "baseModelChangedCleared_many": "Очищены или отключены {{count}} несовместимых подмоделей",
"imageSavingFailed": "Не удалось сохранить изображение", "imageSavingFailed": "Не удалось сохранить изображение",
"canvasSentControlnetAssets": "Холст отправлен в ControlNet и ресурсы", "canvasSentControlnetAssets": "Холст отправлен в ControlNet и ресурсы",
"problemCopyingCanvasDesc": "Невозможно экспортировать базовый слой", "problemCopyingCanvasDesc": "Невозможно экспортировать базовый слой",
@ -652,7 +685,17 @@
"resetInitialImage": "Сбросить начальное изображение", "resetInitialImage": "Сбросить начальное изображение",
"prunedQueue": "Урезанная очередь", "prunedQueue": "Урезанная очередь",
"modelImportCanceled": "Импорт модели отменен", "modelImportCanceled": "Импорт модели отменен",
"parameters": "Параметры" "parameters": "Параметры",
"parameterSetDesc": "Задан {{parameter}}",
"parameterNotSetDesc": "Невозможно задать {{parameter}}",
"baseModelChanged": "Базовая модель сменена",
"parameterNotSetDescWithMessage": "Не удалось задать {{parameter}}: {{message}}",
"parametersSet": "Параметры заданы",
"errorCopied": "Ошибка скопирована",
"sessionRef": "Сессия: {{sessionId}}",
"outOfMemoryError": "Ошибка нехватки памяти",
"outOfMemoryErrorDesc": "Ваши текущие настройки генерации превышают возможности системы. Пожалуйста, измените настройки и повторите попытку.",
"somethingWentWrong": "Что-то пошло не так"
}, },
"tooltip": { "tooltip": {
"feature": { "feature": {
@ -739,7 +782,8 @@
"loadMore": "Загрузить больше", "loadMore": "Загрузить больше",
"resetUI": "$t(accessibility.reset) интерфейс", "resetUI": "$t(accessibility.reset) интерфейс",
"createIssue": "Сообщить о проблеме", "createIssue": "Сообщить о проблеме",
"about": "Об этом" "about": "Об этом",
"submitSupportTicket": "Отправить тикет в службу поддержки"
}, },
"nodes": { "nodes": {
"zoomInNodes": "Увеличьте масштаб", "zoomInNodes": "Увеличьте масштаб",
@ -832,7 +876,7 @@
"workflowName": "Название", "workflowName": "Название",
"collection": "Коллекция", "collection": "Коллекция",
"unknownErrorValidatingWorkflow": "Неизвестная ошибка при проверке рабочего процесса", "unknownErrorValidatingWorkflow": "Неизвестная ошибка при проверке рабочего процесса",
"collectionFieldType": "Коллекция {{name}}", "collectionFieldType": "{{name}} (Коллекция)",
"workflowNotes": "Примечания", "workflowNotes": "Примечания",
"string": "Строка", "string": "Строка",
"unknownNodeType": "Неизвестный тип узла", "unknownNodeType": "Неизвестный тип узла",
@ -848,7 +892,7 @@
"targetNodeDoesNotExist": "Недопустимое ребро: целевой/входной узел {{node}} не существует", "targetNodeDoesNotExist": "Недопустимое ребро: целевой/входной узел {{node}} не существует",
"mismatchedVersion": "Недопустимый узел: узел {{node}} типа {{type}} имеет несоответствующую версию (попробовать обновить?)", "mismatchedVersion": "Недопустимый узел: узел {{node}} типа {{type}} имеет несоответствующую версию (попробовать обновить?)",
"unknownFieldType": "$t(nodes.unknownField) тип: {{type}}", "unknownFieldType": "$t(nodes.unknownField) тип: {{type}}",
"collectionOrScalarFieldType": "Коллекция | Скаляр {{name}}", "collectionOrScalarFieldType": "{{name}} (Один или коллекция)",
"betaDesc": "Этот вызов находится в бета-версии. Пока он не станет стабильным, в нем могут происходить изменения при обновлении приложений. Мы планируем поддерживать этот вызов в течение длительного времени.", "betaDesc": "Этот вызов находится в бета-версии. Пока он не станет стабильным, в нем могут происходить изменения при обновлении приложений. Мы планируем поддерживать этот вызов в течение длительного времени.",
"nodeVersion": "Версия узла", "nodeVersion": "Версия узла",
"loadingNodes": "Загрузка узлов...", "loadingNodes": "Загрузка узлов...",
@ -870,7 +914,16 @@
"noFieldsViewMode": "В этом рабочем процессе нет выбранных полей для отображения. Просмотрите полный рабочий процесс для настройки значений.", "noFieldsViewMode": "В этом рабочем процессе нет выбранных полей для отображения. Просмотрите полный рабочий процесс для настройки значений.",
"graph": "График", "graph": "График",
"showEdgeLabels": "Показать метки на ребрах", "showEdgeLabels": "Показать метки на ребрах",
"showEdgeLabelsHelp": "Показать метки на ребрах, указывающие на соединенные узлы" "showEdgeLabelsHelp": "Показать метки на ребрах, указывающие на соединенные узлы",
"cannotMixAndMatchCollectionItemTypes": "Невозможно смешивать и сопоставлять типы элементов коллекции",
"missingNode": "Отсутствует узел вызова",
"missingInvocationTemplate": "Отсутствует шаблон вызова",
"missingFieldTemplate": "Отсутствующий шаблон поля",
"singleFieldType": "{{name}} (Один)",
"noGraph": "Нет графика",
"imageAccessError": "Невозможно найти изображение {{image_name}}, сбрасываем на значение по умолчанию",
"boardAccessError": "Невозможно найти доску {{board_id}}, сбрасываем на значение по умолчанию",
"modelAccessError": "Невозможно найти модель {{key}}, сброс на модель по умолчанию"
}, },
"controlnet": { "controlnet": {
"amult": "a_mult", "amult": "a_mult",
@ -1441,7 +1494,16 @@
"clearQueueAlertDialog2": "Вы уверены, что хотите очистить очередь?", "clearQueueAlertDialog2": "Вы уверены, что хотите очистить очередь?",
"item": "Элемент", "item": "Элемент",
"graphFailedToQueue": "Не удалось поставить график в очередь", "graphFailedToQueue": "Не удалось поставить график в очередь",
"openQueue": "Открыть очередь" "openQueue": "Открыть очередь",
"prompts_one": "Запрос",
"prompts_few": "Запроса",
"prompts_many": "Запросов",
"iterations_one": "Итерация",
"iterations_few": "Итерации",
"iterations_many": "Итераций",
"generations_one": "Генерация",
"generations_few": "Генерации",
"generations_many": "Генераций"
}, },
"sdxl": { "sdxl": {
"refinerStart": "Запуск доработчика", "refinerStart": "Запуск доработчика",

View File

@ -1,6 +1,6 @@
{ {
"common": { "common": {
"nodes": "節點", "nodes": "工作流程",
"img2img": "圖片轉圖片", "img2img": "圖片轉圖片",
"statusDisconnected": "已中斷連線", "statusDisconnected": "已中斷連線",
"back": "返回", "back": "返回",
@ -11,17 +11,239 @@
"reportBugLabel": "回報錯誤", "reportBugLabel": "回報錯誤",
"githubLabel": "GitHub", "githubLabel": "GitHub",
"hotkeysLabel": "快捷鍵", "hotkeysLabel": "快捷鍵",
"languagePickerLabel": "切換語言", "languagePickerLabel": "語言",
"unifiedCanvas": "統一畫布", "unifiedCanvas": "統一畫布",
"cancel": "取消", "cancel": "取消",
"txt2img": "文字轉圖片" "txt2img": "文字轉圖片",
"controlNet": "ControlNet",
"advanced": "進階",
"folder": "資料夾",
"installed": "已安裝",
"accept": "接受",
"goTo": "前往",
"input": "輸入",
"random": "隨機",
"selected": "已選擇",
"communityLabel": "社群",
"loading": "載入中",
"delete": "刪除",
"copy": "複製",
"error": "錯誤",
"file": "檔案",
"format": "格式",
"imageFailedToLoad": "無法載入圖片"
}, },
"accessibility": { "accessibility": {
"invokeProgressBar": "Invoke 進度條", "invokeProgressBar": "Invoke 進度條",
"uploadImage": "上傳圖片", "uploadImage": "上傳圖片",
"reset": "重設", "reset": "重",
"nextImage": "下一張圖片", "nextImage": "下一張圖片",
"previousImage": "上一張圖片", "previousImage": "上一張圖片",
"menu": "選單" "menu": "選單",
"loadMore": "載入更多",
"about": "關於",
"createIssue": "建立問題",
"resetUI": "$t(accessibility.reset) 介面",
"submitSupportTicket": "提交支援工單",
"mode": "模式"
},
"boards": {
"loading": "載入中…",
"movingImagesToBoard_other": "正在移動 {{count}} 張圖片至板上:",
"move": "移動",
"uncategorized": "未分類",
"cancel": "取消"
},
"metadata": {
"workflow": "工作流程",
"steps": "步數",
"model": "模型",
"seed": "種子",
"vae": "VAE",
"seamless": "無縫",
"metadata": "元數據",
"width": "寬度",
"height": "高度"
},
"accordions": {
"control": {
"title": "控制"
},
"compositing": {
"title": "合成"
},
"advanced": {
"title": "進階",
"options": "$t(accordions.advanced.title) 選項"
}
},
"hotkeys": {
"nodesHotkeys": "節點",
"cancel": {
"title": "取消"
},
"generalHotkeys": "一般",
"keyboardShortcuts": "快捷鍵",
"appHotkeys": "應用程式"
},
"modelManager": {
"advanced": "進階",
"allModels": "全部模型",
"variant": "變體",
"config": "配置",
"model": "模型",
"selected": "已選擇",
"huggingFace": "HuggingFace",
"install": "安裝",
"metadata": "元數據",
"delete": "刪除",
"description": "描述",
"cancel": "取消",
"convert": "轉換",
"manual": "手動",
"none": "無",
"name": "名稱",
"load": "載入",
"height": "高度",
"width": "寬度",
"search": "搜尋",
"vae": "VAE",
"settings": "設定"
},
"controlnet": {
"mlsd": "M-LSD",
"canny": "Canny",
"duplicate": "重複",
"none": "無",
"pidi": "PIDI",
"h": "H",
"balanced": "平衡",
"crop": "裁切",
"processor": "處理器",
"control": "控制",
"f": "F",
"lineart": "線條藝術",
"w": "W",
"hed": "HED",
"delete": "刪除"
},
"queue": {
"queue": "佇列",
"canceled": "已取消",
"failed": "已失敗",
"completed": "已完成",
"cancel": "取消",
"session": "工作階段",
"batch": "批量",
"item": "項目",
"completedIn": "完成於",
"notReady": "無法排隊"
},
"parameters": {
"cancel": {
"cancel": "取消"
},
"height": "高度",
"type": "類型",
"symmetry": "對稱性",
"images": "圖片",
"width": "寬度",
"coherenceMode": "模式",
"seed": "種子",
"general": "一般",
"strength": "強度",
"steps": "步數",
"info": "資訊"
},
"settings": {
"beta": "Beta",
"developer": "開發者",
"general": "一般",
"models": "模型"
},
"popovers": {
"paramModel": {
"heading": "模型"
},
"compositingCoherenceMode": {
"heading": "模式"
},
"paramSteps": {
"heading": "步數"
},
"controlNetProcessor": {
"heading": "處理器"
},
"paramVAE": {
"heading": "VAE"
},
"paramHeight": {
"heading": "高度"
},
"paramSeed": {
"heading": "種子"
},
"paramWidth": {
"heading": "寬度"
},
"refinerSteps": {
"heading": "步數"
}
},
"unifiedCanvas": {
"undo": "復原",
"mask": "遮罩",
"eraser": "橡皮擦",
"antialiasing": "抗鋸齒",
"redo": "重做",
"layer": "圖層",
"accept": "接受",
"brush": "刷子",
"move": "移動",
"brushSize": "大小"
},
"nodes": {
"workflowName": "名稱",
"notes": "註釋",
"workflowVersion": "版本",
"workflowNotes": "註釋",
"executionStateError": "錯誤",
"unableToUpdateNodes_other": "無法更新 {{count}} 個節點",
"integer": "整數",
"workflow": "工作流程",
"enum": "枚舉",
"edit": "編輯",
"string": "字串",
"workflowTags": "標籤",
"node": "節點",
"boolean": "布林值",
"workflowAuthor": "作者",
"version": "版本",
"executionStateCompleted": "已完成",
"edge": "邊緣",
"versionUnknown": " 版本未知"
},
"sdxl": {
"steps": "步數",
"loading": "載入中…",
"refiner": "精煉器"
},
"gallery": {
"copy": "複製",
"download": "下載",
"loading": "載入中"
},
"ui": {
"tabs": {
"models": "模型",
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
"queue": "佇列"
}
},
"models": {
"loading": "載入中"
},
"workflows": {
"name": "名稱"
} }
} }

View File

@ -22,7 +22,13 @@ import type { BatchConfig } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions'; import { socketInvocationComplete } from 'services/events/actions';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
const matcher = isAnyOf(caLayerImageChanged, caLayerProcessorConfigChanged, caLayerModelChanged, caLayerRecalled); const matcher = isAnyOf(
caLayerImageChanged,
caLayerProcessedImageChanged,
caLayerProcessorConfigChanged,
caLayerModelChanged,
caLayerRecalled
);
const DEBOUNCE_MS = 300; const DEBOUNCE_MS = 300;
const log = logger('session'); const log = logger('session');
@ -73,9 +79,10 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
const originalConfig = originalLayer?.controlAdapter.processorConfig; const originalConfig = originalLayer?.controlAdapter.processorConfig;
const image = layer.controlAdapter.image; const image = layer.controlAdapter.image;
const processedImage = layer.controlAdapter.processedImage;
const config = layer.controlAdapter.processorConfig; const config = layer.controlAdapter.processorConfig;
if (isEqual(config, originalConfig) && isEqual(image, originalImage)) { if (isEqual(config, originalConfig) && isEqual(image, originalImage) && processedImage) {
// Neither config nor image have changed, we can bail // Neither config nor image have changed, we can bail
return; return;
} }

View File

@ -5,15 +5,86 @@ import {
socketModelInstallCancelled, socketModelInstallCancelled,
socketModelInstallComplete, socketModelInstallComplete,
socketModelInstallDownloadProgress, socketModelInstallDownloadProgress,
socketModelInstallDownloadsComplete,
socketModelInstallDownloadStarted,
socketModelInstallError, socketModelInstallError,
socketModelInstallStarted,
} from 'services/events/actions'; } from 'services/events/actions';
/**
* A model install has two main stages - downloading and installing. All these events are namespaced under `model_install_`
* which is a bit misleading. For example, a `model_install_started` event is actually fired _after_ the model has fully
* downloaded and is being "physically" installed.
*
* Note: the download events are only fired for remote model installs, not local.
*
* Here's the expected flow:
* - API receives install request, model manager preps the install
* - `model_install_download_started` fired when the download starts
* - `model_install_download_progress` fired continually until the download is complete
* - `model_install_download_complete` fired when the download is complete
* - `model_install_started` fired when the "physical" installation starts
* - `model_install_complete` fired when the installation is complete
* - `model_install_cancelled` fired if the installation is cancelled
* - `model_install_error` fired if the installation has an error
*/
const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select();
export const addModelInstallEventListener = (startAppListening: AppStartListening) => { export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
actionCreator: socketModelInstallDownloadProgress, actionCreator: socketModelInstallDownloadStarted,
effect: async (action, { dispatch }) => { effect: async (action, { dispatch, getState }) => {
const { bytes, total_bytes, id } = action.payload.data; const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'downloading';
}
return draft;
})
);
}
},
});
startAppListening({
actionCreator: socketModelInstallStarted,
effect: async (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'running';
}
return draft;
})
);
}
},
});
startAppListening({
actionCreator: socketModelInstallDownloadProgress,
effect: async (action, { dispatch, getState }) => {
const { bytes, total_bytes, id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch( dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
@ -25,14 +96,20 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
return draft; return draft;
}) })
); );
}
}, },
}); });
startAppListening({ startAppListening({
actionCreator: socketModelInstallComplete, actionCreator: socketModelInstallComplete,
effect: (action, { dispatch }) => { effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data; const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch( dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
@ -42,6 +119,8 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
return draft; return draft;
}) })
); );
}
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }])); dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }])); dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
}, },
@ -49,9 +128,13 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
startAppListening({ startAppListening({
actionCreator: socketModelInstallError, actionCreator: socketModelInstallError,
effect: (action, { dispatch }) => { effect: (action, { dispatch, getState }) => {
const { id, error, error_type } = action.payload.data; const { id, error, error_type } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch( dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
@ -63,14 +146,19 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
return draft; return draft;
}) })
); );
}
}, },
}); });
startAppListening({ startAppListening({
actionCreator: socketModelInstallCancelled, actionCreator: socketModelInstallCancelled,
effect: (action, { dispatch }) => { effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data; const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch( dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
@ -80,6 +168,29 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
return draft; return draft;
}) })
); );
}
},
});
startAppListening({
actionCreator: socketModelInstallDownloadsComplete,
effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'downloads_done';
}
return draft;
})
);
}
}, },
}); });
}; };

View File

@ -4,6 +4,7 @@ import {
caLayerControlModeChanged, caLayerControlModeChanged,
caLayerImageChanged, caLayerImageChanged,
caLayerModelChanged, caLayerModelChanged,
caLayerProcessedImageChanged,
caLayerProcessorConfigChanged, caLayerProcessorConfigChanged,
caOrIPALayerBeginEndStepPctChanged, caOrIPALayerBeginEndStepPctChanged,
caOrIPALayerWeightChanged, caOrIPALayerWeightChanged,
@ -84,6 +85,14 @@ export const CALayerControlAdapterWrapper = memo(({ layerId }: Props) => {
[dispatch, layerId] [dispatch, layerId]
); );
const onErrorLoadingImage = useCallback(() => {
dispatch(caLayerImageChanged({ layerId, imageDTO: null }));
}, [dispatch, layerId]);
const onErrorLoadingProcessedImage = useCallback(() => {
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null }));
}, [dispatch, layerId]);
const droppableData = useMemo<CALayerImageDropData>( const droppableData = useMemo<CALayerImageDropData>(
() => ({ () => ({
actionType: 'SET_CA_LAYER_IMAGE', actionType: 'SET_CA_LAYER_IMAGE',
@ -114,6 +123,8 @@ export const CALayerControlAdapterWrapper = memo(({ layerId }: Props) => {
onChangeImage={onChangeImage} onChangeImage={onChangeImage}
droppableData={droppableData} droppableData={droppableData}
postUploadAction={postUploadAction} postUploadAction={postUploadAction}
onErrorLoadingImage={onErrorLoadingImage}
onErrorLoadingProcessedImage={onErrorLoadingProcessedImage}
/> />
); );
}); });

View File

@ -28,6 +28,8 @@ type Props = {
onChangeProcessorConfig: (processorConfig: ProcessorConfig | null) => void; onChangeProcessorConfig: (processorConfig: ProcessorConfig | null) => void;
onChangeModel: (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => void; onChangeModel: (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => void;
onChangeImage: (imageDTO: ImageDTO | null) => void; onChangeImage: (imageDTO: ImageDTO | null) => void;
onErrorLoadingImage: () => void;
onErrorLoadingProcessedImage: () => void;
droppableData: TypesafeDroppableData; droppableData: TypesafeDroppableData;
postUploadAction: PostUploadAction; postUploadAction: PostUploadAction;
}; };
@ -41,6 +43,8 @@ export const ControlAdapter = memo(
onChangeProcessorConfig, onChangeProcessorConfig,
onChangeModel, onChangeModel,
onChangeImage, onChangeImage,
onErrorLoadingImage,
onErrorLoadingProcessedImage,
droppableData, droppableData,
postUploadAction, postUploadAction,
}: Props) => { }: Props) => {
@ -91,6 +95,8 @@ export const ControlAdapter = memo(
onChangeImage={onChangeImage} onChangeImage={onChangeImage}
droppableData={droppableData} droppableData={droppableData}
postUploadAction={postUploadAction} postUploadAction={postUploadAction}
onErrorLoadingImage={onErrorLoadingImage}
onErrorLoadingProcessedImage={onErrorLoadingProcessedImage}
/> />
</Flex> </Flex>
</Flex> </Flex>

View File

@ -27,10 +27,19 @@ type Props = {
onChangeImage: (imageDTO: ImageDTO | null) => void; onChangeImage: (imageDTO: ImageDTO | null) => void;
droppableData: TypesafeDroppableData; droppableData: TypesafeDroppableData;
postUploadAction: PostUploadAction; postUploadAction: PostUploadAction;
onErrorLoadingImage: () => void;
onErrorLoadingProcessedImage: () => void;
}; };
export const ControlAdapterImagePreview = memo( export const ControlAdapterImagePreview = memo(
({ controlAdapter, onChangeImage, droppableData, postUploadAction }: Props) => { ({
controlAdapter,
onChangeImage,
droppableData,
postUploadAction,
onErrorLoadingImage,
onErrorLoadingProcessedImage,
}: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId); const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
@ -128,10 +137,23 @@ export const ControlAdapterImagePreview = memo(
controlAdapter.processorConfig !== null; controlAdapter.processorConfig !== null;
useEffect(() => { useEffect(() => {
if (isConnected && (isErrorControlImage || isErrorProcessedControlImage)) { if (!isConnected) {
handleResetControlImage(); return;
} }
}, [handleResetControlImage, isConnected, isErrorControlImage, isErrorProcessedControlImage]); if (isErrorControlImage) {
onErrorLoadingImage();
}
if (isErrorProcessedControlImage) {
onErrorLoadingProcessedImage();
}
}, [
handleResetControlImage,
isConnected,
isErrorControlImage,
isErrorProcessedControlImage,
onErrorLoadingImage,
onErrorLoadingProcessedImage,
]);
return ( return (
<Flex <Flex
@ -167,6 +189,7 @@ export const ControlAdapterImagePreview = memo(
droppableData={droppableData} droppableData={droppableData}
imageDTO={processedControlImage} imageDTO={processedControlImage}
isUploadDisabled={true} isUploadDisabled={true}
onError={handleResetControlImage}
/> />
</Box> </Box>

View File

@ -4,20 +4,35 @@ import { createSelector } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useMouseEvents } from 'features/controlLayers/hooks/mouseEventHooks'; import { BRUSH_SPACING_PCT, MAX_BRUSH_SPACING_PX, MIN_BRUSH_SPACING_PX } from 'features/controlLayers/konva/constants';
import { setStageEventHandlers } from 'features/controlLayers/konva/events';
import { debouncedRenderers, renderers as normalRenderers } from 'features/controlLayers/konva/renderers';
import { import {
$brushSize,
$brushSpacingPx,
$isDrawing,
$lastAddedPoint,
$lastCursorPos, $lastCursorPos,
$lastMouseDownPos, $lastMouseDownPos,
$selectedLayerId,
$selectedLayerType,
$shouldInvertBrushSizeScrollDirection,
$tool, $tool,
brushSizeChanged,
isRegionalGuidanceLayer, isRegionalGuidanceLayer,
layerBboxChanged, layerBboxChanged,
layerTranslated, layerTranslated,
rgLayerLineAdded,
rgLayerPointsAdded,
rgLayerRectAdded,
selectControlLayersSlice, selectControlLayersSlice,
} from 'features/controlLayers/store/controlLayersSlice'; } from 'features/controlLayers/store/controlLayersSlice';
import { debouncedRenderers, renderers as normalRenderers } from 'features/controlLayers/util/renderers'; import type { AddLineArg, AddPointToLineArg, AddRectArg } from 'features/controlLayers/store/types';
import Konva from 'konva'; import Konva from 'konva';
import type { IRect } from 'konva/lib/types'; import type { IRect } from 'konva/lib/types';
import { clamp } from 'lodash-es';
import { memo, useCallback, useLayoutEffect, useMemo, useState } from 'react'; import { memo, useCallback, useLayoutEffect, useMemo, useState } from 'react';
import { getImageDTO } from 'services/api/endpoints/images';
import { useDevicePixelRatio } from 'use-device-pixel-ratio'; import { useDevicePixelRatio } from 'use-device-pixel-ratio';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
@ -47,7 +62,6 @@ const useStageRenderer = (
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const state = useAppSelector((s) => s.controlLayers.present); const state = useAppSelector((s) => s.controlLayers.present);
const tool = useStore($tool); const tool = useStore($tool);
const mouseEventHandlers = useMouseEvents();
const lastCursorPos = useStore($lastCursorPos); const lastCursorPos = useStore($lastCursorPos);
const lastMouseDownPos = useStore($lastMouseDownPos); const lastMouseDownPos = useStore($lastMouseDownPos);
const selectedLayerIdColor = useAppSelector(selectSelectedLayerColor); const selectedLayerIdColor = useAppSelector(selectSelectedLayerColor);
@ -56,6 +70,26 @@ const useStageRenderer = (
const layerCount = useMemo(() => state.layers.length, [state.layers]); const layerCount = useMemo(() => state.layers.length, [state.layers]);
const renderers = useMemo(() => (asPreview ? debouncedRenderers : normalRenderers), [asPreview]); const renderers = useMemo(() => (asPreview ? debouncedRenderers : normalRenderers), [asPreview]);
const dpr = useDevicePixelRatio({ round: false }); const dpr = useDevicePixelRatio({ round: false });
const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
const brushSpacingPx = useMemo(
() => clamp(state.brushSize / BRUSH_SPACING_PCT, MIN_BRUSH_SPACING_PX, MAX_BRUSH_SPACING_PX),
[state.brushSize]
);
useLayoutEffect(() => {
$brushSize.set(state.brushSize);
$brushSpacingPx.set(brushSpacingPx);
$selectedLayerId.set(state.selectedLayerId);
$selectedLayerType.set(selectedLayerType);
$shouldInvertBrushSizeScrollDirection.set(shouldInvertBrushSizeScrollDirection);
}, [
brushSpacingPx,
selectedLayerIdColor,
selectedLayerType,
shouldInvertBrushSizeScrollDirection,
state.brushSize,
state.selectedLayerId,
]);
const onLayerPosChanged = useCallback( const onLayerPosChanged = useCallback(
(layerId: string, x: number, y: number) => { (layerId: string, x: number, y: number) => {
@ -71,6 +105,31 @@ const useStageRenderer = (
[dispatch] [dispatch]
); );
const onRGLayerLineAdded = useCallback(
(arg: AddLineArg) => {
dispatch(rgLayerLineAdded(arg));
},
[dispatch]
);
const onRGLayerPointAddedToLine = useCallback(
(arg: AddPointToLineArg) => {
dispatch(rgLayerPointsAdded(arg));
},
[dispatch]
);
const onRGLayerRectAdded = useCallback(
(arg: AddRectArg) => {
dispatch(rgLayerRectAdded(arg));
},
[dispatch]
);
const onBrushSizeChanged = useCallback(
(size: number) => {
dispatch(brushSizeChanged(size));
},
[dispatch]
);
useLayoutEffect(() => { useLayoutEffect(() => {
log.trace('Initializing stage'); log.trace('Initializing stage');
if (!container) { if (!container) {
@ -88,21 +147,29 @@ const useStageRenderer = (
if (asPreview) { if (asPreview) {
return; return;
} }
stage.on('mousedown', mouseEventHandlers.onMouseDown); const cleanup = setStageEventHandlers({
stage.on('mouseup', mouseEventHandlers.onMouseUp); stage,
stage.on('mousemove', mouseEventHandlers.onMouseMove); $tool,
stage.on('mouseleave', mouseEventHandlers.onMouseLeave); $isDrawing,
stage.on('wheel', mouseEventHandlers.onMouseWheel); $lastMouseDownPos,
$lastCursorPos,
$lastAddedPoint,
$brushSize,
$brushSpacingPx,
$selectedLayerId,
$selectedLayerType,
$shouldInvertBrushSizeScrollDirection,
onRGLayerLineAdded,
onRGLayerPointAddedToLine,
onRGLayerRectAdded,
onBrushSizeChanged,
});
return () => { return () => {
log.trace('Cleaning up stage listeners'); log.trace('Removing stage listeners');
stage.off('mousedown', mouseEventHandlers.onMouseDown); cleanup();
stage.off('mouseup', mouseEventHandlers.onMouseUp);
stage.off('mousemove', mouseEventHandlers.onMouseMove);
stage.off('mouseleave', mouseEventHandlers.onMouseLeave);
stage.off('wheel', mouseEventHandlers.onMouseWheel);
}; };
}, [stage, asPreview, mouseEventHandlers]); }, [asPreview, onBrushSizeChanged, onRGLayerLineAdded, onRGLayerPointAddedToLine, onRGLayerRectAdded, stage]);
useLayoutEffect(() => { useLayoutEffect(() => {
log.trace('Updating stage dimensions'); log.trace('Updating stage dimensions');
@ -160,7 +227,7 @@ const useStageRenderer = (
useLayoutEffect(() => { useLayoutEffect(() => {
log.trace('Rendering layers'); log.trace('Rendering layers');
renderers.renderLayers(stage, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged); renderers.renderLayers(stage, state.layers, state.globalMaskLayerOpacity, tool, getImageDTO, onLayerPosChanged);
}, [ }, [
stage, stage,
state.layers, state.layers,

View File

@ -1,233 +0,0 @@
import { $ctrl, $meta } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { calculateNewBrushSize } from 'features/canvas/hooks/useCanvasZoom';
import {
$isDrawing,
$lastCursorPos,
$lastMouseDownPos,
$tool,
brushSizeChanged,
rgLayerLineAdded,
rgLayerPointsAdded,
rgLayerRectAdded,
} from 'features/controlLayers/store/controlLayersSlice';
import type Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node';
import type { Vector2d } from 'konva/lib/types';
import { clamp } from 'lodash-es';
import { useCallback, useMemo, useRef } from 'react';
const getIsFocused = (stage: Konva.Stage) => {
return stage.container().contains(document.activeElement);
};
const getIsMouseDown = (e: KonvaEventObject<MouseEvent>) => e.evt.buttons === 1;
const SNAP_PX = 10;
export const snapPosToStage = (pos: Vector2d, stage: Konva.Stage) => {
const snappedPos = { ...pos };
// Get the normalized threshold for snapping to the edge of the stage
const thresholdX = SNAP_PX / stage.scaleX();
const thresholdY = SNAP_PX / stage.scaleY();
const stageWidth = stage.width() / stage.scaleX();
const stageHeight = stage.height() / stage.scaleY();
// Snap to the edge of the stage if within threshold
if (pos.x - thresholdX < 0) {
snappedPos.x = 0;
} else if (pos.x + thresholdX > stageWidth) {
snappedPos.x = Math.floor(stageWidth);
}
if (pos.y - thresholdY < 0) {
snappedPos.y = 0;
} else if (pos.y + thresholdY > stageHeight) {
snappedPos.y = Math.floor(stageHeight);
}
return snappedPos;
};
export const getScaledFlooredCursorPosition = (stage: Konva.Stage) => {
const pointerPosition = stage.getPointerPosition();
const stageTransform = stage.getAbsoluteTransform().copy();
if (!pointerPosition) {
return;
}
const scaledCursorPosition = stageTransform.invert().point(pointerPosition);
return {
x: Math.floor(scaledCursorPosition.x),
y: Math.floor(scaledCursorPosition.y),
};
};
const syncCursorPos = (stage: Konva.Stage): Vector2d | null => {
const pos = getScaledFlooredCursorPosition(stage);
if (!pos) {
return null;
}
$lastCursorPos.set(pos);
return pos;
};
const BRUSH_SPACING_PCT = 10;
const MIN_BRUSH_SPACING_PX = 5;
const MAX_BRUSH_SPACING_PX = 15;
export const useMouseEvents = () => {
const dispatch = useAppDispatch();
const selectedLayerId = useAppSelector((s) => s.controlLayers.present.selectedLayerId);
const selectedLayerType = useAppSelector((s) => {
const selectedLayer = s.controlLayers.present.layers.find((l) => l.id === s.controlLayers.present.selectedLayerId);
if (!selectedLayer) {
return null;
}
return selectedLayer.type;
});
const tool = useStore($tool);
const lastCursorPosRef = useRef<[number, number] | null>(null);
const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
const brushSize = useAppSelector((s) => s.controlLayers.present.brushSize);
const brushSpacingPx = useMemo(
() => clamp(brushSize / BRUSH_SPACING_PCT, MIN_BRUSH_SPACING_PX, MAX_BRUSH_SPACING_PX),
[brushSize]
);
const onMouseDown = useCallback(
(e: KonvaEventObject<MouseEvent>) => {
const stage = e.target.getStage();
if (!stage) {
return;
}
const pos = syncCursorPos(stage);
if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
return;
}
if (tool === 'brush' || tool === 'eraser') {
dispatch(
rgLayerLineAdded({
layerId: selectedLayerId,
points: [pos.x, pos.y, pos.x, pos.y],
tool,
})
);
$isDrawing.set(true);
$lastMouseDownPos.set(pos);
} else if (tool === 'rect') {
$lastMouseDownPos.set(snapPosToStage(pos, stage));
}
},
[dispatch, selectedLayerId, selectedLayerType, tool]
);
const onMouseUp = useCallback(
(e: KonvaEventObject<MouseEvent>) => {
const stage = e.target.getStage();
if (!stage) {
return;
}
const pos = $lastCursorPos.get();
if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
return;
}
const lastPos = $lastMouseDownPos.get();
const tool = $tool.get();
if (lastPos && selectedLayerId && tool === 'rect') {
const snappedPos = snapPosToStage(pos, stage);
dispatch(
rgLayerRectAdded({
layerId: selectedLayerId,
rect: {
x: Math.min(snappedPos.x, lastPos.x),
y: Math.min(snappedPos.y, lastPos.y),
width: Math.abs(snappedPos.x - lastPos.x),
height: Math.abs(snappedPos.y - lastPos.y),
},
})
);
}
$isDrawing.set(false);
$lastMouseDownPos.set(null);
},
[dispatch, selectedLayerId, selectedLayerType]
);
const onMouseMove = useCallback(
(e: KonvaEventObject<MouseEvent>) => {
const stage = e.target.getStage();
if (!stage) {
return;
}
const pos = syncCursorPos(stage);
if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
return;
}
if (getIsFocused(stage) && getIsMouseDown(e) && (tool === 'brush' || tool === 'eraser')) {
if ($isDrawing.get()) {
// Continue the last line
if (lastCursorPosRef.current) {
// Dispatching redux events impacts perf substantially - using brush spacing keeps dispatches to a reasonable number
if (Math.hypot(lastCursorPosRef.current[0] - pos.x, lastCursorPosRef.current[1] - pos.y) < brushSpacingPx) {
return;
}
}
lastCursorPosRef.current = [pos.x, pos.y];
dispatch(rgLayerPointsAdded({ layerId: selectedLayerId, point: lastCursorPosRef.current }));
} else {
// Start a new line
dispatch(rgLayerLineAdded({ layerId: selectedLayerId, points: [pos.x, pos.y, pos.x, pos.y], tool }));
}
$isDrawing.set(true);
}
},
[brushSpacingPx, dispatch, selectedLayerId, selectedLayerType, tool]
);
const onMouseLeave = useCallback(
(e: KonvaEventObject<MouseEvent>) => {
const stage = e.target.getStage();
if (!stage) {
return;
}
const pos = syncCursorPos(stage);
$isDrawing.set(false);
$lastCursorPos.set(null);
$lastMouseDownPos.set(null);
if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
return;
}
if (getIsFocused(stage) && getIsMouseDown(e) && (tool === 'brush' || tool === 'eraser')) {
dispatch(rgLayerPointsAdded({ layerId: selectedLayerId, point: [pos.x, pos.y] }));
}
},
[selectedLayerId, selectedLayerType, tool, dispatch]
);
const onMouseWheel = useCallback(
(e: KonvaEventObject<WheelEvent>) => {
e.evt.preventDefault();
if (selectedLayerType !== 'regional_guidance_layer' || (tool !== 'brush' && tool !== 'eraser')) {
return;
}
// checking for ctrl key is pressed or not,
// so that brush size can be controlled using ctrl + scroll up/down
// Invert the delta if the property is set to true
let delta = e.evt.deltaY;
if (shouldInvertBrushSizeScrollDirection) {
delta = -delta;
}
if ($ctrl.get() || $meta.get()) {
dispatch(brushSizeChanged(calculateNewBrushSize(brushSize, delta)));
}
},
[selectedLayerType, tool, shouldInvertBrushSizeScrollDirection, dispatch, brushSize]
);
const handlers = useMemo(
() => ({ onMouseDown, onMouseUp, onMouseMove, onMouseLeave, onMouseWheel }),
[onMouseDown, onMouseUp, onMouseMove, onMouseLeave, onMouseWheel]
);
return handlers;
};

View File

@ -1,11 +1,10 @@
import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { imageDataToDataURL } from 'features/canvas/util/blobToDataURL'; import { imageDataToDataURL } from 'features/canvas/util/blobToDataURL';
import { RG_LAYER_OBJECT_GROUP_NAME } from 'features/controlLayers/store/controlLayersSlice';
import Konva from 'konva'; import Konva from 'konva';
import type { IRect } from 'konva/lib/types'; import type { IRect } from 'konva/lib/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
const GET_CLIENT_RECT_CONFIG = { skipTransform: true }; import { RG_LAYER_OBJECT_GROUP_NAME } from './naming';
type Extents = { type Extents = {
minX: number; minX: number;
@ -14,10 +13,13 @@ type Extents = {
maxY: number; maxY: number;
}; };
const GET_CLIENT_RECT_CONFIG = { skipTransform: true };
//#region getImageDataBbox
/** /**
* Get the bounding box of an image. * Get the bounding box of an image.
* @param imageData The ImageData object to get the bounding box of. * @param imageData The ImageData object to get the bounding box of.
* @returns The minimum and maximum x and y values of the image's bounding box. * @returns The minimum and maximum x and y values of the image's bounding box, or null if the image has no pixels.
*/ */
const getImageDataBbox = (imageData: ImageData): Extents | null => { const getImageDataBbox = (imageData: ImageData): Extents | null => {
const { data, width, height } = imageData; const { data, width, height } = imageData;
@ -51,7 +53,9 @@ const getImageDataBbox = (imageData: ImageData): Extents | null => {
return isEmpty ? null : { minX, minY, maxX, maxY }; return isEmpty ? null : { minX, minY, maxX, maxY };
}; };
//#endregion
//#region getIsolatedRGLayerClone
/** /**
* Clones a regional guidance konva layer onto an offscreen stage/canvas. This allows the pixel data for a given layer * Clones a regional guidance konva layer onto an offscreen stage/canvas. This allows the pixel data for a given layer
* to be captured, manipulated or analyzed without interference from other layers. * to be captured, manipulated or analyzed without interference from other layers.
@ -88,7 +92,9 @@ const getIsolatedRGLayerClone = (layer: Konva.Layer): { stageClone: Konva.Stage;
return { stageClone, layerClone }; return { stageClone, layerClone };
}; };
//#endregion
//#region getLayerBboxPixels
/** /**
* Get the bounding box of a regional prompt konva layer. This function has special handling for regional prompt layers. * Get the bounding box of a regional prompt konva layer. This function has special handling for regional prompt layers.
* @param layer The konva layer to get the bounding box of. * @param layer The konva layer to get the bounding box of.
@ -137,7 +143,9 @@ export const getLayerBboxPixels = (layer: Konva.Layer, preview: boolean = false)
return correctedLayerBbox; return correctedLayerBbox;
}; };
//#endregion
//#region getLayerBboxFast
/** /**
* Get the bounding box of a konva layer. This function is faster than `getLayerBboxPixels` but less accurate. It * Get the bounding box of a konva layer. This function is faster than `getLayerBboxPixels` but less accurate. It
* should only be used when there are no eraser strokes or shapes in the layer. * should only be used when there are no eraser strokes or shapes in the layer.
@ -153,3 +161,4 @@ export const getLayerBboxFast = (layer: Konva.Layer): IRect => {
height: Math.floor(bbox.height), height: Math.floor(bbox.height),
}; };
}; };
//#endregion

View File

@ -0,0 +1,36 @@
/**
* A transparency checker pattern image.
* This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL
*/
export const TRANSPARENCY_CHECKER_PATTERN =
'';
/**
* The color of a bounding box stroke when its object is selected.
*/
export const BBOX_SELECTED_STROKE = 'rgba(78, 190, 255, 1)';
/**
* The inner border color for the brush preview.
*/
export const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)';
/**
* The outer border color for the brush preview.
*/
export const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)';
/**
* The target spacing of individual points of brush strokes, as a percentage of the brush size.
*/
export const BRUSH_SPACING_PCT = 10;
/**
* The minimum brush spacing in pixels.
*/
export const MIN_BRUSH_SPACING_PX = 5;
/**
* The maximum brush spacing in pixels.
*/
export const MAX_BRUSH_SPACING_PX = 15;

View File

@ -0,0 +1,201 @@
import { calculateNewBrushSize } from 'features/canvas/hooks/useCanvasZoom';
import {
getIsFocused,
getIsMouseDown,
getScaledFlooredCursorPosition,
snapPosToStage,
} from 'features/controlLayers/konva/util';
import type { AddLineArg, AddPointToLineArg, AddRectArg, Layer, Tool } from 'features/controlLayers/store/types';
import type Konva from 'konva';
import type { Vector2d } from 'konva/lib/types';
import type { WritableAtom } from 'nanostores';
import { TOOL_PREVIEW_LAYER_ID } from './naming';
type SetStageEventHandlersArg = {
stage: Konva.Stage;
$tool: WritableAtom<Tool>;
$isDrawing: WritableAtom<boolean>;
$lastMouseDownPos: WritableAtom<Vector2d | null>;
$lastCursorPos: WritableAtom<Vector2d | null>;
$lastAddedPoint: WritableAtom<Vector2d | null>;
$brushSize: WritableAtom<number>;
$brushSpacingPx: WritableAtom<number>;
$selectedLayerId: WritableAtom<string | null>;
$selectedLayerType: WritableAtom<Layer['type'] | null>;
$shouldInvertBrushSizeScrollDirection: WritableAtom<boolean>;
onRGLayerLineAdded: (arg: AddLineArg) => void;
onRGLayerPointAddedToLine: (arg: AddPointToLineArg) => void;
onRGLayerRectAdded: (arg: AddRectArg) => void;
onBrushSizeChanged: (size: number) => void;
};
const syncCursorPos = (stage: Konva.Stage, $lastCursorPos: WritableAtom<Vector2d | null>) => {
const pos = getScaledFlooredCursorPosition(stage);
if (!pos) {
return null;
}
$lastCursorPos.set(pos);
return pos;
};
export const setStageEventHandlers = ({
stage,
$tool,
$isDrawing,
$lastMouseDownPos,
$lastCursorPos,
$lastAddedPoint,
$brushSize,
$brushSpacingPx,
$selectedLayerId,
$selectedLayerType,
$shouldInvertBrushSizeScrollDirection,
onRGLayerLineAdded,
onRGLayerPointAddedToLine,
onRGLayerRectAdded,
onBrushSizeChanged,
}: SetStageEventHandlersArg): (() => void) => {
stage.on('mouseenter', (e) => {
const stage = e.target.getStage();
if (!stage) {
return;
}
const tool = $tool.get();
stage.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(tool === 'brush' || tool === 'eraser');
});
stage.on('mousedown', (e) => {
const stage = e.target.getStage();
if (!stage) {
return;
}
const tool = $tool.get();
const pos = syncCursorPos(stage, $lastCursorPos);
const selectedLayerId = $selectedLayerId.get();
const selectedLayerType = $selectedLayerType.get();
if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
return;
}
if (tool === 'brush' || tool === 'eraser') {
onRGLayerLineAdded({
layerId: selectedLayerId,
points: [pos.x, pos.y, pos.x, pos.y],
tool,
});
$isDrawing.set(true);
$lastMouseDownPos.set(pos);
} else if (tool === 'rect') {
$lastMouseDownPos.set(snapPosToStage(pos, stage));
}
});
stage.on('mouseup', (e) => {
const stage = e.target.getStage();
if (!stage) {
return;
}
const pos = $lastCursorPos.get();
const selectedLayerId = $selectedLayerId.get();
const selectedLayerType = $selectedLayerType.get();
if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
return;
}
const lastPos = $lastMouseDownPos.get();
const tool = $tool.get();
if (lastPos && selectedLayerId && tool === 'rect') {
const snappedPos = snapPosToStage(pos, stage);
onRGLayerRectAdded({
layerId: selectedLayerId,
rect: {
x: Math.min(snappedPos.x, lastPos.x),
y: Math.min(snappedPos.y, lastPos.y),
width: Math.abs(snappedPos.x - lastPos.x),
height: Math.abs(snappedPos.y - lastPos.y),
},
});
}
$isDrawing.set(false);
$lastMouseDownPos.set(null);
});
stage.on('mousemove', (e) => {
const stage = e.target.getStage();
if (!stage) {
return;
}
const tool = $tool.get();
const pos = syncCursorPos(stage, $lastCursorPos);
const selectedLayerId = $selectedLayerId.get();
const selectedLayerType = $selectedLayerType.get();
stage.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(tool === 'brush' || tool === 'eraser');
if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
return;
}
if (getIsFocused(stage) && getIsMouseDown(e) && (tool === 'brush' || tool === 'eraser')) {
if ($isDrawing.get()) {
// Continue the last line
const lastAddedPoint = $lastAddedPoint.get();
if (lastAddedPoint) {
// Dispatching redux events impacts perf substantially - using brush spacing keeps dispatches to a reasonable number
if (Math.hypot(lastAddedPoint.x - pos.x, lastAddedPoint.y - pos.y) < $brushSpacingPx.get()) {
return;
}
}
$lastAddedPoint.set({ x: pos.x, y: pos.y });
onRGLayerPointAddedToLine({ layerId: selectedLayerId, point: [pos.x, pos.y] });
} else {
// Start a new line
onRGLayerLineAdded({ layerId: selectedLayerId, points: [pos.x, pos.y, pos.x, pos.y], tool });
}
$isDrawing.set(true);
}
});
stage.on('mouseleave', (e) => {
const stage = e.target.getStage();
if (!stage) {
return;
}
const pos = syncCursorPos(stage, $lastCursorPos);
$isDrawing.set(false);
$lastCursorPos.set(null);
$lastMouseDownPos.set(null);
const selectedLayerId = $selectedLayerId.get();
const selectedLayerType = $selectedLayerType.get();
const tool = $tool.get();
stage.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(false);
if (!pos || !selectedLayerId || selectedLayerType !== 'regional_guidance_layer') {
return;
}
if (getIsFocused(stage) && getIsMouseDown(e) && (tool === 'brush' || tool === 'eraser')) {
onRGLayerPointAddedToLine({ layerId: selectedLayerId, point: [pos.x, pos.y] });
}
});
stage.on('wheel', (e) => {
e.evt.preventDefault();
const selectedLayerType = $selectedLayerType.get();
const tool = $tool.get();
if (selectedLayerType !== 'regional_guidance_layer' || (tool !== 'brush' && tool !== 'eraser')) {
return;
}
// Invert the delta if the property is set to true
let delta = e.evt.deltaY;
if ($shouldInvertBrushSizeScrollDirection.get()) {
delta = -delta;
}
if (e.evt.ctrlKey || e.evt.metaKey) {
onBrushSizeChanged(calculateNewBrushSize($brushSize.get(), delta));
}
});
return () => stage.off('mousedown mouseup mousemove mouseenter mouseleave wheel');
};

View File

@ -0,0 +1,21 @@
/**
* Konva filters
* https://konvajs.org/docs/filters/Custom_Filter.html
*/
/**
* Calculates the lightness (HSL) of a given pixel and sets the alpha channel to that value.
* This is useful for edge maps and other masks, to make the black areas transparent.
* @param imageData The image data to apply the filter to
*/
export const LightnessToAlphaFilter = (imageData: ImageData): void => {
const len = imageData.data.length / 4;
for (let i = 0; i < len; i++) {
const r = imageData.data[i * 4 + 0] as number;
const g = imageData.data[i * 4 + 1] as number;
const b = imageData.data[i * 4 + 2] as number;
const cMin = Math.min(r, g, b);
const cMax = Math.max(r, g, b);
imageData.data[i * 4 + 3] = (cMin + cMax) / 2;
}
};

View File

@ -0,0 +1,38 @@
/**
* This file contains IDs, names, and ID getters for konva layers and objects.
*/
// IDs for singleton Konva layers and objects
export const TOOL_PREVIEW_LAYER_ID = 'tool_preview_layer';
export const TOOL_PREVIEW_BRUSH_GROUP_ID = 'tool_preview_layer.brush_group';
export const TOOL_PREVIEW_BRUSH_FILL_ID = 'tool_preview_layer.brush_fill';
export const TOOL_PREVIEW_BRUSH_BORDER_INNER_ID = 'tool_preview_layer.brush_border_inner';
export const TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID = 'tool_preview_layer.brush_border_outer';
export const TOOL_PREVIEW_RECT_ID = 'tool_preview_layer.rect';
export const BACKGROUND_LAYER_ID = 'background_layer';
export const BACKGROUND_RECT_ID = 'background_layer.rect';
export const NO_LAYERS_MESSAGE_LAYER_ID = 'no_layers_message';
// Names for Konva layers and objects (comparable to CSS classes)
export const CA_LAYER_NAME = 'control_adapter_layer';
export const CA_LAYER_IMAGE_NAME = 'control_adapter_layer.image';
export const RG_LAYER_NAME = 'regional_guidance_layer';
export const RG_LAYER_LINE_NAME = 'regional_guidance_layer.line';
export const RG_LAYER_OBJECT_GROUP_NAME = 'regional_guidance_layer.object_group';
export const RG_LAYER_RECT_NAME = 'regional_guidance_layer.rect';
export const INITIAL_IMAGE_LAYER_ID = 'singleton_initial_image_layer';
export const INITIAL_IMAGE_LAYER_NAME = 'initial_image_layer';
export const INITIAL_IMAGE_LAYER_IMAGE_NAME = 'initial_image_layer.image';
export const LAYER_BBOX_NAME = 'layer.bbox';
export const COMPOSITING_RECT_NAME = 'compositing-rect';
// Getters for non-singleton layer and object IDs
export const getRGLayerId = (layerId: string) => `${RG_LAYER_NAME}_${layerId}`;
export const getRGLayerLineId = (layerId: string, lineId: string) => `${layerId}.line_${lineId}`;
export const getRGLayerRectId = (layerId: string, lineId: string) => `${layerId}.rect_${lineId}`;
export const getRGLayerObjectGroupId = (layerId: string, groupId: string) => `${layerId}.objectGroup_${groupId}`;
export const getLayerBboxId = (layerId: string) => `${layerId}.bbox`;
export const getCALayerId = (layerId: string) => `control_adapter_layer_${layerId}`;
export const getCALayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`;
export const getIILayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`;
export const getIPALayerId = (layerId: string) => `ip_adapter_layer_${layerId}`;

View File

@ -1,8 +1,7 @@
import { getStore } from 'app/store/nanostores/store';
import { rgbaColorToString, rgbColorToString } from 'features/canvas/util/colorToString'; import { rgbaColorToString, rgbColorToString } from 'features/canvas/util/colorToString';
import { getScaledFlooredCursorPosition, snapPosToStage } from 'features/controlLayers/hooks/mouseEventHooks'; import { getLayerBboxFast, getLayerBboxPixels } from 'features/controlLayers/konva/bbox';
import { LightnessToAlphaFilter } from 'features/controlLayers/konva/filters';
import { import {
$tool,
BACKGROUND_LAYER_ID, BACKGROUND_LAYER_ID,
BACKGROUND_RECT_ID, BACKGROUND_RECT_ID,
CA_LAYER_IMAGE_NAME, CA_LAYER_IMAGE_NAME,
@ -14,10 +13,6 @@ import {
getRGLayerObjectGroupId, getRGLayerObjectGroupId,
INITIAL_IMAGE_LAYER_IMAGE_NAME, INITIAL_IMAGE_LAYER_IMAGE_NAME,
INITIAL_IMAGE_LAYER_NAME, INITIAL_IMAGE_LAYER_NAME,
isControlAdapterLayer,
isInitialImageLayer,
isRegionalGuidanceLayer,
isRenderableLayer,
LAYER_BBOX_NAME, LAYER_BBOX_NAME,
NO_LAYERS_MESSAGE_LAYER_ID, NO_LAYERS_MESSAGE_LAYER_ID,
RG_LAYER_LINE_NAME, RG_LAYER_LINE_NAME,
@ -30,6 +25,13 @@ import {
TOOL_PREVIEW_BRUSH_GROUP_ID, TOOL_PREVIEW_BRUSH_GROUP_ID,
TOOL_PREVIEW_LAYER_ID, TOOL_PREVIEW_LAYER_ID,
TOOL_PREVIEW_RECT_ID, TOOL_PREVIEW_RECT_ID,
} from 'features/controlLayers/konva/naming';
import { getScaledFlooredCursorPosition, snapPosToStage } from 'features/controlLayers/konva/util';
import {
isControlAdapterLayer,
isInitialImageLayer,
isRegionalGuidanceLayer,
isRenderableLayer,
} from 'features/controlLayers/store/controlLayersSlice'; } from 'features/controlLayers/store/controlLayersSlice';
import type { import type {
ControlAdapterLayer, ControlAdapterLayer,
@ -40,61 +42,46 @@ import type {
VectorMaskLine, VectorMaskLine,
VectorMaskRect, VectorMaskRect,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import { getLayerBboxFast, getLayerBboxPixels } from 'features/controlLayers/util/bbox';
import { t } from 'i18next'; import { t } from 'i18next';
import Konva from 'konva'; import Konva from 'konva';
import type { IRect, Vector2d } from 'konva/lib/types'; import type { IRect, Vector2d } from 'konva/lib/types';
import { debounce } from 'lodash-es'; import { debounce } from 'lodash-es';
import type { RgbColor } from 'react-colorful'; import type { RgbColor } from 'react-colorful';
import { imagesApi } from 'services/api/endpoints/images'; import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
const BBOX_SELECTED_STROKE = 'rgba(78, 190, 255, 1)'; import {
const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)'; BBOX_SELECTED_STROKE,
const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)'; BRUSH_BORDER_INNER_COLOR,
// This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL BRUSH_BORDER_OUTER_COLOR,
export const STAGE_BG_DATAURL = TRANSPARENCY_CHECKER_PATTERN,
''; } from './constants';
const mapId = (object: { id: string }) => object.id; const mapId = (object: { id: string }): string => object.id;
const selectRenderableLayers = (n: Konva.Node) => /**
* Konva selection callback to select all renderable layers. This includes RG, CA and II layers.
*/
const selectRenderableLayers = (n: Konva.Node): boolean =>
n.name() === RG_LAYER_NAME || n.name() === CA_LAYER_NAME || n.name() === INITIAL_IMAGE_LAYER_NAME; n.name() === RG_LAYER_NAME || n.name() === CA_LAYER_NAME || n.name() === INITIAL_IMAGE_LAYER_NAME;
const selectVectorMaskObjects = (node: Konva.Node) => { /**
* Konva selection callback to select RG mask objects. This includes lines and rects.
*/
const selectVectorMaskObjects = (node: Konva.Node): boolean => {
return node.name() === RG_LAYER_LINE_NAME || node.name() === RG_LAYER_RECT_NAME; return node.name() === RG_LAYER_LINE_NAME || node.name() === RG_LAYER_RECT_NAME;
}; };
/** /**
* Creates the brush preview layer. * Creates the singleton tool preview layer and all its objects.
* @param stage The konva stage to render on. * @param stage The konva stage
* @returns The brush preview layer.
*/ */
const createToolPreviewLayer = (stage: Konva.Stage) => { const createToolPreviewLayer = (stage: Konva.Stage): Konva.Layer => {
// Initialize the brush preview layer & add to the stage // Initialize the brush preview layer & add to the stage
const toolPreviewLayer = new Konva.Layer({ id: TOOL_PREVIEW_LAYER_ID, visible: false, listening: false }); const toolPreviewLayer = new Konva.Layer({ id: TOOL_PREVIEW_LAYER_ID, visible: false, listening: false });
stage.add(toolPreviewLayer); stage.add(toolPreviewLayer);
// Add handlers to show/hide the brush preview layer
stage.on('mousemove', (e) => {
const tool = $tool.get();
e.target
.getStage()
?.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)
?.visible(tool === 'brush' || tool === 'eraser');
});
stage.on('mouseleave', (e) => {
e.target.getStage()?.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(false);
});
stage.on('mouseenter', (e) => {
const tool = $tool.get();
e.target
.getStage()
?.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)
?.visible(tool === 'brush' || tool === 'eraser');
});
// Create the brush preview group & circles // Create the brush preview group & circles
const brushPreviewGroup = new Konva.Group({ id: TOOL_PREVIEW_BRUSH_GROUP_ID }); const brushPreviewGroup = new Konva.Group({ id: TOOL_PREVIEW_BRUSH_GROUP_ID });
const brushPreviewFill = new Konva.Circle({ const brushPreviewFill = new Konva.Circle({
@ -121,7 +108,7 @@ const createToolPreviewLayer = (stage: Konva.Stage) => {
brushPreviewGroup.add(brushPreviewBorderOuter); brushPreviewGroup.add(brushPreviewBorderOuter);
toolPreviewLayer.add(brushPreviewGroup); toolPreviewLayer.add(brushPreviewGroup);
// Create the rect preview // Create the rect preview - this is a rectangle drawn from the last mouse down position to the current cursor position
const rectPreview = new Konva.Rect({ id: TOOL_PREVIEW_RECT_ID, listening: false, stroke: 'white', strokeWidth: 1 }); const rectPreview = new Konva.Rect({ id: TOOL_PREVIEW_RECT_ID, listening: false, stroke: 'white', strokeWidth: 1 });
toolPreviewLayer.add(rectPreview); toolPreviewLayer.add(rectPreview);
@ -130,12 +117,14 @@ const createToolPreviewLayer = (stage: Konva.Stage) => {
/** /**
* Renders the brush preview for the selected tool. * Renders the brush preview for the selected tool.
* @param stage The konva stage to render on. * @param stage The konva stage
* @param tool The selected tool. * @param tool The selected tool
* @param color The selected layer's color. * @param color The selected layer's color
* @param cursorPos The cursor position. * @param selectedLayerType The selected layer's type
* @param lastMouseDownPos The position of the last mouse down event - used for the rect tool. * @param globalMaskLayerOpacity The global mask layer opacity
* @param brushSize The brush size. * @param cursorPos The cursor position
* @param lastMouseDownPos The position of the last mouse down event - used for the rect tool
* @param brushSize The brush size
*/ */
const renderToolPreview = ( const renderToolPreview = (
stage: Konva.Stage, stage: Konva.Stage,
@ -146,7 +135,7 @@ const renderToolPreview = (
cursorPos: Vector2d | null, cursorPos: Vector2d | null,
lastMouseDownPos: Vector2d | null, lastMouseDownPos: Vector2d | null,
brushSize: number brushSize: number
) => { ): void => {
const layerCount = stage.find(selectRenderableLayers).length; const layerCount = stage.find(selectRenderableLayers).length;
// Update the stage's pointer style // Update the stage's pointer style
if (layerCount === 0) { if (layerCount === 0) {
@ -162,7 +151,7 @@ const renderToolPreview = (
// Move rect gets a crosshair // Move rect gets a crosshair
stage.container().style.cursor = 'crosshair'; stage.container().style.cursor = 'crosshair';
} else { } else {
// Else we use the brush preview // Else we hide the native cursor and use the konva-rendered brush preview
stage.container().style.cursor = 'none'; stage.container().style.cursor = 'none';
} }
@ -227,28 +216,29 @@ const renderToolPreview = (
}; };
/** /**
* Creates a vector mask layer. * Creates a regional guidance layer.
* @param stage The konva stage to attach the layer to. * @param stage The konva stage
* @param reduxLayer The redux layer to create the konva layer from. * @param layerState The regional guidance layer state
* @param onLayerPosChanged Callback for when the layer's position changes. * @param onLayerPosChanged Callback for when the layer's position changes
*/ */
const createRegionalGuidanceLayer = ( const createRGLayer = (
stage: Konva.Stage, stage: Konva.Stage,
reduxLayer: RegionalGuidanceLayer, layerState: RegionalGuidanceLayer,
onLayerPosChanged?: (layerId: string, x: number, y: number) => void onLayerPosChanged?: (layerId: string, x: number, y: number) => void
) => { ): Konva.Layer => {
// This layer hasn't been added to the konva state yet // This layer hasn't been added to the konva state yet
const konvaLayer = new Konva.Layer({ const konvaLayer = new Konva.Layer({
id: reduxLayer.id, id: layerState.id,
name: RG_LAYER_NAME, name: RG_LAYER_NAME,
draggable: true, draggable: true,
dragDistance: 0, dragDistance: 0,
}); });
// Create a `dragmove` listener for this layer // When a drag on the layer finishes, update the layer's position in state. During the drag, konva handles changing
// the position - we do not need to call this on the `dragmove` event.
if (onLayerPosChanged) { if (onLayerPosChanged) {
konvaLayer.on('dragend', function (e) { konvaLayer.on('dragend', function (e) {
onLayerPosChanged(reduxLayer.id, Math.floor(e.target.x()), Math.floor(e.target.y())); onLayerPosChanged(layerState.id, Math.floor(e.target.x()), Math.floor(e.target.y()));
}); });
} }
@ -258,7 +248,7 @@ const createRegionalGuidanceLayer = (
if (!cursorPos) { if (!cursorPos) {
return this.getAbsolutePosition(); return this.getAbsolutePosition();
} }
// Prevent the user from dragging the layer out of the stage bounds. // Prevent the user from dragging the layer out of the stage bounds by constaining the cursor position to the stage bounds
if ( if (
cursorPos.x < 0 || cursorPos.x < 0 ||
cursorPos.x > stage.width() / stage.scaleX() || cursorPos.x > stage.width() / stage.scaleX() ||
@ -272,7 +262,7 @@ const createRegionalGuidanceLayer = (
// The object group holds all of the layer's objects (e.g. lines and rects) // The object group holds all of the layer's objects (e.g. lines and rects)
const konvaObjectGroup = new Konva.Group({ const konvaObjectGroup = new Konva.Group({
id: getRGLayerObjectGroupId(reduxLayer.id, uuidv4()), id: getRGLayerObjectGroupId(layerState.id, uuidv4()),
name: RG_LAYER_OBJECT_GROUP_NAME, name: RG_LAYER_OBJECT_GROUP_NAME,
listening: false, listening: false,
}); });
@ -284,47 +274,51 @@ const createRegionalGuidanceLayer = (
}; };
/** /**
* Creates a konva line from a redux vector mask line. * Creates a konva line from a vector mask line.
* @param reduxObject The redux object to create the konva line from. * @param vectorMaskLine The vector mask line state
* @param konvaGroup The konva group to add the line to. * @param layerObjectGroup The konva layer's object group to add the line to
*/ */
const createVectorMaskLine = (reduxObject: VectorMaskLine, konvaGroup: Konva.Group): Konva.Line => { const createVectorMaskLine = (vectorMaskLine: VectorMaskLine, layerObjectGroup: Konva.Group): Konva.Line => {
const vectorMaskLine = new Konva.Line({ const konvaLine = new Konva.Line({
id: reduxObject.id, id: vectorMaskLine.id,
key: reduxObject.id, key: vectorMaskLine.id,
name: RG_LAYER_LINE_NAME, name: RG_LAYER_LINE_NAME,
strokeWidth: reduxObject.strokeWidth, strokeWidth: vectorMaskLine.strokeWidth,
tension: 0, tension: 0,
lineCap: 'round', lineCap: 'round',
lineJoin: 'round', lineJoin: 'round',
shadowForStrokeEnabled: false, shadowForStrokeEnabled: false,
globalCompositeOperation: reduxObject.tool === 'brush' ? 'source-over' : 'destination-out', globalCompositeOperation: vectorMaskLine.tool === 'brush' ? 'source-over' : 'destination-out',
listening: false, listening: false,
}); });
konvaGroup.add(vectorMaskLine); layerObjectGroup.add(konvaLine);
return vectorMaskLine; return konvaLine;
}; };
/** /**
* Creates a konva rect from a redux vector mask rect. * Creates a konva rect from a vector mask rect.
* @param reduxObject The redux object to create the konva rect from. * @param vectorMaskRect The vector mask rect state
* @param konvaGroup The konva group to add the rect to. * @param layerObjectGroup The konva layer's object group to add the line to
*/ */
const createVectorMaskRect = (reduxObject: VectorMaskRect, konvaGroup: Konva.Group): Konva.Rect => { const createVectorMaskRect = (vectorMaskRect: VectorMaskRect, layerObjectGroup: Konva.Group): Konva.Rect => {
const vectorMaskRect = new Konva.Rect({ const konvaRect = new Konva.Rect({
id: reduxObject.id, id: vectorMaskRect.id,
key: reduxObject.id, key: vectorMaskRect.id,
name: RG_LAYER_RECT_NAME, name: RG_LAYER_RECT_NAME,
x: reduxObject.x, x: vectorMaskRect.x,
y: reduxObject.y, y: vectorMaskRect.y,
width: reduxObject.width, width: vectorMaskRect.width,
height: reduxObject.height, height: vectorMaskRect.height,
listening: false, listening: false,
}); });
konvaGroup.add(vectorMaskRect); layerObjectGroup.add(konvaRect);
return vectorMaskRect; return konvaRect;
}; };
/**
* Creates the "compositing rect" for a layer.
* @param konvaLayer The konva layer
*/
const createCompositingRect = (konvaLayer: Konva.Layer): Konva.Rect => { const createCompositingRect = (konvaLayer: Konva.Layer): Konva.Rect => {
const compositingRect = new Konva.Rect({ name: COMPOSITING_RECT_NAME, listening: false }); const compositingRect = new Konva.Rect({ name: COMPOSITING_RECT_NAME, listening: false });
konvaLayer.add(compositingRect); konvaLayer.add(compositingRect);
@ -332,41 +326,41 @@ const createCompositingRect = (konvaLayer: Konva.Layer): Konva.Rect => {
}; };
/** /**
* Renders a vector mask layer. * Renders a regional guidance layer.
* @param stage The konva stage to render on. * @param stage The konva stage
* @param reduxLayer The redux vector mask layer to render. * @param layerState The regional guidance layer state
* @param reduxLayerIndex The index of the layer in the redux store. * @param globalMaskLayerOpacity The global mask layer opacity
* @param globalMaskLayerOpacity The opacity of the global mask layer. * @param tool The current tool
* @param tool The current tool. * @param onLayerPosChanged Callback for when the layer's position changes
*/ */
const renderRegionalGuidanceLayer = ( const renderRGLayer = (
stage: Konva.Stage, stage: Konva.Stage,
reduxLayer: RegionalGuidanceLayer, layerState: RegionalGuidanceLayer,
globalMaskLayerOpacity: number, globalMaskLayerOpacity: number,
tool: Tool, tool: Tool,
onLayerPosChanged?: (layerId: string, x: number, y: number) => void onLayerPosChanged?: (layerId: string, x: number, y: number) => void
): void => { ): void => {
const konvaLayer = const konvaLayer =
stage.findOne<Konva.Layer>(`#${reduxLayer.id}`) ?? stage.findOne<Konva.Layer>(`#${layerState.id}`) ?? createRGLayer(stage, layerState, onLayerPosChanged);
createRegionalGuidanceLayer(stage, reduxLayer, onLayerPosChanged);
// Update the layer's position and listening state // Update the layer's position and listening state
konvaLayer.setAttrs({ konvaLayer.setAttrs({
listening: tool === 'move', // The layer only listens when using the move tool - otherwise the stage is handling mouse events listening: tool === 'move', // The layer only listens when using the move tool - otherwise the stage is handling mouse events
x: Math.floor(reduxLayer.x), x: Math.floor(layerState.x),
y: Math.floor(reduxLayer.y), y: Math.floor(layerState.y),
}); });
// Convert the color to a string, stripping the alpha - the object group will handle opacity. // Convert the color to a string, stripping the alpha - the object group will handle opacity.
const rgbColor = rgbColorToString(reduxLayer.previewColor); const rgbColor = rgbColorToString(layerState.previewColor);
const konvaObjectGroup = konvaLayer.findOne<Konva.Group>(`.${RG_LAYER_OBJECT_GROUP_NAME}`); const konvaObjectGroup = konvaLayer.findOne<Konva.Group>(`.${RG_LAYER_OBJECT_GROUP_NAME}`);
assert(konvaObjectGroup, `Object group not found for layer ${reduxLayer.id}`); assert(konvaObjectGroup, `Object group not found for layer ${layerState.id}`);
// We use caching to handle "global" layer opacity, but caching is expensive and we should only do it when required. // We use caching to handle "global" layer opacity, but caching is expensive and we should only do it when required.
let groupNeedsCache = false; let groupNeedsCache = false;
const objectIds = reduxLayer.maskObjects.map(mapId); const objectIds = layerState.maskObjects.map(mapId);
// Destroy any objects that are no longer in the redux state
for (const objectNode of konvaObjectGroup.find(selectVectorMaskObjects)) { for (const objectNode of konvaObjectGroup.find(selectVectorMaskObjects)) {
if (!objectIds.includes(objectNode.id())) { if (!objectIds.includes(objectNode.id())) {
objectNode.destroy(); objectNode.destroy();
@ -374,15 +368,15 @@ const renderRegionalGuidanceLayer = (
} }
} }
for (const reduxObject of reduxLayer.maskObjects) { for (const maskObject of layerState.maskObjects) {
if (reduxObject.type === 'vector_mask_line') { if (maskObject.type === 'vector_mask_line') {
const vectorMaskLine = const vectorMaskLine =
stage.findOne<Konva.Line>(`#${reduxObject.id}`) ?? createVectorMaskLine(reduxObject, konvaObjectGroup); stage.findOne<Konva.Line>(`#${maskObject.id}`) ?? createVectorMaskLine(maskObject, konvaObjectGroup);
// Only update the points if they have changed. The point values are never mutated, they are only added to the // Only update the points if they have changed. The point values are never mutated, they are only added to the
// array, so checking the length is sufficient to determine if we need to re-cache. // array, so checking the length is sufficient to determine if we need to re-cache.
if (vectorMaskLine.points().length !== reduxObject.points.length) { if (vectorMaskLine.points().length !== maskObject.points.length) {
vectorMaskLine.points(reduxObject.points); vectorMaskLine.points(maskObject.points);
groupNeedsCache = true; groupNeedsCache = true;
} }
// Only update the color if it has changed. // Only update the color if it has changed.
@ -390,9 +384,9 @@ const renderRegionalGuidanceLayer = (
vectorMaskLine.stroke(rgbColor); vectorMaskLine.stroke(rgbColor);
groupNeedsCache = true; groupNeedsCache = true;
} }
} else if (reduxObject.type === 'vector_mask_rect') { } else if (maskObject.type === 'vector_mask_rect') {
const konvaObject = const konvaObject =
stage.findOne<Konva.Rect>(`#${reduxObject.id}`) ?? createVectorMaskRect(reduxObject, konvaObjectGroup); stage.findOne<Konva.Rect>(`#${maskObject.id}`) ?? createVectorMaskRect(maskObject, konvaObjectGroup);
// Only update the color if it has changed. // Only update the color if it has changed.
if (konvaObject.fill() !== rgbColor) { if (konvaObject.fill() !== rgbColor) {
@ -403,8 +397,8 @@ const renderRegionalGuidanceLayer = (
} }
// Only update layer visibility if it has changed. // Only update layer visibility if it has changed.
if (konvaLayer.visible() !== reduxLayer.isEnabled) { if (konvaLayer.visible() !== layerState.isEnabled) {
konvaLayer.visible(reduxLayer.isEnabled); konvaLayer.visible(layerState.isEnabled);
groupNeedsCache = true; groupNeedsCache = true;
} }
@ -428,7 +422,7 @@ const renderRegionalGuidanceLayer = (
* Instead, with the special handling, the effect is as if you drew all the shapes at 100% opacity, flattened them to * Instead, with the special handling, the effect is as if you drew all the shapes at 100% opacity, flattened them to
* a single raster image, and _then_ applied the 50% opacity. * a single raster image, and _then_ applied the 50% opacity.
*/ */
if (reduxLayer.isSelected && tool !== 'move') { if (layerState.isSelected && tool !== 'move') {
// We must clear the cache first so Konva will re-draw the group with the new compositing rect // We must clear the cache first so Konva will re-draw the group with the new compositing rect
if (konvaObjectGroup.isCached()) { if (konvaObjectGroup.isCached()) {
konvaObjectGroup.clearCache(); konvaObjectGroup.clearCache();
@ -438,7 +432,7 @@ const renderRegionalGuidanceLayer = (
compositingRect.setAttrs({ compositingRect.setAttrs({
// The rect should be the size of the layer - use the fast method if we don't have a pixel-perfect bbox already // The rect should be the size of the layer - use the fast method if we don't have a pixel-perfect bbox already
...(!reduxLayer.bboxNeedsUpdate && reduxLayer.bbox ? reduxLayer.bbox : getLayerBboxFast(konvaLayer)), ...(!layerState.bboxNeedsUpdate && layerState.bbox ? layerState.bbox : getLayerBboxFast(konvaLayer)),
fill: rgbColor, fill: rgbColor,
opacity: globalMaskLayerOpacity, opacity: globalMaskLayerOpacity,
// Draw this rect only where there are non-transparent pixels under it (e.g. the mask shapes) // Draw this rect only where there are non-transparent pixels under it (e.g. the mask shapes)
@ -459,9 +453,14 @@ const renderRegionalGuidanceLayer = (
} }
}; };
const createInitialImageLayer = (stage: Konva.Stage, reduxLayer: InitialImageLayer): Konva.Layer => { /**
* Creates an initial image konva layer.
* @param stage The konva stage
* @param layerState The initial image layer state
*/
const createIILayer = (stage: Konva.Stage, layerState: InitialImageLayer): Konva.Layer => {
const konvaLayer = new Konva.Layer({ const konvaLayer = new Konva.Layer({
id: reduxLayer.id, id: layerState.id,
name: INITIAL_IMAGE_LAYER_NAME, name: INITIAL_IMAGE_LAYER_NAME,
imageSmoothingEnabled: true, imageSmoothingEnabled: true,
listening: false, listening: false,
@ -470,20 +469,27 @@ const createInitialImageLayer = (stage: Konva.Stage, reduxLayer: InitialImageLay
return konvaLayer; return konvaLayer;
}; };
const createInitialImageLayerImage = (konvaLayer: Konva.Layer, image: HTMLImageElement): Konva.Image => { /**
* Creates the konva image for an initial image layer.
* @param konvaLayer The konva layer
* @param imageEl The image element
*/
const createIILayerImage = (konvaLayer: Konva.Layer, imageEl: HTMLImageElement): Konva.Image => {
const konvaImage = new Konva.Image({ const konvaImage = new Konva.Image({
name: INITIAL_IMAGE_LAYER_IMAGE_NAME, name: INITIAL_IMAGE_LAYER_IMAGE_NAME,
image, image: imageEl,
}); });
konvaLayer.add(konvaImage); konvaLayer.add(konvaImage);
return konvaImage; return konvaImage;
}; };
const updateInitialImageLayerImageAttrs = ( /**
stage: Konva.Stage, * Updates an initial image layer's attributes (width, height, opacity, visibility).
konvaImage: Konva.Image, * @param stage The konva stage
reduxLayer: InitialImageLayer * @param konvaImage The konva image
) => { * @param layerState The initial image layer state
*/
const updateIILayerImageAttrs = (stage: Konva.Stage, konvaImage: Konva.Image, layerState: InitialImageLayer): void => {
// Konva erroneously reports NaN for width and height when the stage is hidden. This causes errors when caching, // Konva erroneously reports NaN for width and height when the stage is hidden. This causes errors when caching,
// but it doesn't seem to break anything. // but it doesn't seem to break anything.
// TODO(psyche): Investigate and report upstream. // TODO(psyche): Investigate and report upstream.
@ -492,46 +498,55 @@ const updateInitialImageLayerImageAttrs = (
if ( if (
konvaImage.width() !== newWidth || konvaImage.width() !== newWidth ||
konvaImage.height() !== newHeight || konvaImage.height() !== newHeight ||
konvaImage.visible() !== reduxLayer.isEnabled konvaImage.visible() !== layerState.isEnabled
) { ) {
konvaImage.setAttrs({ konvaImage.setAttrs({
opacity: reduxLayer.opacity, opacity: layerState.opacity,
scaleX: 1, scaleX: 1,
scaleY: 1, scaleY: 1,
width: stage.width() / stage.scaleX(), width: stage.width() / stage.scaleX(),
height: stage.height() / stage.scaleY(), height: stage.height() / stage.scaleY(),
visible: reduxLayer.isEnabled, visible: layerState.isEnabled,
}); });
} }
if (konvaImage.opacity() !== reduxLayer.opacity) { if (konvaImage.opacity() !== layerState.opacity) {
konvaImage.opacity(reduxLayer.opacity); konvaImage.opacity(layerState.opacity);
} }
}; };
const updateInitialImageLayerImageSource = async ( /**
* Update an initial image layer's image source when the image changes.
* @param stage The konva stage
* @param konvaLayer The konva layer
* @param layerState The initial image layer state
* @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
*/
const updateIILayerImageSource = async (
stage: Konva.Stage, stage: Konva.Stage,
konvaLayer: Konva.Layer, konvaLayer: Konva.Layer,
reduxLayer: InitialImageLayer layerState: InitialImageLayer,
) => { getImageDTO: (imageName: string) => Promise<ImageDTO | null>
if (reduxLayer.image) { ): Promise<void> => {
const imageName = reduxLayer.image.name; if (layerState.image) {
const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(imageName)); const imageName = layerState.image.name;
const imageDTO = await req.unwrap(); const imageDTO = await getImageDTO(imageName);
req.unsubscribe(); if (!imageDTO) {
return;
}
const imageEl = new Image(); const imageEl = new Image();
const imageId = getIILayerImageId(reduxLayer.id, imageName); const imageId = getIILayerImageId(layerState.id, imageName);
imageEl.onload = () => { imageEl.onload = () => {
// Find the existing image or create a new one - must find using the name, bc the id may have just changed // Find the existing image or create a new one - must find using the name, bc the id may have just changed
const konvaImage = const konvaImage =
konvaLayer.findOne<Konva.Image>(`.${INITIAL_IMAGE_LAYER_IMAGE_NAME}`) ?? konvaLayer.findOne<Konva.Image>(`.${INITIAL_IMAGE_LAYER_IMAGE_NAME}`) ??
createInitialImageLayerImage(konvaLayer, imageEl); createIILayerImage(konvaLayer, imageEl);
// Update the image's attributes // Update the image's attributes
konvaImage.setAttrs({ konvaImage.setAttrs({
id: imageId, id: imageId,
image: imageEl, image: imageEl,
}); });
updateInitialImageLayerImageAttrs(stage, konvaImage, reduxLayer); updateIILayerImageAttrs(stage, konvaImage, layerState);
imageEl.id = imageId; imageEl.id = imageId;
}; };
imageEl.src = imageDTO.image_url; imageEl.src = imageDTO.image_url;
@ -540,14 +555,24 @@ const updateInitialImageLayerImageSource = async (
} }
}; };
const renderInitialImageLayer = (stage: Konva.Stage, reduxLayer: InitialImageLayer) => { /**
const konvaLayer = stage.findOne<Konva.Layer>(`#${reduxLayer.id}`) ?? createInitialImageLayer(stage, reduxLayer); * Renders an initial image layer.
* @param stage The konva stage
* @param layerState The initial image layer state
* @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
*/
const renderIILayer = (
stage: Konva.Stage,
layerState: InitialImageLayer,
getImageDTO: (imageName: string) => Promise<ImageDTO | null>
): void => {
const konvaLayer = stage.findOne<Konva.Layer>(`#${layerState.id}`) ?? createIILayer(stage, layerState);
const konvaImage = konvaLayer.findOne<Konva.Image>(`.${INITIAL_IMAGE_LAYER_IMAGE_NAME}`); const konvaImage = konvaLayer.findOne<Konva.Image>(`.${INITIAL_IMAGE_LAYER_IMAGE_NAME}`);
const canvasImageSource = konvaImage?.image(); const canvasImageSource = konvaImage?.image();
let imageSourceNeedsUpdate = false; let imageSourceNeedsUpdate = false;
if (canvasImageSource instanceof HTMLImageElement) { if (canvasImageSource instanceof HTMLImageElement) {
const image = reduxLayer.image; const image = layerState.image;
if (image && canvasImageSource.id !== getCALayerImageId(reduxLayer.id, image.name)) { if (image && canvasImageSource.id !== getCALayerImageId(layerState.id, image.name)) {
imageSourceNeedsUpdate = true; imageSourceNeedsUpdate = true;
} else if (!image) { } else if (!image) {
imageSourceNeedsUpdate = true; imageSourceNeedsUpdate = true;
@ -557,15 +582,20 @@ const renderInitialImageLayer = (stage: Konva.Stage, reduxLayer: InitialImageLay
} }
if (imageSourceNeedsUpdate) { if (imageSourceNeedsUpdate) {
updateInitialImageLayerImageSource(stage, konvaLayer, reduxLayer); updateIILayerImageSource(stage, konvaLayer, layerState, getImageDTO);
} else if (konvaImage) { } else if (konvaImage) {
updateInitialImageLayerImageAttrs(stage, konvaImage, reduxLayer); updateIILayerImageAttrs(stage, konvaImage, layerState);
} }
}; };
const createControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLayer): Konva.Layer => { /**
* Creates a control adapter layer.
* @param stage The konva stage
* @param layerState The control adapter layer state
*/
const createCALayer = (stage: Konva.Stage, layerState: ControlAdapterLayer): Konva.Layer => {
const konvaLayer = new Konva.Layer({ const konvaLayer = new Konva.Layer({
id: reduxLayer.id, id: layerState.id,
name: CA_LAYER_NAME, name: CA_LAYER_NAME,
imageSmoothingEnabled: true, imageSmoothingEnabled: true,
listening: false, listening: false,
@ -574,39 +604,53 @@ const createControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLay
return konvaLayer; return konvaLayer;
}; };
const createControlNetLayerImage = (konvaLayer: Konva.Layer, image: HTMLImageElement): Konva.Image => { /**
* Creates a control adapter layer image.
* @param konvaLayer The konva layer
* @param imageEl The image element
*/
const createCALayerImage = (konvaLayer: Konva.Layer, imageEl: HTMLImageElement): Konva.Image => {
const konvaImage = new Konva.Image({ const konvaImage = new Konva.Image({
name: CA_LAYER_IMAGE_NAME, name: CA_LAYER_IMAGE_NAME,
image, image: imageEl,
}); });
konvaLayer.add(konvaImage); konvaLayer.add(konvaImage);
return konvaImage; return konvaImage;
}; };
const updateControlNetLayerImageSource = async ( /**
* Updates the image source for a control adapter layer. This includes loading the image from the server and updating the konva image.
* @param stage The konva stage
* @param konvaLayer The konva layer
* @param layerState The control adapter layer state
* @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
*/
const updateCALayerImageSource = async (
stage: Konva.Stage, stage: Konva.Stage,
konvaLayer: Konva.Layer, konvaLayer: Konva.Layer,
reduxLayer: ControlAdapterLayer layerState: ControlAdapterLayer,
) => { getImageDTO: (imageName: string) => Promise<ImageDTO | null>
const image = reduxLayer.controlAdapter.processedImage ?? reduxLayer.controlAdapter.image; ): Promise<void> => {
const image = layerState.controlAdapter.processedImage ?? layerState.controlAdapter.image;
if (image) { if (image) {
const imageName = image.name; const imageName = image.name;
const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(imageName)); const imageDTO = await getImageDTO(imageName);
const imageDTO = await req.unwrap(); if (!imageDTO) {
req.unsubscribe(); return;
}
const imageEl = new Image(); const imageEl = new Image();
const imageId = getCALayerImageId(reduxLayer.id, imageName); const imageId = getCALayerImageId(layerState.id, imageName);
imageEl.onload = () => { imageEl.onload = () => {
// Find the existing image or create a new one - must find using the name, bc the id may have just changed // Find the existing image or create a new one - must find using the name, bc the id may have just changed
const konvaImage = const konvaImage =
konvaLayer.findOne<Konva.Image>(`.${CA_LAYER_IMAGE_NAME}`) ?? createControlNetLayerImage(konvaLayer, imageEl); konvaLayer.findOne<Konva.Image>(`.${CA_LAYER_IMAGE_NAME}`) ?? createCALayerImage(konvaLayer, imageEl);
// Update the image's attributes // Update the image's attributes
konvaImage.setAttrs({ konvaImage.setAttrs({
id: imageId, id: imageId,
image: imageEl, image: imageEl,
}); });
updateControlNetLayerImageAttrs(stage, konvaImage, reduxLayer); updateCALayerImageAttrs(stage, konvaImage, layerState);
// Must cache after this to apply the filters // Must cache after this to apply the filters
konvaImage.cache(); konvaImage.cache();
imageEl.id = imageId; imageEl.id = imageId;
@ -617,11 +661,17 @@ const updateControlNetLayerImageSource = async (
} }
}; };
const updateControlNetLayerImageAttrs = ( /**
* Updates the image attributes for a control adapter layer's image (width, height, visibility, opacity, filters).
* @param stage The konva stage
* @param konvaImage The konva image
* @param layerState The control adapter layer state
*/
const updateCALayerImageAttrs = (
stage: Konva.Stage, stage: Konva.Stage,
konvaImage: Konva.Image, konvaImage: Konva.Image,
reduxLayer: ControlAdapterLayer layerState: ControlAdapterLayer
) => { ): void => {
let needsCache = false; let needsCache = false;
// Konva erroneously reports NaN for width and height when the stage is hidden. This causes errors when caching, // Konva erroneously reports NaN for width and height when the stage is hidden. This causes errors when caching,
// but it doesn't seem to break anything. // but it doesn't seem to break anything.
@ -632,36 +682,47 @@ const updateControlNetLayerImageAttrs = (
if ( if (
konvaImage.width() !== newWidth || konvaImage.width() !== newWidth ||
konvaImage.height() !== newHeight || konvaImage.height() !== newHeight ||
konvaImage.visible() !== reduxLayer.isEnabled || konvaImage.visible() !== layerState.isEnabled ||
hasFilter !== reduxLayer.isFilterEnabled hasFilter !== layerState.isFilterEnabled
) { ) {
konvaImage.setAttrs({ konvaImage.setAttrs({
opacity: reduxLayer.opacity, opacity: layerState.opacity,
scaleX: 1, scaleX: 1,
scaleY: 1, scaleY: 1,
width: stage.width() / stage.scaleX(), width: stage.width() / stage.scaleX(),
height: stage.height() / stage.scaleY(), height: stage.height() / stage.scaleY(),
visible: reduxLayer.isEnabled, visible: layerState.isEnabled,
filters: reduxLayer.isFilterEnabled ? [LightnessToAlphaFilter] : [], filters: layerState.isFilterEnabled ? [LightnessToAlphaFilter] : [],
}); });
needsCache = true; needsCache = true;
} }
if (konvaImage.opacity() !== reduxLayer.opacity) { if (konvaImage.opacity() !== layerState.opacity) {
konvaImage.opacity(reduxLayer.opacity); konvaImage.opacity(layerState.opacity);
} }
if (needsCache) { if (needsCache) {
konvaImage.cache(); konvaImage.cache();
} }
}; };
const renderControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLayer) => { /**
const konvaLayer = stage.findOne<Konva.Layer>(`#${reduxLayer.id}`) ?? createControlNetLayer(stage, reduxLayer); * Renders a control adapter layer. If the layer doesn't already exist, it is created. Otherwise, the layer is updated
* with the current image source and attributes.
* @param stage The konva stage
* @param layerState The control adapter layer state
* @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
*/
const renderCALayer = (
stage: Konva.Stage,
layerState: ControlAdapterLayer,
getImageDTO: (imageName: string) => Promise<ImageDTO | null>
): void => {
const konvaLayer = stage.findOne<Konva.Layer>(`#${layerState.id}`) ?? createCALayer(stage, layerState);
const konvaImage = konvaLayer.findOne<Konva.Image>(`.${CA_LAYER_IMAGE_NAME}`); const konvaImage = konvaLayer.findOne<Konva.Image>(`.${CA_LAYER_IMAGE_NAME}`);
const canvasImageSource = konvaImage?.image(); const canvasImageSource = konvaImage?.image();
let imageSourceNeedsUpdate = false; let imageSourceNeedsUpdate = false;
if (canvasImageSource instanceof HTMLImageElement) { if (canvasImageSource instanceof HTMLImageElement) {
const image = reduxLayer.controlAdapter.processedImage ?? reduxLayer.controlAdapter.image; const image = layerState.controlAdapter.processedImage ?? layerState.controlAdapter.image;
if (image && canvasImageSource.id !== getCALayerImageId(reduxLayer.id, image.name)) { if (image && canvasImageSource.id !== getCALayerImageId(layerState.id, image.name)) {
imageSourceNeedsUpdate = true; imageSourceNeedsUpdate = true;
} else if (!image) { } else if (!image) {
imageSourceNeedsUpdate = true; imageSourceNeedsUpdate = true;
@ -671,44 +732,46 @@ const renderControlNetLayer = (stage: Konva.Stage, reduxLayer: ControlAdapterLay
} }
if (imageSourceNeedsUpdate) { if (imageSourceNeedsUpdate) {
updateControlNetLayerImageSource(stage, konvaLayer, reduxLayer); updateCALayerImageSource(stage, konvaLayer, layerState, getImageDTO);
} else if (konvaImage) { } else if (konvaImage) {
updateControlNetLayerImageAttrs(stage, konvaImage, reduxLayer); updateCALayerImageAttrs(stage, konvaImage, layerState);
} }
}; };
/** /**
* Renders the layers on the stage. * Renders the layers on the stage.
* @param stage The konva stage to render on. * @param stage The konva stage
* @param reduxLayers Array of the layers from the redux store. * @param layerStates Array of all layer states
* @param layerOpacity The opacity of the layer. * @param globalMaskLayerOpacity The global mask layer opacity
* @param onLayerPosChanged Callback for when the layer's position changes. This is optional to allow for offscreen rendering. * @param tool The current tool
* @returns * @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
* @param onLayerPosChanged Callback for when the layer's position changes
*/ */
const renderLayers = ( const renderLayers = (
stage: Konva.Stage, stage: Konva.Stage,
reduxLayers: Layer[], layerStates: Layer[],
globalMaskLayerOpacity: number, globalMaskLayerOpacity: number,
tool: Tool, tool: Tool,
getImageDTO: (imageName: string) => Promise<ImageDTO | null>,
onLayerPosChanged?: (layerId: string, x: number, y: number) => void onLayerPosChanged?: (layerId: string, x: number, y: number) => void
) => { ): void => {
const reduxLayerIds = reduxLayers.filter(isRenderableLayer).map(mapId); const layerIds = layerStates.filter(isRenderableLayer).map(mapId);
// Remove un-rendered layers // Remove un-rendered layers
for (const konvaLayer of stage.find<Konva.Layer>(selectRenderableLayers)) { for (const konvaLayer of stage.find<Konva.Layer>(selectRenderableLayers)) {
if (!reduxLayerIds.includes(konvaLayer.id())) { if (!layerIds.includes(konvaLayer.id())) {
konvaLayer.destroy(); konvaLayer.destroy();
} }
} }
for (const reduxLayer of reduxLayers) { for (const layer of layerStates) {
if (isRegionalGuidanceLayer(reduxLayer)) { if (isRegionalGuidanceLayer(layer)) {
renderRegionalGuidanceLayer(stage, reduxLayer, globalMaskLayerOpacity, tool, onLayerPosChanged); renderRGLayer(stage, layer, globalMaskLayerOpacity, tool, onLayerPosChanged);
} }
if (isControlAdapterLayer(reduxLayer)) { if (isControlAdapterLayer(layer)) {
renderControlNetLayer(stage, reduxLayer); renderCALayer(stage, layer, getImageDTO);
} }
if (isInitialImageLayer(reduxLayer)) { if (isInitialImageLayer(layer)) {
renderInitialImageLayer(stage, reduxLayer); renderIILayer(stage, layer, getImageDTO);
} }
// IP Adapter layers are not rendered // IP Adapter layers are not rendered
} }
@ -716,13 +779,12 @@ const renderLayers = (
/** /**
* Creates a bounding box rect for a layer. * Creates a bounding box rect for a layer.
* @param reduxLayer The redux layer to create the bounding box for. * @param layerState The layer state for the layer to create the bounding box for
* @param konvaLayer The konva layer to attach the bounding box to. * @param konvaLayer The konva layer to attach the bounding box to
* @param onBboxMouseDown Callback for when the bounding box is clicked.
*/ */
const createBboxRect = (reduxLayer: Layer, konvaLayer: Konva.Layer) => { const createBboxRect = (layerState: Layer, konvaLayer: Konva.Layer): Konva.Rect => {
const rect = new Konva.Rect({ const rect = new Konva.Rect({
id: getLayerBboxId(reduxLayer.id), id: getLayerBboxId(layerState.id),
name: LAYER_BBOX_NAME, name: LAYER_BBOX_NAME,
strokeWidth: 1, strokeWidth: 1,
visible: false, visible: false,
@ -733,12 +795,12 @@ const createBboxRect = (reduxLayer: Layer, konvaLayer: Konva.Layer) => {
/** /**
* Renders the bounding boxes for the layers. * Renders the bounding boxes for the layers.
* @param stage The konva stage to render on * @param stage The konva stage
* @param reduxLayers An array of all redux layers to draw bboxes for * @param layerStates An array of layers to draw bboxes for
* @param tool The current tool * @param tool The current tool
* @returns * @returns
*/ */
const renderBboxes = (stage: Konva.Stage, reduxLayers: Layer[], tool: Tool) => { const renderBboxes = (stage: Konva.Stage, layerStates: Layer[], tool: Tool): void => {
// Hide all bboxes so they don't interfere with getClientRect // Hide all bboxes so they don't interfere with getClientRect
for (const bboxRect of stage.find<Konva.Rect>(`.${LAYER_BBOX_NAME}`)) { for (const bboxRect of stage.find<Konva.Rect>(`.${LAYER_BBOX_NAME}`)) {
bboxRect.visible(false); bboxRect.visible(false);
@ -749,39 +811,39 @@ const renderBboxes = (stage: Konva.Stage, reduxLayers: Layer[], tool: Tool) => {
return; return;
} }
for (const reduxLayer of reduxLayers.filter(isRegionalGuidanceLayer)) { for (const layer of layerStates.filter(isRegionalGuidanceLayer)) {
if (!reduxLayer.bbox) { if (!layer.bbox) {
continue; continue;
} }
const konvaLayer = stage.findOne<Konva.Layer>(`#${reduxLayer.id}`); const konvaLayer = stage.findOne<Konva.Layer>(`#${layer.id}`);
assert(konvaLayer, `Layer ${reduxLayer.id} not found in stage`); assert(konvaLayer, `Layer ${layer.id} not found in stage`);
const bboxRect = konvaLayer.findOne<Konva.Rect>(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(reduxLayer, konvaLayer); const bboxRect = konvaLayer.findOne<Konva.Rect>(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(layer, konvaLayer);
bboxRect.setAttrs({ bboxRect.setAttrs({
visible: !reduxLayer.bboxNeedsUpdate, visible: !layer.bboxNeedsUpdate,
listening: reduxLayer.isSelected, listening: layer.isSelected,
x: reduxLayer.bbox.x, x: layer.bbox.x,
y: reduxLayer.bbox.y, y: layer.bbox.y,
width: reduxLayer.bbox.width, width: layer.bbox.width,
height: reduxLayer.bbox.height, height: layer.bbox.height,
stroke: reduxLayer.isSelected ? BBOX_SELECTED_STROKE : '', stroke: layer.isSelected ? BBOX_SELECTED_STROKE : '',
}); });
} }
}; };
/** /**
* Calculates the bbox of each regional guidance layer. Only calculates if the mask has changed. * Calculates the bbox of each regional guidance layer. Only calculates if the mask has changed.
* @param stage The konva stage to render on. * @param stage The konva stage
* @param reduxLayers An array of redux layers to calculate bboxes for * @param layerStates An array of layers to calculate bboxes for
* @param onBboxChanged Callback for when the bounding box changes * @param onBboxChanged Callback for when the bounding box changes
*/ */
const updateBboxes = ( const updateBboxes = (
stage: Konva.Stage, stage: Konva.Stage,
reduxLayers: Layer[], layerStates: Layer[],
onBboxChanged: (layerId: string, bbox: IRect | null) => void onBboxChanged: (layerId: string, bbox: IRect | null) => void
) => { ): void => {
for (const rgLayer of reduxLayers.filter(isRegionalGuidanceLayer)) { for (const rgLayer of layerStates.filter(isRegionalGuidanceLayer)) {
const konvaLayer = stage.findOne<Konva.Layer>(`#${rgLayer.id}`); const konvaLayer = stage.findOne<Konva.Layer>(`#${rgLayer.id}`);
assert(konvaLayer, `Layer ${rgLayer.id} not found in stage`); assert(konvaLayer, `Layer ${rgLayer.id} not found in stage`);
// We only need to recalculate the bbox if the layer has changed // We only need to recalculate the bbox if the layer has changed
@ -808,7 +870,7 @@ const updateBboxes = (
/** /**
* Creates the background layer for the stage. * Creates the background layer for the stage.
* @param stage The konva stage to render on * @param stage The konva stage
*/ */
const createBackgroundLayer = (stage: Konva.Stage): Konva.Layer => { const createBackgroundLayer = (stage: Konva.Stage): Konva.Layer => {
const layer = new Konva.Layer({ const layer = new Konva.Layer({
@ -829,17 +891,17 @@ const createBackgroundLayer = (stage: Konva.Stage): Konva.Layer => {
image.onload = () => { image.onload = () => {
background.fillPatternImage(image); background.fillPatternImage(image);
}; };
image.src = STAGE_BG_DATAURL; image.src = TRANSPARENCY_CHECKER_PATTERN;
return layer; return layer;
}; };
/** /**
* Renders the background layer for the stage. * Renders the background layer for the stage.
* @param stage The konva stage to render on * @param stage The konva stage
* @param width The unscaled width of the canvas * @param width The unscaled width of the canvas
* @param height The unscaled height of the canvas * @param height The unscaled height of the canvas
*/ */
const renderBackground = (stage: Konva.Stage, width: number, height: number) => { const renderBackground = (stage: Konva.Stage, width: number, height: number): void => {
const layer = stage.findOne<Konva.Layer>(`#${BACKGROUND_LAYER_ID}`) ?? createBackgroundLayer(stage); const layer = stage.findOne<Konva.Layer>(`#${BACKGROUND_LAYER_ID}`) ?? createBackgroundLayer(stage);
const background = layer.findOne<Konva.Rect>(`#${BACKGROUND_RECT_ID}`); const background = layer.findOne<Konva.Rect>(`#${BACKGROUND_RECT_ID}`);
@ -880,6 +942,10 @@ const arrangeLayers = (stage: Konva.Stage, layerIds: string[]): void => {
stage.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.zIndex(nextZIndex++); stage.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.zIndex(nextZIndex++);
}; };
/**
* Creates the "no layers" fallback layer
* @param stage The konva stage
*/
const createNoLayersMessageLayer = (stage: Konva.Stage): Konva.Layer => { const createNoLayersMessageLayer = (stage: Konva.Stage): Konva.Layer => {
const noLayersMessageLayer = new Konva.Layer({ const noLayersMessageLayer = new Konva.Layer({
id: NO_LAYERS_MESSAGE_LAYER_ID, id: NO_LAYERS_MESSAGE_LAYER_ID,
@ -891,7 +957,7 @@ const createNoLayersMessageLayer = (stage: Konva.Stage): Konva.Layer => {
y: 0, y: 0,
align: 'center', align: 'center',
verticalAlign: 'middle', verticalAlign: 'middle',
text: t('controlLayers.noLayersAdded'), text: t('controlLayers.noLayersAdded', 'No Layers Added'),
fontFamily: '"Inter Variable", sans-serif', fontFamily: '"Inter Variable", sans-serif',
fontStyle: '600', fontStyle: '600',
fill: 'white', fill: 'white',
@ -901,7 +967,14 @@ const createNoLayersMessageLayer = (stage: Konva.Stage): Konva.Layer => {
return noLayersMessageLayer; return noLayersMessageLayer;
}; };
const renderNoLayersMessage = (stage: Konva.Stage, layerCount: number, width: number, height: number) => { /**
* Renders the "no layers" message when there are no layers to render
* @param stage The konva stage
* @param layerCount The current number of layers
* @param width The target width of the text
* @param height The target height of the text
*/
const renderNoLayersMessage = (stage: Konva.Stage, layerCount: number, width: number, height: number): void => {
const noLayersMessageLayer = const noLayersMessageLayer =
stage.findOne<Konva.Layer>(`#${NO_LAYERS_MESSAGE_LAYER_ID}`) ?? createNoLayersMessageLayer(stage); stage.findOne<Konva.Layer>(`#${NO_LAYERS_MESSAGE_LAYER_ID}`) ?? createNoLayersMessageLayer(stage);
if (layerCount === 0) { if (layerCount === 0) {
@ -936,20 +1009,3 @@ export const debouncedRenderers = {
arrangeLayers: debounce(arrangeLayers, DEBOUNCE_MS), arrangeLayers: debounce(arrangeLayers, DEBOUNCE_MS),
updateBboxes: debounce(updateBboxes, DEBOUNCE_MS), updateBboxes: debounce(updateBboxes, DEBOUNCE_MS),
}; };
/**
* Calculates the lightness (HSL) of a given pixel and sets the alpha channel to that value.
* This is useful for edge maps and other masks, to make the black areas transparent.
* @param imageData The image data to apply the filter to
*/
const LightnessToAlphaFilter = (imageData: ImageData) => {
const len = imageData.data.length / 4;
for (let i = 0; i < len; i++) {
const r = imageData.data[i * 4 + 0] as number;
const g = imageData.data[i * 4 + 1] as number;
const b = imageData.data[i * 4 + 2] as number;
const cMin = Math.min(r, g, b);
const cMax = Math.max(r, g, b);
imageData.data[i * 4 + 3] = (cMin + cMax) / 2;
}
};

View File

@ -0,0 +1,67 @@
import type Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node';
import type { Vector2d } from 'konva/lib/types';
//#region getScaledFlooredCursorPosition
/**
* Gets the scaled and floored cursor position on the stage. If the cursor is not currently over the stage, returns null.
* @param stage The konva stage
*/
export const getScaledFlooredCursorPosition = (stage: Konva.Stage): Vector2d | null => {
const pointerPosition = stage.getPointerPosition();
const stageTransform = stage.getAbsoluteTransform().copy();
if (!pointerPosition) {
return null;
}
const scaledCursorPosition = stageTransform.invert().point(pointerPosition);
return {
x: Math.floor(scaledCursorPosition.x),
y: Math.floor(scaledCursorPosition.y),
};
};
//#endregion
//#region snapPosToStage
/**
* Snaps a position to the edge of the stage if within a threshold of the edge
* @param pos The position to snap
* @param stage The konva stage
* @param snapPx The snap threshold in pixels
*/
export const snapPosToStage = (pos: Vector2d, stage: Konva.Stage, snapPx = 10): Vector2d => {
const snappedPos = { ...pos };
// Get the normalized threshold for snapping to the edge of the stage
const thresholdX = snapPx / stage.scaleX();
const thresholdY = snapPx / stage.scaleY();
const stageWidth = stage.width() / stage.scaleX();
const stageHeight = stage.height() / stage.scaleY();
// Snap to the edge of the stage if within threshold
if (pos.x - thresholdX < 0) {
snappedPos.x = 0;
} else if (pos.x + thresholdX > stageWidth) {
snappedPos.x = Math.floor(stageWidth);
}
if (pos.y - thresholdY < 0) {
snappedPos.y = 0;
} else if (pos.y + thresholdY > stageHeight) {
snappedPos.y = Math.floor(stageHeight);
}
return snappedPos;
};
//#endregion
//#region getIsMouseDown
/**
* Checks if the left mouse button is currently pressed
* @param e The konva event
*/
export const getIsMouseDown = (e: KonvaEventObject<MouseEvent>): boolean => e.evt.buttons === 1;
//#endregion
//#region getIsFocused
/**
* Checks if the stage is currently focused
* @param stage The konva stage
*/
export const getIsFocused = (stage: Konva.Stage): boolean => stage.container().contains(document.activeElement);
//#endregion

View File

@ -4,6 +4,14 @@ import type { PersistConfig, RootState } from 'app/store/store';
import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils'; import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { roundDownToMultiple } from 'common/util/roundDownToMultiple'; import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
import {
getCALayerId,
getIPALayerId,
getRGLayerId,
getRGLayerLineId,
getRGLayerRectId,
INITIAL_IMAGE_LAYER_ID,
} from 'features/controlLayers/konva/naming';
import type { import type {
CLIPVisionModelV2, CLIPVisionModelV2,
ControlModeV2, ControlModeV2,
@ -36,6 +44,9 @@ import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import type { import type {
AddLineArg,
AddPointToLineArg,
AddRectArg,
ControlAdapterLayer, ControlAdapterLayer,
ControlLayersState, ControlLayersState,
DrawingTool, DrawingTool,
@ -492,11 +503,11 @@ export const controlLayersSlice = createSlice({
layer.bboxNeedsUpdate = true; layer.bboxNeedsUpdate = true;
layer.uploadedMaskImage = null; layer.uploadedMaskImage = null;
}, },
prepare: (payload: { layerId: string; points: [number, number, number, number]; tool: DrawingTool }) => ({ prepare: (payload: AddLineArg) => ({
payload: { ...payload, lineUuid: uuidv4() }, payload: { ...payload, lineUuid: uuidv4() },
}), }),
}, },
rgLayerPointsAdded: (state, action: PayloadAction<{ layerId: string; point: [number, number] }>) => { rgLayerPointsAdded: (state, action: PayloadAction<AddPointToLineArg>) => {
const { layerId, point } = action.payload; const { layerId, point } = action.payload;
const layer = selectRGLayerOrThrow(state, layerId); const layer = selectRGLayerOrThrow(state, layerId);
const lastLine = layer.maskObjects.findLast(isLine); const lastLine = layer.maskObjects.findLast(isLine);
@ -529,7 +540,7 @@ export const controlLayersSlice = createSlice({
layer.bboxNeedsUpdate = true; layer.bboxNeedsUpdate = true;
layer.uploadedMaskImage = null; layer.uploadedMaskImage = null;
}, },
prepare: (payload: { layerId: string; rect: IRect }) => ({ payload: { ...payload, rectUuid: uuidv4() } }), prepare: (payload: AddRectArg) => ({ payload: { ...payload, rectUuid: uuidv4() } }),
}, },
rgLayerMaskImageUploaded: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO }>) => { rgLayerMaskImageUploaded: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO }>) => {
const { layerId, imageDTO } = action.payload; const { layerId, imageDTO } = action.payload;
@ -883,45 +894,21 @@ const migrateControlLayersState = (state: any): any => {
return state; return state;
}; };
// Ephemeral interaction state
export const $isDrawing = atom(false); export const $isDrawing = atom(false);
export const $lastMouseDownPos = atom<Vector2d | null>(null); export const $lastMouseDownPos = atom<Vector2d | null>(null);
export const $tool = atom<Tool>('brush'); export const $tool = atom<Tool>('brush');
export const $lastCursorPos = atom<Vector2d | null>(null); export const $lastCursorPos = atom<Vector2d | null>(null);
export const $isPreviewVisible = atom(true);
export const $lastAddedPoint = atom<Vector2d | null>(null);
// IDs for singleton Konva layers and objects // Some nanostores that are manually synced to redux state to provide imperative access
export const TOOL_PREVIEW_LAYER_ID = 'tool_preview_layer'; // TODO(psyche): This is a hack, figure out another way to handle this...
export const TOOL_PREVIEW_BRUSH_GROUP_ID = 'tool_preview_layer.brush_group'; export const $brushSize = atom<number>(0);
export const TOOL_PREVIEW_BRUSH_FILL_ID = 'tool_preview_layer.brush_fill'; export const $brushSpacingPx = atom<number>(0);
export const TOOL_PREVIEW_BRUSH_BORDER_INNER_ID = 'tool_preview_layer.brush_border_inner'; export const $selectedLayerId = atom<string | null>(null);
export const TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID = 'tool_preview_layer.brush_border_outer'; export const $selectedLayerType = atom<Layer['type'] | null>(null);
export const TOOL_PREVIEW_RECT_ID = 'tool_preview_layer.rect'; export const $shouldInvertBrushSizeScrollDirection = atom(false);
export const BACKGROUND_LAYER_ID = 'background_layer';
export const BACKGROUND_RECT_ID = 'background_layer.rect';
export const NO_LAYERS_MESSAGE_LAYER_ID = 'no_layers_message';
// Names (aka classes) for Konva layers and objects
export const CA_LAYER_NAME = 'control_adapter_layer';
export const CA_LAYER_IMAGE_NAME = 'control_adapter_layer.image';
export const RG_LAYER_NAME = 'regional_guidance_layer';
export const RG_LAYER_LINE_NAME = 'regional_guidance_layer.line';
export const RG_LAYER_OBJECT_GROUP_NAME = 'regional_guidance_layer.object_group';
export const RG_LAYER_RECT_NAME = 'regional_guidance_layer.rect';
export const INITIAL_IMAGE_LAYER_ID = 'singleton_initial_image_layer';
export const INITIAL_IMAGE_LAYER_NAME = 'initial_image_layer';
export const INITIAL_IMAGE_LAYER_IMAGE_NAME = 'initial_image_layer.image';
export const LAYER_BBOX_NAME = 'layer.bbox';
export const COMPOSITING_RECT_NAME = 'compositing-rect';
// Getters for non-singleton layer and object IDs
export const getRGLayerId = (layerId: string) => `${RG_LAYER_NAME}_${layerId}`;
const getRGLayerLineId = (layerId: string, lineId: string) => `${layerId}.line_${lineId}`;
const getRGLayerRectId = (layerId: string, lineId: string) => `${layerId}.rect_${lineId}`;
export const getRGLayerObjectGroupId = (layerId: string, groupId: string) => `${layerId}.objectGroup_${groupId}`;
export const getLayerBboxId = (layerId: string) => `${layerId}.bbox`;
export const getCALayerId = (layerId: string) => `control_adapter_layer_${layerId}`;
export const getCALayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`;
export const getIILayerImageId = (layerId: string, imageName: string) => `${layerId}.image_${imageName}`;
export const getIPALayerId = (layerId: string) => `ip_adapter_layer_${layerId}`;
export const controlLayersPersistConfig: PersistConfig<ControlLayersState> = { export const controlLayersPersistConfig: PersistConfig<ControlLayersState> = {
name: controlLayersSlice.name, name: controlLayersSlice.name,

View File

@ -17,6 +17,7 @@ import {
zParameterPositivePrompt, zParameterPositivePrompt,
zParameterStrength, zParameterStrength,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
import type { IRect } from 'konva/lib/types';
import { z } from 'zod'; import { z } from 'zod';
const zTool = z.enum(['brush', 'eraser', 'move', 'rect']); const zTool = z.enum(['brush', 'eraser', 'move', 'rect']);
@ -129,3 +130,7 @@ export type ControlLayersState = {
aspectRatio: AspectRatioState; aspectRatio: AspectRatioState;
}; };
}; };
export type AddLineArg = { layerId: string; points: [number, number, number, number]; tool: DrawingTool };
export type AddPointToLineArg = { layerId: string; point: [number, number] };
export type AddRectArg = { layerId: string; rect: IRect };

View File

@ -1,66 +0,0 @@
import { getStore } from 'app/store/nanostores/store';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
import { isRegionalGuidanceLayer, RG_LAYER_NAME } from 'features/controlLayers/store/controlLayersSlice';
import { renderers } from 'features/controlLayers/util/renderers';
import Konva from 'konva';
import { assert } from 'tsafe';
/**
* Get the blobs of all regional prompt layers. Only visible layers are returned.
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
* @param preview Whether to open a new tab displaying each layer.
* @returns A map of layer IDs to blobs.
*/
export const getRegionalPromptLayerBlobs = async (
layerIds?: string[],
preview: boolean = false
): Promise<Record<string, Blob>> => {
const state = getStore().getState();
const { layers } = state.controlLayers.present;
const { width, height } = state.controlLayers.present.size;
const reduxLayers = layers.filter(isRegionalGuidanceLayer);
const container = document.createElement('div');
const stage = new Konva.Stage({ container, width, height });
renderers.renderLayers(stage, reduxLayers, 1, 'brush');
const konvaLayers = stage.find<Konva.Layer>(`.${RG_LAYER_NAME}`);
const blobs: Record<string, Blob> = {};
// First remove all layers
for (const layer of konvaLayers) {
layer.remove();
}
// Next render each layer to a blob
for (const layer of konvaLayers) {
if (layerIds && !layerIds.includes(layer.id())) {
continue;
}
const reduxLayer = reduxLayers.find((l) => l.id === layer.id());
assert(reduxLayer, `Redux layer ${layer.id()} not found`);
stage.add(layer);
const blob = await new Promise<Blob>((resolve) => {
stage.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
});
});
if (preview) {
const base64 = await blobToDataURL(blob);
openBase64ImageInTab([
{
base64,
caption: `${reduxLayer.id}: ${reduxLayer.positivePrompt} / ${reduxLayer.negativePrompt}`,
},
]);
}
layer.remove();
blobs[layer.id()] = blob;
}
return blobs;
};

View File

@ -28,7 +28,9 @@ const ImageMetadataGraphTabContent = ({ image }: Props) => {
return <IAINoContentFallback label={t('nodes.noGraph')} />; return <IAINoContentFallback label={t('nodes.noGraph')} />;
} }
return <DataViewer data={graph} label={t('nodes.graph')} />; return (
<DataViewer fileName={`${image.image_name.replace('.png', '')}_graph`} data={graph} label={t('nodes.graph')} />
);
}; };
export default memo(ImageMetadataGraphTabContent); export default memo(ImageMetadataGraphTabContent);

View File

@ -68,14 +68,22 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
</TabPanel> </TabPanel>
<TabPanel> <TabPanel>
{metadata ? ( {metadata ? (
<DataViewer data={metadata} label={t('metadata.metadata')} /> <DataViewer
fileName={`${image.image_name.replace('.png', '')}_metadata`}
data={metadata}
label={t('metadata.metadata')}
/>
) : ( ) : (
<IAINoContentFallback label={t('metadata.noMetaData')} /> <IAINoContentFallback label={t('metadata.noMetaData')} />
)} )}
</TabPanel> </TabPanel>
<TabPanel> <TabPanel>
{image ? ( {image ? (
<DataViewer data={image} label={t('metadata.imageDetails')} /> <DataViewer
fileName={`${image.image_name.replace('.png', '')}_details`}
data={image}
label={t('metadata.imageDetails')}
/>
) : ( ) : (
<IAINoContentFallback label={t('metadata.noImageDetails')} /> <IAINoContentFallback label={t('metadata.noImageDetails')} />
)} )}

View File

@ -28,7 +28,13 @@ const ImageMetadataWorkflowTabContent = ({ image }: Props) => {
return <IAINoContentFallback label={t('nodes.noWorkflow')} />; return <IAINoContentFallback label={t('nodes.noWorkflow')} />;
} }
return <DataViewer data={workflow} label={t('metadata.workflow')} />; return (
<DataViewer
fileName={`${image.image_name.replace('.png', '')}_workflow`}
data={workflow}
label={t('metadata.workflow')}
/>
);
}; };
export default memo(ImageMetadataWorkflowTabContent); export default memo(ImageMetadataWorkflowTabContent);

View File

@ -3,7 +3,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { useBoolean } from 'common/hooks/useBoolean'; import { useBoolean } from 'common/hooks/useBoolean';
import { preventDefault } from 'common/util/stopPropagation'; import { preventDefault } from 'common/util/stopPropagation';
import type { Dimensions } from 'features/canvas/store/canvasTypes'; import type { Dimensions } from 'features/canvas/store/canvasTypes';
import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers'; import { TRANSPARENCY_CHECKER_PATTERN } from 'features/controlLayers/konva/constants';
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel'; import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
import { memo, useMemo, useRef } from 'react'; import { memo, useMemo, useRef } from 'react';
@ -78,7 +78,7 @@ export const ImageComparisonHover = memo(({ firstImage, secondImage, containerDi
left={0} left={0}
right={0} right={0}
bottom={0} bottom={0}
backgroundImage={STAGE_BG_DATAURL} backgroundImage={TRANSPARENCY_CHECKER_PATTERN}
backgroundRepeat="repeat" backgroundRepeat="repeat"
opacity={0.2} opacity={0.2}
/> />

View File

@ -2,7 +2,7 @@ import { Box, Flex, Icon, Image } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { preventDefault } from 'common/util/stopPropagation'; import { preventDefault } from 'common/util/stopPropagation';
import type { Dimensions } from 'features/canvas/store/canvasTypes'; import type { Dimensions } from 'features/canvas/store/canvasTypes';
import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers'; import { TRANSPARENCY_CHECKER_PATTERN } from 'features/controlLayers/konva/constants';
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel'; import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi'; import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
@ -120,7 +120,7 @@ export const ImageComparisonSlider = memo(({ firstImage, secondImage, containerD
left={0} left={0}
right={0} right={0}
bottom={0} bottom={0}
backgroundImage={STAGE_BG_DATAURL} backgroundImage={TRANSPARENCY_CHECKER_PATTERN}
backgroundRepeat="repeat" backgroundRepeat="repeat"
opacity={0.2} opacity={0.2}
/> />

View File

@ -1,4 +1,7 @@
import { getStore } from 'app/store/nanostores/store';
import { deepClone } from 'common/util/deepClone';
import { objectKeys } from 'common/util/objectKeys'; import { objectKeys } from 'common/util/objectKeys';
import { shouldConcatPromptsChanged } from 'features/controlLayers/store/controlLayersSlice';
import type { Layer } from 'features/controlLayers/store/types'; import type { Layer } from 'features/controlLayers/store/types';
import type { LoRA } from 'features/lora/store/loraSlice'; import type { LoRA } from 'features/lora/store/loraSlice';
import type { import type {
@ -16,6 +19,7 @@ import { validators } from 'features/metadata/util/validators';
import type { ModelIdentifierField } from 'features/nodes/types/common'; import type { ModelIdentifierField } from 'features/nodes/types/common';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { size } from 'lodash-es';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { parsers } from './parsers'; import { parsers } from './parsers';
@ -376,54 +380,25 @@ export const handlers = {
}), }),
} as const; } as const;
type ParsedValue = Awaited<ReturnType<(typeof handlers)[keyof typeof handlers]['parse']>>;
type RecallResults = Partial<Record<keyof typeof handlers, ParsedValue>>;
export const parseAndRecallPrompts = async (metadata: unknown) => { export const parseAndRecallPrompts = async (metadata: unknown) => {
const results = await Promise.allSettled([ const keysToRecall: (keyof typeof handlers)[] = [
handlers.positivePrompt.parse(metadata).then((positivePrompt) => { 'positivePrompt',
if (!handlers.positivePrompt.recall) { 'negativePrompt',
return; 'sdxlPositiveStylePrompt',
} 'sdxlNegativeStylePrompt',
handlers.positivePrompt?.recall(positivePrompt); ];
}), const recalled = await recallKeys(keysToRecall, metadata);
handlers.negativePrompt.parse(metadata).then((negativePrompt) => { if (size(recalled) > 0) {
if (!handlers.negativePrompt.recall) {
return;
}
handlers.negativePrompt?.recall(negativePrompt);
}),
handlers.sdxlPositiveStylePrompt.parse(metadata).then((sdxlPositiveStylePrompt) => {
if (!handlers.sdxlPositiveStylePrompt.recall) {
return;
}
handlers.sdxlPositiveStylePrompt?.recall(sdxlPositiveStylePrompt);
}),
handlers.sdxlNegativeStylePrompt.parse(metadata).then((sdxlNegativeStylePrompt) => {
if (!handlers.sdxlNegativeStylePrompt.recall) {
return;
}
handlers.sdxlNegativeStylePrompt?.recall(sdxlNegativeStylePrompt);
}),
]);
if (results.some((result) => result.status === 'fulfilled')) {
parameterSetToast(t('metadata.allPrompts')); parameterSetToast(t('metadata.allPrompts'));
} }
}; };
export const parseAndRecallImageDimensions = async (metadata: unknown) => { export const parseAndRecallImageDimensions = async (metadata: unknown) => {
const results = await Promise.allSettled([ const recalled = recallKeys(['width', 'height'], metadata);
handlers.width.parse(metadata).then((width) => { if (size(recalled) > 0) {
if (!handlers.width.recall) {
return;
}
handlers.width?.recall(width);
}),
handlers.height.parse(metadata).then((height) => {
if (!handlers.height.recall) {
return;
}
handlers.height?.recall(height);
}),
]);
if (results.some((result) => result.status === 'fulfilled')) {
parameterSetToast(t('metadata.imageDimensions')); parameterSetToast(t('metadata.imageDimensions'));
} }
}; };
@ -438,28 +413,20 @@ export const parseAndRecallAllMetadata = async (
toControlLayers: boolean, toControlLayers: boolean,
skip: (keyof typeof handlers)[] = [] skip: (keyof typeof handlers)[] = []
) => { ) => {
const skipKeys = skip ?? []; const skipKeys = deepClone(skip);
if (toControlLayers) { if (toControlLayers) {
skipKeys.push(...TO_CONTROL_LAYERS_SKIP_KEYS); skipKeys.push(...TO_CONTROL_LAYERS_SKIP_KEYS);
} else { } else {
skipKeys.push(...NOT_TO_CONTROL_LAYERS_SKIP_KEYS); skipKeys.push(...NOT_TO_CONTROL_LAYERS_SKIP_KEYS);
} }
const results = await Promise.allSettled(
objectKeys(handlers)
.filter((key) => !skipKeys.includes(key))
.map((key) => {
const { parse, recall } = handlers[key];
return parse(metadata).then((value) => {
if (!recall) {
return;
}
/* @ts-expect-error The return type of parse and the input type of recall are guaranteed to be compatible. */
recall(value);
});
})
);
if (results.some((result) => result.status === 'fulfilled')) { // We may need to take some further action depending on what was recalled. For example, we need to disable SDXL prompt
// concat if the negative or positive style prompt was set. Because the recalling is all async, we need to collect all
// results
const keysToRecall = objectKeys(handlers).filter((key) => !skipKeys.includes(key));
const recalled = await recallKeys(keysToRecall, metadata);
if (size(recalled) > 0) {
toast({ toast({
id: 'PARAMETER_SET', id: 'PARAMETER_SET',
title: t('toast.parametersSet'), title: t('toast.parametersSet'),
@ -473,3 +440,43 @@ export const parseAndRecallAllMetadata = async (
}); });
} }
}; };
/**
* Recalls a set of keys from metadata.
* Includes special handling for some metadata where recalling may have side effects. For example, recalling a "style"
* prompt that is different from the "positive" or "negative" prompt should disable prompt concatenation.
* @param keysToRecall An array of keys to recall.
* @param metadata The metadata to recall from
* @returns A promise that resolves to an object containing the recalled values.
*/
const recallKeys = async (keysToRecall: (keyof typeof handlers)[], metadata: unknown): Promise<RecallResults> => {
const { dispatch } = getStore();
const recalled: RecallResults = {};
for (const key of keysToRecall) {
const { parse, recall } = handlers[key];
if (!recall) {
continue;
}
try {
const value = await parse(metadata);
/* @ts-expect-error The return type of parse and the input type of recall are guaranteed to be compatible. */
await recall(value);
recalled[key] = value;
} catch {
// no-op
}
}
if (
(recalled['sdxlPositiveStylePrompt'] && recalled['sdxlPositiveStylePrompt'] !== recalled['positivePrompt']) ||
(recalled['sdxlNegativeStylePrompt'] && recalled['sdxlNegativeStylePrompt'] !== recalled['negativePrompt'])
) {
// If we set the negative style prompt or positive style prompt, we should disable prompt concat
dispatch(shouldConcatPromptsChanged(false));
} else {
// Otherwise, we should enable prompt concat
dispatch(shouldConcatPromptsChanged(true));
}
return recalled;
};

View File

@ -1,6 +1,7 @@
import { getStore } from 'app/store/nanostores/store'; import { getStore } from 'app/store/nanostores/store';
import type { ModelIdentifierField } from 'features/nodes/types/common'; import type { ModelIdentifierField } from 'features/nodes/types/common';
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common'; import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
import type { ModelIdentifier } from 'features/nodes/types/v2/common';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types'; import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
@ -107,19 +108,30 @@ export const fetchModelConfigWithTypeGuard = async <T extends AnyModelConfig>(
/** /**
* Fetches the model key from a model identifier. This includes fetching the key for MM1 format model identifiers. * Fetches the model key from a model identifier. This includes fetching the key for MM1 format model identifiers.
* @param modelIdentifier The model identifier. The MM2 format `{key: string}` simply extracts the key. The MM1 format * @param modelIdentifier The model identifier. This can be a MM1 or MM2 identifier. In every case, we attempt to fetch
* `{model_name: string, base_model: BaseModelType}` must do a network request to fetch the key. * the model config from the server to ensure that the model identifier is valid and represents an installed model.
* @param type The type of model to fetch. This is used to fetch the key for MM1 format model identifiers. * @param type The type of model to fetch. This is used to fetch the key for MM1 format model identifiers.
* @param message An optional custom message to include in the error if the model identifier is invalid. * @param message An optional custom message to include in the error if the model identifier is invalid.
* @returns A promise that resolves to the model key. * @returns A promise that resolves to the model key.
* @throws {InvalidModelConfigError} If the model identifier is invalid. * @throws {InvalidModelConfigError} If the model identifier is invalid.
*/ */
export const getModelKey = async (modelIdentifier: unknown, type: ModelType, message?: string): Promise<string> => { export const getModelKey = async (
modelIdentifier: unknown | ModelIdentifierField | ModelIdentifier,
type: ModelType,
message?: string
): Promise<string> => {
if (isModelIdentifier(modelIdentifier)) { if (isModelIdentifier(modelIdentifier)) {
return modelIdentifier.key; try {
// Check if the model exists by key
return (await fetchModelConfig(modelIdentifier.key)).key;
} catch {
// If not, fetch the model key by name and base model
return (await fetchModelConfigByAttrs(modelIdentifier.name, modelIdentifier.base, type)).key;
} }
if (isModelIdentifierV2(modelIdentifier)) { } else if (isModelIdentifierV2(modelIdentifier)) {
// Try by old-format model identifier
return (await fetchModelConfigByAttrs(modelIdentifier.model_name, modelIdentifier.base_model, type)).key; return (await fetchModelConfigByAttrs(modelIdentifier.model_name, modelIdentifier.base_model, type)).key;
} }
// Nope, couldn't find it
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`); throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
}; };

View File

@ -4,7 +4,7 @@ import {
initialT2IAdapter, initialT2IAdapter,
} from 'features/controlAdapters/util/buildControlAdapter'; } from 'features/controlAdapters/util/buildControlAdapter';
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor'; import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
import { getCALayerId, getIPALayerId, INITIAL_IMAGE_LAYER_ID } from 'features/controlLayers/store/controlLayersSlice'; import { getCALayerId, getIPALayerId, INITIAL_IMAGE_LAYER_ID } from 'features/controlLayers/konva/naming';
import type { ControlAdapterLayer, InitialImageLayer, IPAdapterLayer, Layer } from 'features/controlLayers/store/types'; import type { ControlAdapterLayer, InitialImageLayer, IPAdapterLayer, Layer } from 'features/controlLayers/store/types';
import { zLayer } from 'features/controlLayers/store/types'; import { zLayer } from 'features/controlLayers/store/types';
import { import {

View File

@ -6,12 +6,10 @@ import {
ipAdaptersReset, ipAdaptersReset,
t2iAdaptersReset, t2iAdaptersReset,
} from 'features/controlAdapters/store/controlAdaptersSlice'; } from 'features/controlAdapters/store/controlAdaptersSlice';
import { getCALayerId, getIPALayerId, getRGLayerId } from 'features/controlLayers/konva/naming';
import { import {
allLayersDeleted, allLayersDeleted,
caLayerRecalled, caLayerRecalled,
getCALayerId,
getIPALayerId,
getRGLayerId,
heightChanged, heightChanged,
iiLayerRecalled, iiLayerRecalled,
ipaLayerRecalled, ipaLayerRecalled,

View File

@ -1,6 +1,10 @@
import { getStore } from 'app/store/nanostores/store'; import { getStore } from 'app/store/nanostores/store';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
import { RG_LAYER_NAME } from 'features/controlLayers/konva/naming';
import { renderers } from 'features/controlLayers/konva/renderers';
import { import {
isControlAdapterLayer, isControlAdapterLayer,
isInitialImageLayer, isInitialImageLayer,
@ -16,7 +20,6 @@ import type {
ProcessorConfig, ProcessorConfig,
T2IAdapterConfigV2, T2IAdapterConfigV2,
} from 'features/controlLayers/util/controlAdapters'; } from 'features/controlLayers/util/controlAdapters';
import { getRegionalPromptLayerBlobs } from 'features/controlLayers/util/getLayerBlobs';
import type { ImageField } from 'features/nodes/types/common'; import type { ImageField } from 'features/nodes/types/common';
import { import {
CONTROL_NET_COLLECT, CONTROL_NET_COLLECT,
@ -31,11 +34,13 @@ import {
T2I_ADAPTER_COLLECT, T2I_ADAPTER_COLLECT,
} from 'features/nodes/util/graph/constants'; } from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import Konva from 'konva';
import { size } from 'lodash-es'; import { size } from 'lodash-es';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images'; import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types'; import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
//#region addControlLayers
/** /**
* Adds the control layers to the graph * Adds the control layers to the graph
* @param state The app root state * @param state The app root state
@ -90,7 +95,7 @@ export const addControlLayers = async (
const validRGLayers = validLayers.filter(isRegionalGuidanceLayer); const validRGLayers = validLayers.filter(isRegionalGuidanceLayer);
const layerIds = validRGLayers.map((l) => l.id); const layerIds = validRGLayers.map((l) => l.id);
const blobs = await getRegionalPromptLayerBlobs(layerIds); const blobs = await getRGLayerBlobs(layerIds);
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs'); assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
for (const layer of validRGLayers) { for (const layer of validRGLayers) {
@ -257,6 +262,7 @@ export const addControlLayers = async (
g.upsertMetadata({ control_layers: { layers: validLayers, version: state.controlLayers.present._version } }); g.upsertMetadata({ control_layers: { layers: validLayers, version: state.controlLayers.present._version } });
return validLayers; return validLayers;
}; };
//#endregion
//#region Control Adapters //#region Control Adapters
const addGlobalControlAdapterToGraph = ( const addGlobalControlAdapterToGraph = (
@ -509,7 +515,7 @@ const isValidLayer = (layer: Layer, base: BaseModelType) => {
}; };
//#endregion //#endregion
//#region Helpers //#region getMaskImage
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => { const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {
if (layer.uploadedMaskImage) { if (layer.uploadedMaskImage) {
const imageDTO = await getImageDTO(layer.uploadedMaskImage.name); const imageDTO = await getImageDTO(layer.uploadedMaskImage.name);
@ -529,7 +535,9 @@ const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<I
dispatch(rgLayerMaskImageUploaded({ layerId: layer.id, imageDTO })); dispatch(rgLayerMaskImageUploaded({ layerId: layer.id, imageDTO }));
return imageDTO; return imageDTO;
}; };
//#endregion
//#region buildControlImage
const buildControlImage = ( const buildControlImage = (
image: ImageWithDims | null, image: ImageWithDims | null,
processedImage: ImageWithDims | null, processedImage: ImageWithDims | null,
@ -549,3 +557,61 @@ const buildControlImage = (
assert(false, 'Attempted to add unprocessed control image'); assert(false, 'Attempted to add unprocessed control image');
}; };
//#endregion //#endregion
//#region getRGLayerBlobs
/**
* Get the blobs of all regional prompt layers. Only visible layers are returned.
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
* @param preview Whether to open a new tab displaying each layer.
* @returns A map of layer IDs to blobs.
*/
const getRGLayerBlobs = async (layerIds?: string[], preview: boolean = false): Promise<Record<string, Blob>> => {
const state = getStore().getState();
const { layers } = state.controlLayers.present;
const { width, height } = state.controlLayers.present.size;
const reduxLayers = layers.filter(isRegionalGuidanceLayer);
const container = document.createElement('div');
const stage = new Konva.Stage({ container, width, height });
renderers.renderLayers(stage, reduxLayers, 1, 'brush', getImageDTO);
const konvaLayers = stage.find<Konva.Layer>(`.${RG_LAYER_NAME}`);
const blobs: Record<string, Blob> = {};
// First remove all layers
for (const layer of konvaLayers) {
layer.remove();
}
// Next render each layer to a blob
for (const layer of konvaLayers) {
if (layerIds && !layerIds.includes(layer.id())) {
continue;
}
const reduxLayer = reduxLayers.find((l) => l.id === layer.id());
assert(reduxLayer, `Redux layer ${layer.id()} not found`);
stage.add(layer);
const blob = await new Promise<Blob>((resolve) => {
stage.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
});
});
if (preview) {
const base64 = await blobToDataURL(blob);
openBase64ImageInTab([
{
base64,
caption: `${reduxLayer.id}: ${reduxLayer.positivePrompt} / ${reduxLayer.negativePrompt}`,
},
]);
}
layer.remove();
blobs[layer.id()] = blob;
}
return blobs;
};
//#endregion

View File

@ -1,8 +1,17 @@
import { Flex } from '@invoke-ai/ui-library'; import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { StageComponent } from 'features/controlLayers/components/StageComponent'; import { StageComponent } from 'features/controlLayers/components/StageComponent';
import { $isPreviewVisible } from 'features/controlLayers/store/controlLayersSlice';
import { AspectRatioIconPreview } from 'features/parameters/components/ImageSize/AspectRatioIconPreview';
import { memo } from 'react'; import { memo } from 'react';
export const AspectRatioCanvasPreview = memo(() => { export const AspectRatioCanvasPreview = memo(() => {
const isPreviewVisible = useStore($isPreviewVisible);
if (!isPreviewVisible) {
return <AspectRatioIconPreview />;
}
return ( return (
<Flex w="full" h="full" alignItems="center" justifyContent="center" position="relative"> <Flex w="full" h="full" alignItems="center" justifyContent="center" position="relative">
<StageComponent asPreview /> <StageComponent asPreview />

View File

@ -3,15 +3,12 @@ import { aspectRatioChanged, heightChanged, widthChanged } from 'features/contro
import { ParamHeight } from 'features/parameters/components/Core/ParamHeight'; import { ParamHeight } from 'features/parameters/components/Core/ParamHeight';
import { ParamWidth } from 'features/parameters/components/Core/ParamWidth'; import { ParamWidth } from 'features/parameters/components/Core/ParamWidth';
import { AspectRatioCanvasPreview } from 'features/parameters/components/ImageSize/AspectRatioCanvasPreview'; import { AspectRatioCanvasPreview } from 'features/parameters/components/ImageSize/AspectRatioCanvasPreview';
import { AspectRatioIconPreview } from 'features/parameters/components/ImageSize/AspectRatioIconPreview';
import { ImageSize } from 'features/parameters/components/ImageSize/ImageSize'; import { ImageSize } from 'features/parameters/components/ImageSize/ImageSize';
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types'; import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
export const ImageSizeLinear = memo(() => { export const ImageSizeLinear = memo(() => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const tab = useAppSelector(activeTabNameSelector);
const width = useAppSelector((s) => s.controlLayers.present.size.width); const width = useAppSelector((s) => s.controlLayers.present.size.width);
const height = useAppSelector((s) => s.controlLayers.present.size.height); const height = useAppSelector((s) => s.controlLayers.present.size.height);
const aspectRatioState = useAppSelector((s) => s.controlLayers.present.size.aspectRatio); const aspectRatioState = useAppSelector((s) => s.controlLayers.present.size.aspectRatio);
@ -50,7 +47,7 @@ export const ImageSizeLinear = memo(() => {
aspectRatioState={aspectRatioState} aspectRatioState={aspectRatioState}
heightComponent={<ParamHeight />} heightComponent={<ParamHeight />}
widthComponent={<ParamWidth />} widthComponent={<ParamWidth />}
previewComponent={tab === 'generation' ? <AspectRatioCanvasPreview /> : <AspectRatioIconPreview />} previewComponent={<AspectRatioCanvasPreview />}
onChangeAspectRatioState={onChangeAspectRatioState} onChangeAspectRatioState={onChangeAspectRatioState}
onChangeWidth={onChangeWidth} onChangeWidth={onChangeWidth}
onChangeHeight={onChangeHeight} onChangeHeight={onChangeHeight}

View File

@ -3,6 +3,7 @@ import { Box, Flex, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/u
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants'; import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
import { ControlLayersPanelContent } from 'features/controlLayers/components/ControlLayersPanelContent'; import { ControlLayersPanelContent } from 'features/controlLayers/components/ControlLayersPanelContent';
import { $isPreviewVisible } from 'features/controlLayers/store/controlLayersSlice';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice'; import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { Prompts } from 'features/parameters/components/Prompts/Prompts'; import { Prompts } from 'features/parameters/components/Prompts/Prompts';
import QueueControls from 'features/queue/components/QueueControls'; import QueueControls from 'features/queue/components/QueueControls';
@ -53,6 +54,7 @@ const ParametersPanelTextToImage = () => {
if (i === 1) { if (i === 1) {
dispatch(isImageViewerOpenChanged(false)); dispatch(isImageViewerOpenChanged(false));
} }
$isPreviewVisible.set(i === 0);
}, },
[dispatch] [dispatch]
); );
@ -66,6 +68,7 @@ const ParametersPanelTextToImage = () => {
<Flex gap={2} flexDirection="column" h="full" w="full"> <Flex gap={2} flexDirection="column" h="full" w="full">
{isSDXL ? <SDXLPrompts /> : <Prompts />} {isSDXL ? <SDXLPrompts /> : <Prompts />}
<Tabs <Tabs
defaultIndex={0}
variant="enclosed" variant="enclosed"
display="flex" display="flex"
flexDir="column" flexDir="column"

Some files were not shown because too many files have changed in this diff Show More