mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into ryan/model-cache-logging-only
This commit is contained in:
commit
096d195d6e
@ -10,6 +10,20 @@ To use a community workflow, download the the `.json` node graph file and load i
|
|||||||
|
|
||||||
--------------------------------
|
--------------------------------
|
||||||
|
|
||||||
|
--------------------------------
|
||||||
|
### Make 3D
|
||||||
|
|
||||||
|
**Description:** Create compelling 3D stereo images from 2D originals.
|
||||||
|
|
||||||
|
**Node Link:** [https://gitlab.com/srcrr/shift3d/-/raw/main/make3d.py](https://gitlab.com/srcrr/shift3d)
|
||||||
|
|
||||||
|
**Example Node Graph:** https://gitlab.com/srcrr/shift3d/-/raw/main/example-workflow.json?ref_type=heads&inline=false
|
||||||
|
|
||||||
|
**Output Examples**
|
||||||
|
|
||||||
|
![Painting of a cozy delapidated house](https://gitlab.com/srcrr/shift3d/-/raw/main/example-1.png){: style="height:512px;width:512px"}
|
||||||
|
![Photo of cute puppies](https://gitlab.com/srcrr/shift3d/-/raw/main/example-2.png){: style="height:512px;width:512px"}
|
||||||
|
|
||||||
--------------------------------
|
--------------------------------
|
||||||
### Ideal Size
|
### Ideal Size
|
||||||
|
|
||||||
|
@ -68,6 +68,7 @@ class FieldDescriptions:
|
|||||||
height = "Height of output (px)"
|
height = "Height of output (px)"
|
||||||
control = "ControlNet(s) to apply"
|
control = "ControlNet(s) to apply"
|
||||||
ip_adapter = "IP-Adapter to apply"
|
ip_adapter = "IP-Adapter to apply"
|
||||||
|
t2i_adapter = "T2I-Adapter(s) to apply"
|
||||||
denoised_latents = "Denoised latents tensor"
|
denoised_latents = "Denoised latents tensor"
|
||||||
latents = "Latents tensor"
|
latents = "Latents tensor"
|
||||||
strength = "Strength of denoising (proportional to steps)"
|
strength = "Strength of denoising (proportional to steps)"
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, TypedDict
|
from typing import Optional, TypedDict
|
||||||
@ -11,6 +10,7 @@ from PIL import Image, ImageDraw, ImageFilter, ImageFont, ImageOps
|
|||||||
from PIL.Image import Image as ImageType
|
from PIL.Image import Image as ImageType
|
||||||
from pydantic import validator
|
from pydantic import validator
|
||||||
|
|
||||||
|
import invokeai.assets.fonts as font_assets
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
InputField,
|
InputField,
|
||||||
@ -138,6 +138,7 @@ def generate_face_box_mask(
|
|||||||
chunk_x_offset: int = 0,
|
chunk_x_offset: int = 0,
|
||||||
chunk_y_offset: int = 0,
|
chunk_y_offset: int = 0,
|
||||||
draw_mesh: bool = True,
|
draw_mesh: bool = True,
|
||||||
|
check_bounds: bool = True,
|
||||||
) -> list[FaceResultData]:
|
) -> list[FaceResultData]:
|
||||||
result = []
|
result = []
|
||||||
mask_pil = None
|
mask_pil = None
|
||||||
@ -217,7 +218,7 @@ def generate_face_box_mask(
|
|||||||
im_width, im_height = pil_image.size
|
im_width, im_height = pil_image.size
|
||||||
over_w = im_width * 0.1
|
over_w = im_width * 0.1
|
||||||
over_h = im_height * 0.1
|
over_h = im_height * 0.1
|
||||||
if (
|
if not check_bounds or (
|
||||||
(left_side >= -over_w)
|
(left_side >= -over_w)
|
||||||
and (right_side < im_width + over_w)
|
and (right_side < im_width + over_w)
|
||||||
and (top_side >= -over_h)
|
and (top_side >= -over_h)
|
||||||
@ -345,6 +346,7 @@ def get_faces_list(
|
|||||||
chunk_x_offset=0,
|
chunk_x_offset=0,
|
||||||
chunk_y_offset=0,
|
chunk_y_offset=0,
|
||||||
draw_mesh=draw_mesh,
|
draw_mesh=draw_mesh,
|
||||||
|
check_bounds=False,
|
||||||
)
|
)
|
||||||
if should_chunk or len(result) == 0:
|
if should_chunk or len(result) == 0:
|
||||||
context.services.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).")
|
context.services.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).")
|
||||||
@ -402,7 +404,7 @@ def get_faces_list(
|
|||||||
return all_faces
|
return all_faces
|
||||||
|
|
||||||
|
|
||||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.0.0")
|
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.0.1")
|
||||||
class FaceOffInvocation(BaseInvocation):
|
class FaceOffInvocation(BaseInvocation):
|
||||||
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
||||||
|
|
||||||
@ -496,7 +498,7 @@ class FaceOffInvocation(BaseInvocation):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.0.0")
|
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.0.1")
|
||||||
class FaceMaskInvocation(BaseInvocation):
|
class FaceMaskInvocation(BaseInvocation):
|
||||||
"""Face mask creation using mediapipe face detection"""
|
"""Face mask creation using mediapipe face detection"""
|
||||||
|
|
||||||
@ -614,7 +616,7 @@ class FaceMaskInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.0.0"
|
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.0.1"
|
||||||
)
|
)
|
||||||
class FaceIdentifierInvocation(BaseInvocation):
|
class FaceIdentifierInvocation(BaseInvocation):
|
||||||
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
||||||
@ -641,9 +643,9 @@ class FaceIdentifierInvocation(BaseInvocation):
|
|||||||
draw_mesh=False,
|
draw_mesh=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
path = Path(__file__).resolve().parent.parent.parent
|
# Note - font may be found either in the repo if running an editable install, or in the venv if running a package install
|
||||||
font_path = os.path.abspath(path / "assets/fonts/inter/Inter-Regular.ttf")
|
font_path = [x for x in [Path(y, "inter/Inter-Regular.ttf") for y in font_assets.__path__] if x.exists()]
|
||||||
font = ImageFont.truetype(font_path, FONT_SIZE)
|
font = ImageFont.truetype(font_path[0].as_posix(), FONT_SIZE)
|
||||||
|
|
||||||
# Paste face IDs on the output image
|
# Paste face IDs on the output image
|
||||||
draw = ImageDraw.Draw(image)
|
draw = ImageDraw.Draw(image)
|
||||||
|
@ -11,6 +11,7 @@ import torchvision.transforms as T
|
|||||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
|
from diffusers.models.adapter import FullAdapterXL, T2IAdapter
|
||||||
from diffusers.models.attention_processor import (
|
from diffusers.models.attention_processor import (
|
||||||
AttnProcessor2_0,
|
AttnProcessor2_0,
|
||||||
LoRAAttnProcessor2_0,
|
LoRAAttnProcessor2_0,
|
||||||
@ -33,6 +34,7 @@ from invokeai.app.invocations.primitives import (
|
|||||||
LatentsOutput,
|
LatentsOutput,
|
||||||
build_latents_output,
|
build_latents_output,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||||
@ -47,6 +49,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
ControlNetData,
|
ControlNetData,
|
||||||
IPAdapterData,
|
IPAdapterData,
|
||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
|
T2IAdapterData,
|
||||||
image_resized_to_grid_as_tensor,
|
image_resized_to_grid_as_tensor,
|
||||||
)
|
)
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
@ -196,7 +199,7 @@ def get_scheduler(
|
|||||||
title="Denoise Latents",
|
title="Denoise Latents",
|
||||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.1.0",
|
version="1.2.0",
|
||||||
)
|
)
|
||||||
class DenoiseLatentsInvocation(BaseInvocation):
|
class DenoiseLatentsInvocation(BaseInvocation):
|
||||||
"""Denoises noisy latents to decodable images"""
|
"""Denoises noisy latents to decodable images"""
|
||||||
@ -226,9 +229,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
ip_adapter: Optional[IPAdapterField] = InputField(
|
ip_adapter: Optional[IPAdapterField] = InputField(
|
||||||
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
|
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
|
||||||
)
|
)
|
||||||
|
t2i_adapter: Union[T2IAdapterField, list[T2IAdapterField]] = InputField(
|
||||||
|
description=FieldDescriptions.t2i_adapter, title="T2I-Adapter", default=None, input=Input.Connection, ui_order=7
|
||||||
|
)
|
||||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||||
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=7
|
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=8
|
||||||
)
|
)
|
||||||
|
|
||||||
@validator("cfg_scale")
|
@validator("cfg_scale")
|
||||||
@ -451,6 +457,91 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
end_step_percent=ip_adapter.end_step_percent,
|
end_step_percent=ip_adapter.end_step_percent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def run_t2i_adapters(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
||||||
|
latents_shape: list[int],
|
||||||
|
do_classifier_free_guidance: bool,
|
||||||
|
) -> Optional[list[T2IAdapterData]]:
|
||||||
|
if t2i_adapter is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Handle the possibility that t2i_adapter could be a list or a single T2IAdapterField.
|
||||||
|
if isinstance(t2i_adapter, T2IAdapterField):
|
||||||
|
t2i_adapter = [t2i_adapter]
|
||||||
|
|
||||||
|
if len(t2i_adapter) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
t2i_adapter_data = []
|
||||||
|
for t2i_adapter_field in t2i_adapter:
|
||||||
|
t2i_adapter_model_info = context.services.model_manager.get_model(
|
||||||
|
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
|
||||||
|
model_type=ModelType.T2IAdapter,
|
||||||
|
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name)
|
||||||
|
|
||||||
|
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||||
|
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
|
||||||
|
max_unet_downscale = 8
|
||||||
|
elif t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusionXL:
|
||||||
|
max_unet_downscale = 4
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
t2i_adapter_model: T2IAdapter
|
||||||
|
with t2i_adapter_model_info as t2i_adapter_model:
|
||||||
|
total_downscale_factor = t2i_adapter_model.total_downscale_factor
|
||||||
|
if isinstance(t2i_adapter_model.adapter, FullAdapterXL):
|
||||||
|
# HACK(ryand): Work around a bug in FullAdapterXL. This is being addressed upstream in diffusers by
|
||||||
|
# this PR: https://github.com/huggingface/diffusers/pull/5134.
|
||||||
|
total_downscale_factor = total_downscale_factor // 2
|
||||||
|
|
||||||
|
# Resize the T2I-Adapter input image.
|
||||||
|
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
|
||||||
|
# result will match the latent image's dimensions after max_unet_downscale is applied.
|
||||||
|
t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor
|
||||||
|
t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor
|
||||||
|
|
||||||
|
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
|
||||||
|
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
|
||||||
|
# T2I-Adapter model.
|
||||||
|
#
|
||||||
|
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
|
||||||
|
# of the same requirements (e.g. preserving binary masks during resize).
|
||||||
|
t2i_image = prepare_control_image(
|
||||||
|
image=image,
|
||||||
|
do_classifier_free_guidance=False,
|
||||||
|
width=t2i_input_width,
|
||||||
|
height=t2i_input_height,
|
||||||
|
num_channels=t2i_adapter_model.config.in_channels,
|
||||||
|
device=t2i_adapter_model.device,
|
||||||
|
dtype=t2i_adapter_model.dtype,
|
||||||
|
resize_mode=t2i_adapter_field.resize_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter_state = t2i_adapter_model(t2i_image)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
for idx, value in enumerate(adapter_state):
|
||||||
|
adapter_state[idx] = torch.cat([value] * 2, dim=0)
|
||||||
|
|
||||||
|
t2i_adapter_data.append(
|
||||||
|
T2IAdapterData(
|
||||||
|
adapter_state=adapter_state,
|
||||||
|
weight=t2i_adapter_field.weight,
|
||||||
|
begin_step_percent=t2i_adapter_field.begin_step_percent,
|
||||||
|
end_step_percent=t2i_adapter_field.end_step_percent,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return t2i_adapter_data
|
||||||
|
|
||||||
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
||||||
# TODO: research more for second order schedulers timesteps
|
# TODO: research more for second order schedulers timesteps
|
||||||
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
|
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
|
||||||
@ -522,6 +613,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
mask, masked_latents = self.prep_inpaint_mask(context, latents)
|
mask, masked_latents = self.prep_inpaint_mask(context, latents)
|
||||||
|
|
||||||
|
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||||
|
# below. Investigate whether this is appropriate.
|
||||||
|
t2i_adapter_data = self.run_t2i_adapters(
|
||||||
|
context, self.t2i_adapter, latents.shape, do_classifier_free_guidance=True
|
||||||
|
)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
@ -602,8 +699,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
masked_latents=masked_latents,
|
masked_latents=masked_latents,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
control_data=controlnet_data, # list[ControlNetData],
|
control_data=controlnet_data,
|
||||||
ip_adapter_data=ip_adapter_data, # IPAdapterData,
|
ip_adapter_data=ip_adapter_data,
|
||||||
|
t2i_adapter_data=t2i_adapter_data,
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
83
invokeai/app/invocations/t2i_adapter.py
Normal file
83
invokeai/app/invocations/t2i_adapter.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UIType,
|
||||||
|
invocation,
|
||||||
|
invocation_output,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
||||||
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
|
from invokeai.backend.model_management.models.base import BaseModelType
|
||||||
|
|
||||||
|
|
||||||
|
class T2IAdapterModelField(BaseModel):
|
||||||
|
model_name: str = Field(description="Name of the T2I-Adapter model")
|
||||||
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
|
||||||
|
class T2IAdapterField(BaseModel):
|
||||||
|
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||||
|
t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.")
|
||||||
|
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
||||||
|
begin_step_percent: float = Field(
|
||||||
|
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||||
|
)
|
||||||
|
end_step_percent: float = Field(
|
||||||
|
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
|
||||||
|
)
|
||||||
|
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("t2i_adapter_output")
|
||||||
|
class T2IAdapterOutput(BaseInvocationOutput):
|
||||||
|
t2i_adapter: T2IAdapterField = OutputField(description=FieldDescriptions.t2i_adapter, title="T2I Adapter")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.0"
|
||||||
|
)
|
||||||
|
class T2IAdapterInvocation(BaseInvocation):
|
||||||
|
"""Collects T2I-Adapter info to pass to other nodes."""
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
||||||
|
ip_adapter_model: T2IAdapterModelField = InputField(
|
||||||
|
description="The T2I-Adapter model.",
|
||||||
|
title="T2I-Adapter Model",
|
||||||
|
input=Input.Direct,
|
||||||
|
ui_order=-1,
|
||||||
|
)
|
||||||
|
weight: Union[float, list[float]] = InputField(
|
||||||
|
default=1, ge=0, description="The weight given to the T2I-Adapter", ui_type=UIType.Float, title="Weight"
|
||||||
|
)
|
||||||
|
begin_step_percent: float = InputField(
|
||||||
|
default=0, ge=-1, le=2, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||||
|
)
|
||||||
|
end_step_percent: float = InputField(
|
||||||
|
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
|
||||||
|
)
|
||||||
|
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(
|
||||||
|
default="just_resize",
|
||||||
|
description="The resize mode applied to the T2I-Adapter input image so that it matches the target output size.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> T2IAdapterOutput:
|
||||||
|
return T2IAdapterOutput(
|
||||||
|
t2i_adapter=T2IAdapterField(
|
||||||
|
image=self.image,
|
||||||
|
t2i_adapter_model=self.ip_adapter_model,
|
||||||
|
weight=self.weight,
|
||||||
|
begin_step_percent=self.begin_step_percent,
|
||||||
|
end_step_percent=self.end_step_percent,
|
||||||
|
resize_mode=self.resize_mode,
|
||||||
|
)
|
||||||
|
)
|
@ -4,12 +4,14 @@ from typing import Literal
|
|||||||
|
|
||||||
import cv2 as cv
|
import cv2 as cv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
@ -22,13 +24,19 @@ ESRGAN_MODELS = Literal[
|
|||||||
"RealESRGAN_x2plus.pth",
|
"RealESRGAN_x2plus.pth",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if choose_torch_device() == torch.device("mps"):
|
||||||
|
from torch import mps
|
||||||
|
|
||||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.0.0")
|
|
||||||
|
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.1.0")
|
||||||
class ESRGANInvocation(BaseInvocation):
|
class ESRGANInvocation(BaseInvocation):
|
||||||
"""Upscales an image using RealESRGAN."""
|
"""Upscales an image using RealESRGAN."""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The input image")
|
image: ImageField = InputField(description="The input image")
|
||||||
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
||||||
|
tile_size: int = InputField(
|
||||||
|
default=400, ge=0, description="Tile size for tiled ESRGAN upscaling (0=tiling disabled)"
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -86,9 +94,11 @@ class ESRGANInvocation(BaseInvocation):
|
|||||||
model_path=str(models_path / esrgan_model_path),
|
model_path=str(models_path / esrgan_model_path),
|
||||||
model=rrdbnet_model,
|
model=rrdbnet_model,
|
||||||
half=False,
|
half=False,
|
||||||
|
tile=self.tile_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
||||||
|
# TODO: This strips the alpha... is that okay?
|
||||||
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
|
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
|
||||||
|
|
||||||
# We can pass an `outscale` value here, but it just resizes the image by that factor after
|
# We can pass an `outscale` value here, but it just resizes the image by that factor after
|
||||||
@ -99,6 +109,10 @@ class ESRGANInvocation(BaseInvocation):
|
|||||||
# back to PIL
|
# back to PIL
|
||||||
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
|
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
if choose_torch_device() == torch.device("mps"):
|
||||||
|
mps.empty_cache()
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=pil_image,
|
image=pil_image,
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
@ -255,6 +255,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
||||||
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||||
|
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", category="Generation", )
|
||||||
|
|
||||||
# QUEUE
|
# QUEUE
|
||||||
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )
|
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Annotated, Any, Optional, Union, cast, get_args, get_origin, get_type_hints
|
from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pydantic import BaseModel, root_validator, validator
|
from pydantic import BaseModel, root_validator, validator
|
||||||
@ -170,6 +170,18 @@ class NodeIdMismatchError(ValueError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSubGraphError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CyclicalGraphError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownGraphValidationError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# TODO: Create and use an Empty output?
|
# TODO: Create and use an Empty output?
|
||||||
@invocation_output("graph_output")
|
@invocation_output("graph_output")
|
||||||
class GraphInvocationOutput(BaseInvocationOutput):
|
class GraphInvocationOutput(BaseInvocationOutput):
|
||||||
@ -254,59 +266,6 @@ class Graph(BaseModel):
|
|||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
|
||||||
@root_validator
|
|
||||||
def validate_nodes_and_edges(cls, values):
|
|
||||||
"""Validates that all edges match nodes in the graph"""
|
|
||||||
nodes = cast(Optional[dict[str, BaseInvocation]], values.get("nodes"))
|
|
||||||
edges = cast(Optional[list[Edge]], values.get("edges"))
|
|
||||||
|
|
||||||
if nodes is not None:
|
|
||||||
# Validate that all node ids are unique
|
|
||||||
node_ids = [n.id for n in nodes.values()]
|
|
||||||
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
|
|
||||||
if duplicate_node_ids:
|
|
||||||
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
|
|
||||||
|
|
||||||
# Validate that all node ids match the keys in the nodes dict
|
|
||||||
for k, v in nodes.items():
|
|
||||||
if k != v.id:
|
|
||||||
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
|
|
||||||
|
|
||||||
if edges is not None and nodes is not None:
|
|
||||||
# Validate that all edges match nodes in the graph
|
|
||||||
node_ids = set([e.source.node_id for e in edges] + [e.destination.node_id for e in edges])
|
|
||||||
missing_node_ids = [node_id for node_id in node_ids if node_id not in nodes]
|
|
||||||
if missing_node_ids:
|
|
||||||
raise NodeNotFoundError(
|
|
||||||
f"All edges must reference nodes in the graph, missing nodes: {missing_node_ids}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate that all edge fields match node fields in the graph
|
|
||||||
for edge in edges:
|
|
||||||
source_node = nodes.get(edge.source.node_id, None)
|
|
||||||
if source_node is None:
|
|
||||||
raise NodeFieldNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph")
|
|
||||||
|
|
||||||
destination_node = nodes.get(edge.destination.node_id, None)
|
|
||||||
if destination_node is None:
|
|
||||||
raise NodeFieldNotFoundError(
|
|
||||||
f"Edge destination node {edge.destination.node_id} does not exist in the graph"
|
|
||||||
)
|
|
||||||
|
|
||||||
# output fields are not on the node object directly, they are on the output type
|
|
||||||
if edge.source.field not in source_node.get_output_type().__fields__:
|
|
||||||
raise NodeFieldNotFoundError(
|
|
||||||
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# input fields are on the node
|
|
||||||
if edge.destination.field not in destination_node.__fields__:
|
|
||||||
raise NodeFieldNotFoundError(
|
|
||||||
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return values
|
|
||||||
|
|
||||||
def add_node(self, node: BaseInvocation) -> None:
|
def add_node(self, node: BaseInvocation) -> None:
|
||||||
"""Adds a node to a graph
|
"""Adds a node to a graph
|
||||||
|
|
||||||
@ -377,53 +336,108 @@ class Graph(BaseModel):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def is_valid(self) -> bool:
|
def validate_self(self) -> None:
|
||||||
"""Validates the graph."""
|
"""
|
||||||
|
Validates the graph.
|
||||||
|
|
||||||
|
Raises an exception if the graph is invalid:
|
||||||
|
- `DuplicateNodeIdError`
|
||||||
|
- `NodeIdMismatchError`
|
||||||
|
- `InvalidSubGraphError`
|
||||||
|
- `NodeNotFoundError`
|
||||||
|
- `NodeFieldNotFoundError`
|
||||||
|
- `CyclicalGraphError`
|
||||||
|
- `InvalidEdgeError`
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Validate that all node ids are unique
|
||||||
|
node_ids = [n.id for n in self.nodes.values()]
|
||||||
|
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
|
||||||
|
if duplicate_node_ids:
|
||||||
|
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
|
||||||
|
|
||||||
|
# Validate that all node ids match the keys in the nodes dict
|
||||||
|
for k, v in self.nodes.items():
|
||||||
|
if k != v.id:
|
||||||
|
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
|
||||||
|
|
||||||
# Validate all subgraphs
|
# Validate all subgraphs
|
||||||
for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)):
|
for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)):
|
||||||
if not gn.graph.is_valid():
|
try:
|
||||||
return False
|
gn.graph.validate_self()
|
||||||
|
except Exception as e:
|
||||||
|
raise InvalidSubGraphError(f"Subgraph {gn.id} is invalid") from e
|
||||||
|
|
||||||
# Validate all edges reference nodes in the graph
|
# Validate that all edges match nodes and fields in the graph
|
||||||
node_ids = set([e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges])
|
for edge in self.edges:
|
||||||
if not all((self.has_node(node_id) for node_id in node_ids)):
|
source_node = self.nodes.get(edge.source.node_id, None)
|
||||||
return False
|
if source_node is None:
|
||||||
|
raise NodeNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph")
|
||||||
|
|
||||||
|
destination_node = self.nodes.get(edge.destination.node_id, None)
|
||||||
|
if destination_node is None:
|
||||||
|
raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph")
|
||||||
|
|
||||||
|
# output fields are not on the node object directly, they are on the output type
|
||||||
|
if edge.source.field not in source_node.get_output_type().__fields__:
|
||||||
|
raise NodeFieldNotFoundError(
|
||||||
|
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# input fields are on the node
|
||||||
|
if edge.destination.field not in destination_node.__fields__:
|
||||||
|
raise NodeFieldNotFoundError(
|
||||||
|
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
|
||||||
|
)
|
||||||
|
|
||||||
# Validate there are no cycles
|
# Validate there are no cycles
|
||||||
g = self.nx_graph_flat()
|
g = self.nx_graph_flat()
|
||||||
if not nx.is_directed_acyclic_graph(g):
|
if not nx.is_directed_acyclic_graph(g):
|
||||||
return False
|
raise CyclicalGraphError("Graph contains cycles")
|
||||||
|
|
||||||
# Validate all edge connections are valid
|
# Validate all edge connections are valid
|
||||||
if not all(
|
for e in self.edges:
|
||||||
(
|
if not are_connections_compatible(
|
||||||
are_connections_compatible(
|
|
||||||
self.get_node(e.source.node_id),
|
self.get_node(e.source.node_id),
|
||||||
e.source.field,
|
e.source.field,
|
||||||
self.get_node(e.destination.node_id),
|
self.get_node(e.destination.node_id),
|
||||||
e.destination.field,
|
e.destination.field,
|
||||||
|
):
|
||||||
|
raise InvalidEdgeError(
|
||||||
|
f"Invalid edge from {e.source.node_id}.{e.source.field} to {e.destination.node_id}.{e.destination.field}"
|
||||||
)
|
)
|
||||||
for e in self.edges
|
|
||||||
)
|
|
||||||
):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Validate all iterators
|
# Validate all iterators & collectors
|
||||||
# TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available
|
# TODO: may need to validate all iterators & collectors in subgraphs so edge connections in parent graphs will be available
|
||||||
if not all(
|
for n in self.nodes.values():
|
||||||
(self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation))
|
if isinstance(n, IterateInvocation) and not self._is_iterator_connection_valid(n.id):
|
||||||
):
|
raise InvalidEdgeError(f"Invalid iterator node {n.id}")
|
||||||
return False
|
if isinstance(n, CollectInvocation) and not self._is_collector_connection_valid(n.id):
|
||||||
|
raise InvalidEdgeError(f"Invalid collector node {n.id}")
|
||||||
|
|
||||||
# Validate all collectors
|
return None
|
||||||
# TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available
|
|
||||||
if not all(
|
|
||||||
(self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation))
|
|
||||||
):
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if the graph is valid.
|
||||||
|
|
||||||
|
Raises `UnknownGraphValidationError` if there is a problem validating the graph (not a validation error).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.validate_self()
|
||||||
return True
|
return True
|
||||||
|
except (
|
||||||
|
DuplicateNodeIdError,
|
||||||
|
NodeIdMismatchError,
|
||||||
|
InvalidSubGraphError,
|
||||||
|
NodeNotFoundError,
|
||||||
|
NodeFieldNotFoundError,
|
||||||
|
CyclicalGraphError,
|
||||||
|
InvalidEdgeError,
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
raise UnknownGraphValidationError(f"Problem validating graph {e}") from e
|
||||||
|
|
||||||
def _validate_edge(self, edge: Edge):
|
def _validate_edge(self, edge: Edge):
|
||||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||||
@ -804,6 +818,12 @@ class GraphExecutionState(BaseModel):
|
|||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@validator("graph")
|
||||||
|
def graph_is_valid(cls, v: Graph):
|
||||||
|
"""Validates that the graph is valid"""
|
||||||
|
v.validate_self()
|
||||||
|
return v
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"required": [
|
"required": [
|
||||||
|
@ -9,6 +9,7 @@ from PIL import Image, PngImagePlugin
|
|||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
from send2trash import send2trash
|
from send2trash import send2trash
|
||||||
|
|
||||||
|
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
||||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||||
|
|
||||||
|
|
||||||
@ -79,6 +80,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||||
__cache: Dict[Path, PILImageType]
|
__cache: Dict[Path, PILImageType]
|
||||||
__max_cache_size: int
|
__max_cache_size: int
|
||||||
|
__compress_level: int
|
||||||
|
|
||||||
def __init__(self, output_folder: Union[str, Path]):
|
def __init__(self, output_folder: Union[str, Path]):
|
||||||
self.__cache = dict()
|
self.__cache = dict()
|
||||||
@ -87,7 +89,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
|
|
||||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||||
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
||||||
|
self.__compress_level = InvokeAIAppConfig.get_config().png_compress_level
|
||||||
# Validate required output folders at launch
|
# Validate required output folders at launch
|
||||||
self.__validate_storage_folders()
|
self.__validate_storage_folders()
|
||||||
|
|
||||||
@ -134,7 +136,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
if original_workflow is not None:
|
if original_workflow is not None:
|
||||||
pnginfo.add_text("invokeai_workflow", original_workflow)
|
pnginfo.add_text("invokeai_workflow", original_workflow)
|
||||||
|
|
||||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
image.save(image_path, "PNG", pnginfo=pnginfo, compress_level=self.__compress_level)
|
||||||
|
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
|
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
|
||||||
|
@ -584,7 +584,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
FROM images
|
FROM images
|
||||||
JOIN board_images ON images.image_name = board_images.image_name
|
JOIN board_images ON images.image_name = board_images.image_name
|
||||||
WHERE board_images.board_id = ?
|
WHERE board_images.board_id = ?
|
||||||
ORDER BY images.created_at DESC
|
ORDER BY images.starred DESC, images.created_at DESC
|
||||||
LIMIT 1;
|
LIMIT 1;
|
||||||
""",
|
""",
|
||||||
(board_id,),
|
(board_id,),
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import traceback
|
||||||
from threading import BoundedSemaphore
|
from threading import BoundedSemaphore
|
||||||
from threading import Event as ThreadEvent
|
from threading import Event as ThreadEvent
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
@ -123,6 +124,10 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.__invoker.services.logger.error(f"Error in session processor: {e}")
|
self.__invoker.services.logger.error(f"Error in session processor: {e}")
|
||||||
|
if queue_item is not None:
|
||||||
|
self.__invoker.services.session_queue.cancel_queue_item(
|
||||||
|
queue_item.item_id, error=traceback.format_exc()
|
||||||
|
)
|
||||||
poll_now_event.wait(POLLING_INTERVAL)
|
poll_now_event.wait(POLLING_INTERVAL)
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -80,7 +80,7 @@ class SessionQueueBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem:
|
||||||
"""Cancels a session queue item"""
|
"""Cancels a session queue item"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -123,6 +123,11 @@ class Batch(BaseModel):
|
|||||||
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
|
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@validator("graph")
|
||||||
|
def validate_graph(cls, v: Graph):
|
||||||
|
v.validate_self()
|
||||||
|
return v
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"required": [
|
"required": [
|
||||||
|
@ -555,10 +555,11 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
self.__lock.release()
|
self.__lock.release()
|
||||||
return PruneResult(deleted=count)
|
return PruneResult(deleted=count)
|
||||||
|
|
||||||
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem:
|
||||||
queue_item = self.get_queue_item(item_id)
|
queue_item = self.get_queue_item(item_id)
|
||||||
if queue_item.status not in ["canceled", "failed", "completed"]:
|
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||||
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
|
status = "failed" if error is not None else "canceled"
|
||||||
|
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error)
|
||||||
self.__invoker.services.queue.cancel(queue_item.session_id)
|
self.__invoker.services.queue.cancel(queue_item.session_id)
|
||||||
self.__invoker.services.events.emit_session_canceled(
|
self.__invoker.services.events.emit_session_canceled(
|
||||||
queue_item_id=queue_item.item_id,
|
queue_item_id=queue_item.item_id,
|
||||||
|
@ -265,22 +265,41 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
|
|||||||
|
|
||||||
|
|
||||||
def prepare_control_image(
|
def prepare_control_image(
|
||||||
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
|
|
||||||
# but now should be able to assume that image is a single PIL.Image, which simplifies things
|
|
||||||
image: Image,
|
image: Image,
|
||||||
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
|
width: int,
|
||||||
# latents_to_match_resolution, # TorchTensor of shape (batch_size, 3, height, width)
|
height: int,
|
||||||
width=512, # should be 8 * latent.shape[3]
|
num_channels: int = 3,
|
||||||
height=512, # should be 8 * latent height[2]
|
|
||||||
# batch_size=1, # currently no batching
|
|
||||||
# num_images_per_prompt=1, # currently only single image
|
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
control_mode="balanced",
|
control_mode="balanced",
|
||||||
resize_mode="just_resize_simple",
|
resize_mode="just_resize_simple",
|
||||||
):
|
):
|
||||||
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
|
"""Pre-process images for ControlNets or T2I-Adapters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Image): The PIL image to pre-process.
|
||||||
|
width (int): The target width in pixels.
|
||||||
|
height (int): The target height in pixels.
|
||||||
|
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
|
||||||
|
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
|
||||||
|
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
|
||||||
|
device (str, optional): The target device for the output image. Defaults to "cuda".
|
||||||
|
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
|
||||||
|
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
|
||||||
|
Defaults to True.
|
||||||
|
control_mode (str, optional): Defaults to "balanced".
|
||||||
|
resize_mode (str, optional): Defaults to "just_resize_simple".
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: If resize_mode == "crop_resize_simple".
|
||||||
|
NotImplementedError: If resize_mode == "fill_resize_simple".
|
||||||
|
ValueError: If `resize_mode` is not recognized.
|
||||||
|
ValueError: If `num_channels` is out of range.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The pre-processed input tensor.
|
||||||
|
"""
|
||||||
if (
|
if (
|
||||||
resize_mode == "just_resize_simple"
|
resize_mode == "just_resize_simple"
|
||||||
or resize_mode == "crop_resize_simple"
|
or resize_mode == "crop_resize_simple"
|
||||||
@ -289,10 +308,10 @@ def prepare_control_image(
|
|||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
if resize_mode == "just_resize_simple":
|
if resize_mode == "just_resize_simple":
|
||||||
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||||
elif resize_mode == "crop_resize_simple": # not yet implemented
|
elif resize_mode == "crop_resize_simple":
|
||||||
pass
|
raise NotImplementedError(f"prepare_control_image is not implemented for resize_mode='{resize_mode}'.")
|
||||||
elif resize_mode == "fill_resize_simple": # not yet implemented
|
elif resize_mode == "fill_resize_simple":
|
||||||
pass
|
raise NotImplementedError(f"prepare_control_image is not implemented for resize_mode='{resize_mode}'.")
|
||||||
nimage = np.array(image)
|
nimage = np.array(image)
|
||||||
nimage = nimage[None, :]
|
nimage = nimage[None, :]
|
||||||
nimage = np.concatenate([nimage], axis=0)
|
nimage = np.concatenate([nimage], axis=0)
|
||||||
@ -313,9 +332,11 @@ def prepare_control_image(
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
pass
|
raise ValueError(f"Unsupported resize_mode: '{resize_mode}'.")
|
||||||
print("ERROR: invalid resize_mode ==> ", resize_mode)
|
|
||||||
exit(1)
|
if timage.shape[1] < num_channels or num_channels <= 0:
|
||||||
|
raise ValueError(f"Cannot achieve the target of num_channels={num_channels}.")
|
||||||
|
timage = timage[:, :num_channels, :, :]
|
||||||
|
|
||||||
timage = timage.to(device=device, dtype=dtype)
|
timage = timage.to(device=device, dtype=dtype)
|
||||||
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
|
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
|
||||||
|
@ -218,6 +218,20 @@ class IPAdapterPlus(IPAdapter):
|
|||||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterPlusXL(IPAdapterPlus):
|
||||||
|
"""IP-Adapter Plus for SDXL."""
|
||||||
|
|
||||||
|
def _init_image_proj_model(self, state_dict):
|
||||||
|
return Resampler.from_state_dict(
|
||||||
|
state_dict=state_dict,
|
||||||
|
depth=4,
|
||||||
|
dim_head=64,
|
||||||
|
heads=20,
|
||||||
|
num_queries=self._num_tokens,
|
||||||
|
ff_mult=4,
|
||||||
|
).to(self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
|
||||||
def build_ip_adapter(
|
def build_ip_adapter(
|
||||||
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
|
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
|
||||||
) -> Union[IPAdapter, IPAdapterPlus]:
|
) -> Union[IPAdapter, IPAdapterPlus]:
|
||||||
@ -228,6 +242,14 @@ def build_ip_adapter(
|
|||||||
is_plus = "proj.weight" not in state_dict["image_proj"]
|
is_plus = "proj.weight" not in state_dict["image_proj"]
|
||||||
|
|
||||||
if is_plus:
|
if is_plus:
|
||||||
|
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||||
|
if cross_attention_dim == 768:
|
||||||
|
# SD1 IP-Adapter Plus
|
||||||
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
|
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
|
||||||
|
elif cross_attention_dim == 2048:
|
||||||
|
# SDXL IP-Adapter Plus
|
||||||
|
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
|
||||||
else:
|
else:
|
||||||
return IPAdapter(state_dict, device=device, dtype=dtype)
|
return IPAdapter(state_dict, device=device, dtype=dtype)
|
||||||
|
@ -57,6 +57,7 @@ class ModelProbe(object):
|
|||||||
"AutoencoderTiny": ModelType.Vae,
|
"AutoencoderTiny": ModelType.Vae,
|
||||||
"ControlNetModel": ModelType.ControlNet,
|
"ControlNetModel": ModelType.ControlNet,
|
||||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||||
|
"T2IAdapter": ModelType.T2IAdapter,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -408,6 +409,11 @@ class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
# classes for probing folders
|
# classes for probing folders
|
||||||
#######################################################
|
#######################################################
|
||||||
@ -595,6 +601,26 @@ class CLIPVisionFolderProbe(FolderProbeBase):
|
|||||||
return BaseModelType.Any
|
return BaseModelType.Any
|
||||||
|
|
||||||
|
|
||||||
|
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
config_file = self.folder_path / "config.json"
|
||||||
|
if not config_file.exists():
|
||||||
|
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||||
|
with open(config_file, "r") as file:
|
||||||
|
config = json.load(file)
|
||||||
|
|
||||||
|
adapter_type = config.get("adapter_type", None)
|
||||||
|
if adapter_type == "full_adapter_xl":
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
elif adapter_type == "full_adapter" or "light_adapter":
|
||||||
|
# I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
else:
|
||||||
|
raise InvalidModelException(
|
||||||
|
f"Unable to determine base model for '{self.folder_path}' (adapter_type = {adapter_type})."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
############## register probe classes ######
|
############## register probe classes ######
|
||||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||||
@ -603,6 +629,7 @@ ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInvers
|
|||||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||||
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||||
@ -611,5 +638,6 @@ ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInver
|
|||||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
||||||
|
|
||||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||||
|
@ -25,6 +25,7 @@ from .lora import LoRAModel
|
|||||||
from .sdxl import StableDiffusionXLModel
|
from .sdxl import StableDiffusionXLModel
|
||||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||||
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
|
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
|
||||||
|
from .t2i_adapter import T2IAdapterModel
|
||||||
from .textual_inversion import TextualInversionModel
|
from .textual_inversion import TextualInversionModel
|
||||||
from .vae import VaeModel
|
from .vae import VaeModel
|
||||||
|
|
||||||
@ -38,6 +39,7 @@ MODEL_CLASSES = {
|
|||||||
ModelType.TextualInversion: TextualInversionModel,
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
ModelType.IPAdapter: IPAdapterModel,
|
ModelType.IPAdapter: IPAdapterModel,
|
||||||
ModelType.CLIPVision: CLIPVisionModel,
|
ModelType.CLIPVision: CLIPVisionModel,
|
||||||
|
ModelType.T2IAdapter: T2IAdapterModel,
|
||||||
},
|
},
|
||||||
BaseModelType.StableDiffusion2: {
|
BaseModelType.StableDiffusion2: {
|
||||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||||
@ -48,6 +50,7 @@ MODEL_CLASSES = {
|
|||||||
ModelType.TextualInversion: TextualInversionModel,
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
ModelType.IPAdapter: IPAdapterModel,
|
ModelType.IPAdapter: IPAdapterModel,
|
||||||
ModelType.CLIPVision: CLIPVisionModel,
|
ModelType.CLIPVision: CLIPVisionModel,
|
||||||
|
ModelType.T2IAdapter: T2IAdapterModel,
|
||||||
},
|
},
|
||||||
BaseModelType.StableDiffusionXL: {
|
BaseModelType.StableDiffusionXL: {
|
||||||
ModelType.Main: StableDiffusionXLModel,
|
ModelType.Main: StableDiffusionXLModel,
|
||||||
@ -59,6 +62,7 @@ MODEL_CLASSES = {
|
|||||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||||
ModelType.IPAdapter: IPAdapterModel,
|
ModelType.IPAdapter: IPAdapterModel,
|
||||||
ModelType.CLIPVision: CLIPVisionModel,
|
ModelType.CLIPVision: CLIPVisionModel,
|
||||||
|
ModelType.T2IAdapter: T2IAdapterModel,
|
||||||
},
|
},
|
||||||
BaseModelType.StableDiffusionXLRefiner: {
|
BaseModelType.StableDiffusionXLRefiner: {
|
||||||
ModelType.Main: StableDiffusionXLModel,
|
ModelType.Main: StableDiffusionXLModel,
|
||||||
@ -70,6 +74,7 @@ MODEL_CLASSES = {
|
|||||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||||
ModelType.IPAdapter: IPAdapterModel,
|
ModelType.IPAdapter: IPAdapterModel,
|
||||||
ModelType.CLIPVision: CLIPVisionModel,
|
ModelType.CLIPVision: CLIPVisionModel,
|
||||||
|
ModelType.T2IAdapter: T2IAdapterModel,
|
||||||
},
|
},
|
||||||
BaseModelType.Any: {
|
BaseModelType.Any: {
|
||||||
ModelType.CLIPVision: CLIPVisionModel,
|
ModelType.CLIPVision: CLIPVisionModel,
|
||||||
@ -81,6 +86,7 @@ MODEL_CLASSES = {
|
|||||||
ModelType.ControlNet: ControlNetModel,
|
ModelType.ControlNet: ControlNetModel,
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
ModelType.IPAdapter: IPAdapterModel,
|
ModelType.IPAdapter: IPAdapterModel,
|
||||||
|
ModelType.T2IAdapter: T2IAdapterModel,
|
||||||
},
|
},
|
||||||
# BaseModelType.Kandinsky2_1: {
|
# BaseModelType.Kandinsky2_1: {
|
||||||
# ModelType.Main: Kandinsky2_1Model,
|
# ModelType.Main: Kandinsky2_1Model,
|
||||||
|
@ -53,6 +53,7 @@ class ModelType(str, Enum):
|
|||||||
TextualInversion = "embedding"
|
TextualInversion = "embedding"
|
||||||
IPAdapter = "ip_adapter"
|
IPAdapter = "ip_adapter"
|
||||||
CLIPVision = "clip_vision"
|
CLIPVision = "clip_vision"
|
||||||
|
T2IAdapter = "t2i_adapter"
|
||||||
|
|
||||||
|
|
||||||
class SubModelType(str, Enum):
|
class SubModelType(str, Enum):
|
||||||
|
102
invokeai/backend/model_management/models/t2i_adapter.py
Normal file
102
invokeai/backend/model_management/models/t2i_adapter.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
import os
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import T2IAdapter
|
||||||
|
|
||||||
|
from invokeai.backend.model_management.models.base import (
|
||||||
|
BaseModelType,
|
||||||
|
EmptyConfigLoader,
|
||||||
|
InvalidModelException,
|
||||||
|
ModelBase,
|
||||||
|
ModelConfigBase,
|
||||||
|
ModelNotFoundException,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
calc_model_size_by_data,
|
||||||
|
calc_model_size_by_fs,
|
||||||
|
classproperty,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class T2IAdapterModelFormat(str, Enum):
|
||||||
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
|
|
||||||
|
class T2IAdapterModel(ModelBase):
|
||||||
|
class DiffusersConfig(ModelConfigBase):
|
||||||
|
model_format: Literal[T2IAdapterModelFormat.Diffusers]
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
assert model_type == ModelType.T2IAdapter
|
||||||
|
super().__init__(model_path, base_model, model_type)
|
||||||
|
|
||||||
|
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||||
|
|
||||||
|
model_class_name = config.get("_class_name", None)
|
||||||
|
if model_class_name not in {"T2IAdapter"}:
|
||||||
|
raise InvalidModelException(f"Invalid T2I-Adapter model. Unknown _class_name: '{model_class_name}'.")
|
||||||
|
|
||||||
|
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
|
||||||
|
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||||
|
|
||||||
|
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||||
|
if child_type is not None:
|
||||||
|
raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.")
|
||||||
|
return self.model_size
|
||||||
|
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
torch_dtype: Optional[torch.dtype],
|
||||||
|
child_type: Optional[SubModelType] = None,
|
||||||
|
) -> T2IAdapter:
|
||||||
|
if child_type is not None:
|
||||||
|
raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.")
|
||||||
|
|
||||||
|
model = None
|
||||||
|
for variant in ["fp16", None]:
|
||||||
|
try:
|
||||||
|
model = self.model_class.from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
variant=variant,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if not model:
|
||||||
|
raise ModelNotFoundException()
|
||||||
|
|
||||||
|
# Calculate a more accurate size after loading the model into memory.
|
||||||
|
self.model_size = calc_model_size_by_data(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@classproperty
|
||||||
|
def save_to_config(cls) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def detect_format(cls, path: str):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
raise ModelNotFoundException(f"Model not found at '{path}'.")
|
||||||
|
|
||||||
|
if os.path.isdir(path):
|
||||||
|
if os.path.exists(os.path.join(path, "config.json")):
|
||||||
|
return T2IAdapterModelFormat.Diffusers
|
||||||
|
|
||||||
|
raise InvalidModelException(f"Unsupported T2I-Adapter format: '{path}'.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_if_required(
|
||||||
|
cls,
|
||||||
|
model_path: str,
|
||||||
|
output_path: str,
|
||||||
|
config: ModelConfigBase,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
) -> str:
|
||||||
|
format = cls.detect_format(model_path)
|
||||||
|
if format == T2IAdapterModelFormat.Diffusers:
|
||||||
|
return model_path
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported format: '{format}'.")
|
@ -173,6 +173,16 @@ class IPAdapterData:
|
|||||||
end_step_percent: float = Field(default=1.0)
|
end_step_percent: float = Field(default=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class T2IAdapterData:
|
||||||
|
"""A structure containing the information required to apply conditioning from a single T2I-Adapter model."""
|
||||||
|
|
||||||
|
adapter_state: dict[torch.Tensor] = Field()
|
||||||
|
weight: Union[float, list[float]] = Field(default=1.0)
|
||||||
|
begin_step_percent: float = Field(default=0.0)
|
||||||
|
end_step_percent: float = Field(default=1.0)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
||||||
r"""
|
r"""
|
||||||
@ -327,6 +337,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
ip_adapter_data: Optional[IPAdapterData] = None,
|
ip_adapter_data: Optional[IPAdapterData] = None,
|
||||||
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
mask: Optional[torch.Tensor] = None,
|
mask: Optional[torch.Tensor] = None,
|
||||||
masked_latents: Optional[torch.Tensor] = None,
|
masked_latents: Optional[torch.Tensor] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
@ -379,6 +390,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
|
t2i_adapter_data=t2i_adapter_data,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
@ -399,6 +411,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
ip_adapter_data: Optional[IPAdapterData] = None,
|
ip_adapter_data: Optional[IPAdapterData] = None,
|
||||||
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
):
|
):
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
@ -454,6 +467,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
|
t2i_adapter_data=t2i_adapter_data,
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
|
|
||||||
@ -500,6 +514,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
ip_adapter_data: Optional[IPAdapterData] = None,
|
ip_adapter_data: Optional[IPAdapterData] = None,
|
||||||
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
):
|
):
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
@ -527,11 +542,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# otherwise, set IP-Adapter scale to 0, so it has no effect
|
# otherwise, set IP-Adapter scale to 0, so it has no effect
|
||||||
ip_adapter_data.ip_adapter_model.set_scale(0.0)
|
ip_adapter_data.ip_adapter_model.set_scale(0.0)
|
||||||
|
|
||||||
# handle ControlNet(s)
|
# Handle ControlNet(s) and T2I-Adapter(s)
|
||||||
# default is no controlnet, so set controlnet processing output to None
|
down_block_additional_residuals = None
|
||||||
controlnet_down_block_samples, controlnet_mid_block_sample = None, None
|
mid_block_additional_residual = None
|
||||||
if control_data is not None:
|
if control_data is not None and t2i_adapter_data is not None:
|
||||||
controlnet_down_block_samples, controlnet_mid_block_sample = self.invokeai_diffuser.do_controlnet_step(
|
# TODO(ryand): This is a limitation of the UNet2DConditionModel API, not a fundamental incompatibility
|
||||||
|
# between ControlNets and T2I-Adapters. We will try to fix this upstream in diffusers.
|
||||||
|
raise Exception("ControlNet(s) and T2I-Adapter(s) cannot be used simultaneously (yet).")
|
||||||
|
elif control_data is not None:
|
||||||
|
down_block_additional_residuals, mid_block_additional_residual = self.invokeai_diffuser.do_controlnet_step(
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
sample=latent_model_input,
|
sample=latent_model_input,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
@ -539,6 +558,32 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
)
|
)
|
||||||
|
elif t2i_adapter_data is not None:
|
||||||
|
accum_adapter_state = None
|
||||||
|
for single_t2i_adapter_data in t2i_adapter_data:
|
||||||
|
# Determine the T2I-Adapter weights for the current denoising step.
|
||||||
|
first_t2i_adapter_step = math.floor(single_t2i_adapter_data.begin_step_percent * total_step_count)
|
||||||
|
last_t2i_adapter_step = math.ceil(single_t2i_adapter_data.end_step_percent * total_step_count)
|
||||||
|
t2i_adapter_weight = (
|
||||||
|
single_t2i_adapter_data.weight[step_index]
|
||||||
|
if isinstance(single_t2i_adapter_data.weight, list)
|
||||||
|
else single_t2i_adapter_data.weight
|
||||||
|
)
|
||||||
|
if step_index < first_t2i_adapter_step or step_index > last_t2i_adapter_step:
|
||||||
|
# If the current step is outside of the T2I-Adapter's begin/end step range, then set its weight to 0
|
||||||
|
# so it has no effect.
|
||||||
|
t2i_adapter_weight = 0.0
|
||||||
|
|
||||||
|
# Apply the t2i_adapter_weight, and accumulate.
|
||||||
|
if accum_adapter_state is None:
|
||||||
|
# Handle the first T2I-Adapter.
|
||||||
|
accum_adapter_state = [val * t2i_adapter_weight for val in single_t2i_adapter_data.adapter_state]
|
||||||
|
else:
|
||||||
|
# Add to the previous adapter states.
|
||||||
|
for idx, value in enumerate(single_t2i_adapter_data.adapter_state):
|
||||||
|
accum_adapter_state[idx] += value * t2i_adapter_weight
|
||||||
|
|
||||||
|
down_block_additional_residuals = accum_adapter_state
|
||||||
|
|
||||||
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
||||||
sample=latent_model_input,
|
sample=latent_model_input,
|
||||||
@ -547,8 +592,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
# extra:
|
# extra:
|
||||||
down_block_additional_residuals=controlnet_down_block_samples, # from controlnet(s)
|
down_block_additional_residuals=down_block_additional_residuals,
|
||||||
mid_block_additional_residual=controlnet_mid_block_sample, # from controlnet(s)
|
mid_block_additional_residual=mid_block_additional_residual,
|
||||||
)
|
)
|
||||||
|
|
||||||
guidance_scale = conditioning_data.guidance_scale
|
guidance_scale = conditioning_data.guidance_scale
|
||||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -1,4 +1,4 @@
|
|||||||
import{w as s,hY as T,v as l,a2 as I,hZ as R,ae as V,h_ as z,h$ as j,i0 as D,i1 as F,i2 as G,i3 as W,i4 as K,aG as Y,i5 as Z,i6 as H}from"./index-94062f76.js";import{M as U}from"./MantineProvider-a057bfc9.js";var P=String.raw,E=P`
|
import{w as s,h$ as T,v as l,a2 as I,i0 as R,ae as V,i1 as z,i2 as j,i3 as D,i4 as F,i5 as G,i6 as W,i7 as K,aG as H,i8 as U,i9 as Y}from"./index-6f7e7659.js";import{M as Z}from"./MantineProvider-2072a471.js";var P=String.raw,E=P`
|
||||||
:root,
|
:root,
|
||||||
:host {
|
:host {
|
||||||
--chakra-vh: 100vh;
|
--chakra-vh: 100vh;
|
||||||
@ -277,4 +277,4 @@ import{w as s,hY as T,v as l,a2 as I,hZ as R,ae as V,h_ as z,h$ as j,i0 as D,i1
|
|||||||
}
|
}
|
||||||
|
|
||||||
${E}
|
${E}
|
||||||
`}),g={light:"chakra-ui-light",dark:"chakra-ui-dark"};function Q(e={}){const{preventTransition:o=!0}=e,n={setDataset:r=>{const t=o?n.preventTransition():void 0;document.documentElement.dataset.theme=r,document.documentElement.style.colorScheme=r,t==null||t()},setClassName(r){document.body.classList.add(r?g.dark:g.light),document.body.classList.remove(r?g.light:g.dark)},query(){return window.matchMedia("(prefers-color-scheme: dark)")},getSystemTheme(r){var t;return((t=n.query().matches)!=null?t:r==="dark")?"dark":"light"},addListener(r){const t=n.query(),i=a=>{r(a.matches?"dark":"light")};return typeof t.addListener=="function"?t.addListener(i):t.addEventListener("change",i),()=>{typeof t.removeListener=="function"?t.removeListener(i):t.removeEventListener("change",i)}},preventTransition(){const r=document.createElement("style");return r.appendChild(document.createTextNode("*{-webkit-transition:none!important;-moz-transition:none!important;-o-transition:none!important;-ms-transition:none!important;transition:none!important}")),document.head.appendChild(r),()=>{window.getComputedStyle(document.body),requestAnimationFrame(()=>{requestAnimationFrame(()=>{document.head.removeChild(r)})})}}};return n}var X="chakra-ui-color-mode";function L(e){return{ssr:!1,type:"localStorage",get(o){if(!(globalThis!=null&&globalThis.document))return o;let n;try{n=localStorage.getItem(e)||o}catch{}return n||o},set(o){try{localStorage.setItem(e,o)}catch{}}}}var ee=L(X),M=()=>{};function S(e,o){return e.type==="cookie"&&e.ssr?e.get(o):o}function O(e){const{value:o,children:n,options:{useSystemColorMode:r,initialColorMode:t,disableTransitionOnChange:i}={},colorModeManager:a=ee}=e,d=t==="dark"?"dark":"light",[u,p]=l.useState(()=>S(a,d)),[y,b]=l.useState(()=>S(a)),{getSystemTheme:w,setClassName:k,setDataset:x,addListener:$}=l.useMemo(()=>Q({preventTransition:i}),[i]),v=t==="system"&&!u?y:u,c=l.useCallback(h=>{const f=h==="system"?w():h;p(f),k(f==="dark"),x(f),a.set(f)},[a,w,k,x]);I(()=>{t==="system"&&b(w())},[]),l.useEffect(()=>{const h=a.get();if(h){c(h);return}if(t==="system"){c("system");return}c(d)},[a,d,t,c]);const C=l.useCallback(()=>{c(v==="dark"?"light":"dark")},[v,c]);l.useEffect(()=>{if(r)return $(c)},[r,$,c]);const A=l.useMemo(()=>({colorMode:o??v,toggleColorMode:o?M:C,setColorMode:o?M:c,forced:o!==void 0}),[v,C,c,o]);return s.jsx(R.Provider,{value:A,children:n})}O.displayName="ColorModeProvider";var te=["borders","breakpoints","colors","components","config","direction","fonts","fontSizes","fontWeights","letterSpacings","lineHeights","radii","shadows","sizes","space","styles","transition","zIndices"];function re(e){return V(e)?te.every(o=>Object.prototype.hasOwnProperty.call(e,o)):!1}function m(e){return typeof e=="function"}function oe(...e){return o=>e.reduce((n,r)=>r(n),o)}var ne=e=>function(...n){let r=[...n],t=n[n.length-1];return re(t)&&r.length>1?r=r.slice(0,r.length-1):t=e,oe(...r.map(i=>a=>m(i)?i(a):ae(a,i)))(t)},ie=ne(j);function ae(...e){return z({},...e,_)}function _(e,o,n,r){if((m(e)||m(o))&&Object.prototype.hasOwnProperty.call(r,n))return(...t)=>{const i=m(e)?e(...t):e,a=m(o)?o(...t):o;return z({},i,a,_)}}var q=l.createContext({getDocument(){return document},getWindow(){return window}});q.displayName="EnvironmentContext";function N(e){const{children:o,environment:n,disabled:r}=e,t=l.useRef(null),i=l.useMemo(()=>n||{getDocument:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument)!=null?u:document},getWindow:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument.defaultView)!=null?u:window}},[n]),a=!r||!n;return s.jsxs(q.Provider,{value:i,children:[o,a&&s.jsx("span",{id:"__chakra_env",hidden:!0,ref:t})]})}N.displayName="EnvironmentProvider";var se=e=>{const{children:o,colorModeManager:n,portalZIndex:r,resetScope:t,resetCSS:i=!0,theme:a={},environment:d,cssVarsRoot:u,disableEnvironment:p,disableGlobalStyle:y}=e,b=s.jsx(N,{environment:d,disabled:p,children:o});return s.jsx(D,{theme:a,cssVarsRoot:u,children:s.jsxs(O,{colorModeManager:n,options:a.config,children:[i?s.jsx(J,{scope:t}):s.jsx(B,{}),!y&&s.jsx(F,{}),r?s.jsx(G,{zIndex:r,children:b}):b]})})},le=e=>function({children:n,theme:r=e,toastOptions:t,...i}){return s.jsxs(se,{theme:r,...i,children:[s.jsx(W,{value:t==null?void 0:t.defaultOptions,children:n}),s.jsx(K,{...t})]})},de=le(j);const ue=()=>l.useMemo(()=>({colorScheme:"dark",fontFamily:"'Inter Variable', sans-serif",components:{ScrollArea:{defaultProps:{scrollbarSize:10},styles:{scrollbar:{"&:hover":{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}},thumb:{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}}}}}),[]),ce=L("@@invokeai-color-mode");function he({children:e}){const{i18n:o}=Y(),n=o.dir(),r=l.useMemo(()=>ie({...Z,direction:n}),[n]);l.useEffect(()=>{document.body.dir=n},[n]);const t=ue();return s.jsx(U,{theme:t,children:s.jsx(de,{theme:r,colorModeManager:ce,toastOptions:H,children:e})})}const ve=l.memo(he);export{ve as default};
|
`}),g={light:"chakra-ui-light",dark:"chakra-ui-dark"};function Q(e={}){const{preventTransition:o=!0}=e,n={setDataset:r=>{const t=o?n.preventTransition():void 0;document.documentElement.dataset.theme=r,document.documentElement.style.colorScheme=r,t==null||t()},setClassName(r){document.body.classList.add(r?g.dark:g.light),document.body.classList.remove(r?g.light:g.dark)},query(){return window.matchMedia("(prefers-color-scheme: dark)")},getSystemTheme(r){var t;return((t=n.query().matches)!=null?t:r==="dark")?"dark":"light"},addListener(r){const t=n.query(),i=a=>{r(a.matches?"dark":"light")};return typeof t.addListener=="function"?t.addListener(i):t.addEventListener("change",i),()=>{typeof t.removeListener=="function"?t.removeListener(i):t.removeEventListener("change",i)}},preventTransition(){const r=document.createElement("style");return r.appendChild(document.createTextNode("*{-webkit-transition:none!important;-moz-transition:none!important;-o-transition:none!important;-ms-transition:none!important;transition:none!important}")),document.head.appendChild(r),()=>{window.getComputedStyle(document.body),requestAnimationFrame(()=>{requestAnimationFrame(()=>{document.head.removeChild(r)})})}}};return n}var X="chakra-ui-color-mode";function L(e){return{ssr:!1,type:"localStorage",get(o){if(!(globalThis!=null&&globalThis.document))return o;let n;try{n=localStorage.getItem(e)||o}catch{}return n||o},set(o){try{localStorage.setItem(e,o)}catch{}}}}var ee=L(X),M=()=>{};function S(e,o){return e.type==="cookie"&&e.ssr?e.get(o):o}function O(e){const{value:o,children:n,options:{useSystemColorMode:r,initialColorMode:t,disableTransitionOnChange:i}={},colorModeManager:a=ee}=e,d=t==="dark"?"dark":"light",[u,p]=l.useState(()=>S(a,d)),[y,b]=l.useState(()=>S(a)),{getSystemTheme:w,setClassName:k,setDataset:x,addListener:$}=l.useMemo(()=>Q({preventTransition:i}),[i]),v=t==="system"&&!u?y:u,c=l.useCallback(h=>{const f=h==="system"?w():h;p(f),k(f==="dark"),x(f),a.set(f)},[a,w,k,x]);I(()=>{t==="system"&&b(w())},[]),l.useEffect(()=>{const h=a.get();if(h){c(h);return}if(t==="system"){c("system");return}c(d)},[a,d,t,c]);const C=l.useCallback(()=>{c(v==="dark"?"light":"dark")},[v,c]);l.useEffect(()=>{if(r)return $(c)},[r,$,c]);const A=l.useMemo(()=>({colorMode:o??v,toggleColorMode:o?M:C,setColorMode:o?M:c,forced:o!==void 0}),[v,C,c,o]);return s.jsx(R.Provider,{value:A,children:n})}O.displayName="ColorModeProvider";var te=["borders","breakpoints","colors","components","config","direction","fonts","fontSizes","fontWeights","letterSpacings","lineHeights","radii","shadows","sizes","space","styles","transition","zIndices"];function re(e){return V(e)?te.every(o=>Object.prototype.hasOwnProperty.call(e,o)):!1}function m(e){return typeof e=="function"}function oe(...e){return o=>e.reduce((n,r)=>r(n),o)}var ne=e=>function(...n){let r=[...n],t=n[n.length-1];return re(t)&&r.length>1?r=r.slice(0,r.length-1):t=e,oe(...r.map(i=>a=>m(i)?i(a):ae(a,i)))(t)},ie=ne(j);function ae(...e){return z({},...e,_)}function _(e,o,n,r){if((m(e)||m(o))&&Object.prototype.hasOwnProperty.call(r,n))return(...t)=>{const i=m(e)?e(...t):e,a=m(o)?o(...t):o;return z({},i,a,_)}}var q=l.createContext({getDocument(){return document},getWindow(){return window}});q.displayName="EnvironmentContext";function N(e){const{children:o,environment:n,disabled:r}=e,t=l.useRef(null),i=l.useMemo(()=>n||{getDocument:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument)!=null?u:document},getWindow:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument.defaultView)!=null?u:window}},[n]),a=!r||!n;return s.jsxs(q.Provider,{value:i,children:[o,a&&s.jsx("span",{id:"__chakra_env",hidden:!0,ref:t})]})}N.displayName="EnvironmentProvider";var se=e=>{const{children:o,colorModeManager:n,portalZIndex:r,resetScope:t,resetCSS:i=!0,theme:a={},environment:d,cssVarsRoot:u,disableEnvironment:p,disableGlobalStyle:y}=e,b=s.jsx(N,{environment:d,disabled:p,children:o});return s.jsx(D,{theme:a,cssVarsRoot:u,children:s.jsxs(O,{colorModeManager:n,options:a.config,children:[i?s.jsx(J,{scope:t}):s.jsx(B,{}),!y&&s.jsx(F,{}),r?s.jsx(G,{zIndex:r,children:b}):b]})})},le=e=>function({children:n,theme:r=e,toastOptions:t,...i}){return s.jsxs(se,{theme:r,...i,children:[s.jsx(W,{value:t==null?void 0:t.defaultOptions,children:n}),s.jsx(K,{...t})]})},de=le(j);const ue=()=>l.useMemo(()=>({colorScheme:"dark",fontFamily:"'Inter Variable', sans-serif",components:{ScrollArea:{defaultProps:{scrollbarSize:10},styles:{scrollbar:{"&:hover":{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}},thumb:{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}}}}}),[]),ce=L("@@invokeai-color-mode");function he({children:e}){const{i18n:o}=H(),n=o.dir(),r=l.useMemo(()=>ie({...U,direction:n}),[n]);l.useEffect(()=>{document.body.dir=n},[n]);const t=ue();return s.jsx(Z,{theme:t,children:s.jsx(de,{theme:r,colorModeManager:ce,toastOptions:Y,children:e})})}const ve=l.memo(he);export{ve as default};
|
158
invokeai/frontend/web/dist/assets/index-6f7e7659.js
vendored
Normal file
158
invokeai/frontend/web/dist/assets/index-6f7e7659.js
vendored
Normal file
File diff suppressed because one or more lines are too long
158
invokeai/frontend/web/dist/assets/index-94062f76.js
vendored
158
invokeai/frontend/web/dist/assets/index-94062f76.js
vendored
File diff suppressed because one or more lines are too long
5
invokeai/frontend/web/dist/index.html
vendored
5
invokeai/frontend/web/dist/index.html
vendored
@ -3,6 +3,9 @@
|
|||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8" />
|
<meta charset="UTF-8" />
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate">
|
||||||
|
<meta http-equiv="Pragma" content="no-cache">
|
||||||
|
<meta http-equiv="Expires" content="0">
|
||||||
<title>InvokeAI - A Stable Diffusion Toolkit</title>
|
<title>InvokeAI - A Stable Diffusion Toolkit</title>
|
||||||
<link rel="shortcut icon" type="icon" href="./assets/favicon-0d253ced.ico" />
|
<link rel="shortcut icon" type="icon" href="./assets/favicon-0d253ced.ico" />
|
||||||
<style>
|
<style>
|
||||||
@ -12,7 +15,7 @@
|
|||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
<script type="module" crossorigin src="./assets/index-94062f76.js"></script>
|
<script type="module" crossorigin src="./assets/index-6f7e7659.js"></script>
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
<body dir="ltr">
|
<body dir="ltr">
|
||||||
|
2
invokeai/frontend/web/dist/locales/en.json
vendored
2
invokeai/frontend/web/dist/locales/en.json
vendored
@ -697,7 +697,7 @@
|
|||||||
"noLoRAsAvailable": "No LoRAs available",
|
"noLoRAsAvailable": "No LoRAs available",
|
||||||
"noMatchingLoRAs": "No matching LoRAs",
|
"noMatchingLoRAs": "No matching LoRAs",
|
||||||
"noMatchingModels": "No matching Models",
|
"noMatchingModels": "No matching Models",
|
||||||
"noModelsAvailable": "No Modelss available",
|
"noModelsAvailable": "No models available",
|
||||||
"selectLoRA": "Select a LoRA",
|
"selectLoRA": "Select a LoRA",
|
||||||
"selectModel": "Select a Model"
|
"selectModel": "Select a Model"
|
||||||
},
|
},
|
||||||
|
@ -3,6 +3,9 @@
|
|||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8" />
|
<meta charset="UTF-8" />
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate">
|
||||||
|
<meta http-equiv="Pragma" content="no-cache">
|
||||||
|
<meta http-equiv="Expires" content="0">
|
||||||
<title>InvokeAI - A Stable Diffusion Toolkit</title>
|
<title>InvokeAI - A Stable Diffusion Toolkit</title>
|
||||||
<link rel="shortcut icon" type="icon" href="favicon.ico" />
|
<link rel="shortcut icon" type="icon" href="favicon.ico" />
|
||||||
<style>
|
<style>
|
||||||
|
@ -697,7 +697,7 @@
|
|||||||
"noLoRAsAvailable": "No LoRAs available",
|
"noLoRAsAvailable": "No LoRAs available",
|
||||||
"noMatchingLoRAs": "No matching LoRAs",
|
"noMatchingLoRAs": "No matching LoRAs",
|
||||||
"noMatchingModels": "No matching Models",
|
"noMatchingModels": "No matching Models",
|
||||||
"noModelsAvailable": "No Modelss available",
|
"noModelsAvailable": "No models available",
|
||||||
"selectLoRA": "Select a LoRA",
|
"selectLoRA": "Select a LoRA",
|
||||||
"selectModel": "Select a Model"
|
"selectModel": "Select a Model"
|
||||||
},
|
},
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
|
import {
|
||||||
|
controlNetRemoved,
|
||||||
|
ipAdapterModelChanged,
|
||||||
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { loraRemoved } from 'features/lora/store/loraSlice';
|
import { loraRemoved } from 'features/lora/store/loraSlice';
|
||||||
import {
|
import {
|
||||||
modelChanged,
|
modelChanged,
|
||||||
@ -16,12 +19,14 @@ import {
|
|||||||
} from 'features/sdxl/store/sdxlSlice';
|
} from 'features/sdxl/store/sdxlSlice';
|
||||||
import { forEach, some } from 'lodash-es';
|
import { forEach, some } from 'lodash-es';
|
||||||
import {
|
import {
|
||||||
|
ipAdapterModelsAdapter,
|
||||||
mainModelsAdapter,
|
mainModelsAdapter,
|
||||||
modelsApi,
|
modelsApi,
|
||||||
vaeModelsAdapter,
|
vaeModelsAdapter,
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/endpoints/models';
|
||||||
import { TypeGuardFor } from 'services/api/types';
|
import { TypeGuardFor } from 'services/api/types';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
import { zIPAdapterModel } from 'features/nodes/types/types';
|
||||||
|
|
||||||
export const addModelsLoadedListener = () => {
|
export const addModelsLoadedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
@ -234,6 +239,50 @@ export const addModelsLoadedListener = () => {
|
|||||||
});
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
startAppListening({
|
||||||
|
matcher: modelsApi.endpoints.getIPAdapterModels.matchFulfilled,
|
||||||
|
effect: async (action, { getState, dispatch }) => {
|
||||||
|
// ControlNet models loaded - need to remove missing ControlNets from state
|
||||||
|
const log = logger('models');
|
||||||
|
log.info(
|
||||||
|
{ models: action.payload.entities },
|
||||||
|
`IP Adapter models loaded (${action.payload.ids.length})`
|
||||||
|
);
|
||||||
|
|
||||||
|
const { model } = getState().controlNet.ipAdapterInfo;
|
||||||
|
|
||||||
|
const isModelAvailable = some(
|
||||||
|
action.payload.entities,
|
||||||
|
(m) =>
|
||||||
|
m?.model_name === model?.model_name &&
|
||||||
|
m?.base_model === model?.base_model
|
||||||
|
);
|
||||||
|
|
||||||
|
if (isModelAvailable) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const firstModel = ipAdapterModelsAdapter
|
||||||
|
.getSelectors()
|
||||||
|
.selectAll(action.payload)[0];
|
||||||
|
|
||||||
|
if (!firstModel) {
|
||||||
|
dispatch(ipAdapterModelChanged(null));
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = zIPAdapterModel.safeParse(firstModel);
|
||||||
|
|
||||||
|
if (!result.success) {
|
||||||
|
log.error(
|
||||||
|
{ error: result.error.format() },
|
||||||
|
'Failed to parse IP Adapter model'
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(ipAdapterModelChanged(result.data));
|
||||||
|
},
|
||||||
|
});
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: modelsApi.endpoints.getTextualInversionModels.matchFulfilled,
|
matcher: modelsApi.endpoints.getTextualInversionModels.matchFulfilled,
|
||||||
effect: async (action) => {
|
effect: async (action) => {
|
||||||
|
@ -8,6 +8,7 @@ import {
|
|||||||
} from 'features/gallery/store/gallerySlice';
|
} from 'features/gallery/store/gallerySlice';
|
||||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||||
import { CANVAS_OUTPUT } from 'features/nodes/util/graphBuilders/constants';
|
import { CANVAS_OUTPUT } from 'features/nodes/util/graphBuilders/constants';
|
||||||
|
import { boardsApi } from 'services/api/endpoints/boards';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { isImageOutput } from 'services/api/guards';
|
import { isImageOutput } from 'services/api/guards';
|
||||||
import { imagesAdapter } from 'services/api/util';
|
import { imagesAdapter } from 'services/api/util';
|
||||||
@ -70,11 +71,21 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// update the total images for the board
|
||||||
|
dispatch(
|
||||||
|
boardsApi.util.updateQueryData(
|
||||||
|
'getBoardImagesTotal',
|
||||||
|
imageDTO.board_id ?? 'none',
|
||||||
|
(draft) => {
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||||
|
draft.total += 1;
|
||||||
|
}
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
imagesApi.util.invalidateTags([
|
imagesApi.util.invalidateTags([
|
||||||
{ type: 'BoardImagesTotal', id: imageDTO.board_id },
|
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
|
||||||
{ type: 'BoardAssetsTotal', id: imageDTO.board_id },
|
|
||||||
{ type: 'Board', id: imageDTO.board_id },
|
|
||||||
])
|
])
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -5,8 +5,23 @@ import ParamIPAdapterFeatureToggle from './ParamIPAdapterFeatureToggle';
|
|||||||
import ParamIPAdapterImage from './ParamIPAdapterImage';
|
import ParamIPAdapterImage from './ParamIPAdapterImage';
|
||||||
import ParamIPAdapterModelSelect from './ParamIPAdapterModelSelect';
|
import ParamIPAdapterModelSelect from './ParamIPAdapterModelSelect';
|
||||||
import ParamIPAdapterWeight from './ParamIPAdapterWeight';
|
import ParamIPAdapterWeight from './ParamIPAdapterWeight';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from '../../../../app/store/store';
|
||||||
|
import { defaultSelectorOptions } from '../../../../app/store/util/defaultMemoizeOptions';
|
||||||
|
import { useAppSelector } from '../../../../app/store/storeHooks';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
(state) => {
|
||||||
|
const { isIPAdapterEnabled } = state.controlNet;
|
||||||
|
|
||||||
|
return { isIPAdapterEnabled };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
const IPAdapterPanel = () => {
|
const IPAdapterPanel = () => {
|
||||||
|
const { isIPAdapterEnabled } = useAppSelector(selector);
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
sx={{
|
sx={{
|
||||||
@ -14,7 +29,6 @@ const IPAdapterPanel = () => {
|
|||||||
gap: 3,
|
gap: 3,
|
||||||
paddingInline: 3,
|
paddingInline: 3,
|
||||||
paddingBlock: 2,
|
paddingBlock: 2,
|
||||||
paddingBottom: 5,
|
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
position: 'relative',
|
position: 'relative',
|
||||||
bg: 'base.250',
|
bg: 'base.250',
|
||||||
@ -24,11 +38,27 @@ const IPAdapterPanel = () => {
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<ParamIPAdapterFeatureToggle />
|
<ParamIPAdapterFeatureToggle />
|
||||||
<ParamIPAdapterImage />
|
{isIPAdapterEnabled && (
|
||||||
|
<>
|
||||||
<ParamIPAdapterModelSelect />
|
<ParamIPAdapterModelSelect />
|
||||||
|
<Flex gap="3">
|
||||||
|
<Flex
|
||||||
|
flexDirection="column"
|
||||||
|
sx={{
|
||||||
|
h: 28,
|
||||||
|
w: 'full',
|
||||||
|
gap: 4,
|
||||||
|
mb: 4,
|
||||||
|
}}
|
||||||
|
>
|
||||||
<ParamIPAdapterWeight />
|
<ParamIPAdapterWeight />
|
||||||
<ParamIPAdapterBeginEnd />
|
<ParamIPAdapterBeginEnd />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
<ParamIPAdapterImage />
|
||||||
|
</Flex>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -66,7 +66,8 @@ const ParamIPAdapterImage = () => {
|
|||||||
layerStyle="second"
|
layerStyle="second"
|
||||||
sx={{
|
sx={{
|
||||||
position: 'relative',
|
position: 'relative',
|
||||||
w: 'full',
|
h: 28,
|
||||||
|
w: 28,
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
justifyContent: 'center',
|
justifyContent: 'center',
|
||||||
aspectRatio: '1/1',
|
aspectRatio: '1/1',
|
||||||
|
@ -88,12 +88,16 @@ const ParamIPAdapterModelSelect = () => {
|
|||||||
className="nowheel nodrag"
|
className="nowheel nodrag"
|
||||||
tooltip={selectedModel?.description}
|
tooltip={selectedModel?.description}
|
||||||
value={selectedModel?.id ?? null}
|
value={selectedModel?.id ?? null}
|
||||||
placeholder="Pick one"
|
placeholder={
|
||||||
error={!selectedModel}
|
data.length > 0
|
||||||
|
? t('models.selectModel')
|
||||||
|
: t('models.noModelsAvailable')
|
||||||
|
}
|
||||||
|
error={!selectedModel && data.length > 0}
|
||||||
data={data}
|
data={data}
|
||||||
onChange={handleValueChanged}
|
onChange={handleValueChanged}
|
||||||
sx={{ width: '100%' }}
|
sx={{ width: '100%' }}
|
||||||
disabled={!isEnabled}
|
disabled={!isEnabled || data.length === 0}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -53,8 +53,12 @@ export const isValidDrop = (
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (payloadType === 'IMAGE_DTOS') {
|
if (payloadType === 'IMAGE_DTOS') {
|
||||||
// TODO (multi-select)
|
// Assume all images are on the same board - this is true for the moment
|
||||||
return true;
|
const { imageDTOs } = active.data.current.payload;
|
||||||
|
const currentBoard = imageDTOs[0]?.board_id ?? 'none';
|
||||||
|
const destinationBoard = overData.context.boardId;
|
||||||
|
|
||||||
|
return currentBoard !== destinationBoard;
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
@ -71,14 +75,17 @@ export const isValidDrop = (
|
|||||||
// Check if the image's board is the board we are dragging onto
|
// Check if the image's board is the board we are dragging onto
|
||||||
if (payloadType === 'IMAGE_DTO') {
|
if (payloadType === 'IMAGE_DTO') {
|
||||||
const { imageDTO } = active.data.current.payload;
|
const { imageDTO } = active.data.current.payload;
|
||||||
const currentBoard = imageDTO.board_id;
|
const currentBoard = imageDTO.board_id ?? 'none';
|
||||||
|
|
||||||
return currentBoard !== 'none';
|
return currentBoard !== 'none';
|
||||||
}
|
}
|
||||||
|
|
||||||
if (payloadType === 'IMAGE_DTOS') {
|
if (payloadType === 'IMAGE_DTOS') {
|
||||||
// TODO (multi-select)
|
// Assume all images are on the same board - this is true for the moment
|
||||||
return true;
|
const { imageDTOs } = active.data.current.payload;
|
||||||
|
const currentBoard = imageDTOs[0]?.board_id ?? 'none';
|
||||||
|
|
||||||
|
return currentBoard !== 'none';
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
|
@ -77,12 +77,12 @@ const GalleryBoard = ({
|
|||||||
const { data: imagesTotal } = useGetBoardImagesTotalQuery(board.board_id);
|
const { data: imagesTotal } = useGetBoardImagesTotalQuery(board.board_id);
|
||||||
const { data: assetsTotal } = useGetBoardAssetsTotalQuery(board.board_id);
|
const { data: assetsTotal } = useGetBoardAssetsTotalQuery(board.board_id);
|
||||||
const tooltip = useMemo(() => {
|
const tooltip = useMemo(() => {
|
||||||
if (!imagesTotal || !assetsTotal) {
|
if (imagesTotal?.total === undefined || assetsTotal?.total === undefined) {
|
||||||
return undefined;
|
return undefined;
|
||||||
}
|
}
|
||||||
return `${imagesTotal} image${
|
return `${imagesTotal.total} image${imagesTotal.total === 1 ? '' : 's'}, ${
|
||||||
imagesTotal > 1 ? 's' : ''
|
assetsTotal.total
|
||||||
}, ${assetsTotal} asset${assetsTotal > 1 ? 's' : ''}`;
|
} asset${assetsTotal.total === 1 ? '' : 's'}`;
|
||||||
}, [assetsTotal, imagesTotal]);
|
}, [assetsTotal, imagesTotal]);
|
||||||
|
|
||||||
const { currentData: coverImage } = useGetImageDTOQuery(
|
const { currentData: coverImage } = useGetImageDTOQuery(
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Box, Flex, Image, Text } from '@chakra-ui/react';
|
import { Box, Flex, Image, Text, Tooltip } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
@ -15,6 +15,10 @@ import { memo, useCallback, useMemo, useState } from 'react';
|
|||||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||||
import AutoAddIcon from '../AutoAddIcon';
|
import AutoAddIcon from '../AutoAddIcon';
|
||||||
import BoardContextMenu from '../BoardContextMenu';
|
import BoardContextMenu from '../BoardContextMenu';
|
||||||
|
import {
|
||||||
|
useGetBoardAssetsTotalQuery,
|
||||||
|
useGetBoardImagesTotalQuery,
|
||||||
|
} from 'services/api/endpoints/boards';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
isSelected: boolean;
|
isSelected: boolean;
|
||||||
@ -41,6 +45,17 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
|||||||
}, [dispatch, autoAssignBoardOnClick]);
|
}, [dispatch, autoAssignBoardOnClick]);
|
||||||
const [isHovered, setIsHovered] = useState(false);
|
const [isHovered, setIsHovered] = useState(false);
|
||||||
|
|
||||||
|
const { data: imagesTotal } = useGetBoardImagesTotalQuery('none');
|
||||||
|
const { data: assetsTotal } = useGetBoardAssetsTotalQuery('none');
|
||||||
|
const tooltip = useMemo(() => {
|
||||||
|
if (imagesTotal?.total === undefined || assetsTotal?.total === undefined) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
return `${imagesTotal.total} image${imagesTotal.total === 1 ? '' : 's'}, ${
|
||||||
|
assetsTotal.total
|
||||||
|
} asset${assetsTotal.total === 1 ? '' : 's'}`;
|
||||||
|
}, [assetsTotal, imagesTotal]);
|
||||||
|
|
||||||
const handleMouseOver = useCallback(() => {
|
const handleMouseOver = useCallback(() => {
|
||||||
setIsHovered(true);
|
setIsHovered(true);
|
||||||
}, []);
|
}, []);
|
||||||
@ -74,6 +89,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
|||||||
>
|
>
|
||||||
<BoardContextMenu board_id="none">
|
<BoardContextMenu board_id="none">
|
||||||
{(ref) => (
|
{(ref) => (
|
||||||
|
<Tooltip label={tooltip} openDelay={1000} hasArrow>
|
||||||
<Flex
|
<Flex
|
||||||
ref={ref}
|
ref={ref}
|
||||||
onClick={handleSelectBoard}
|
onClick={handleSelectBoard}
|
||||||
@ -139,12 +155,16 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
|||||||
>
|
>
|
||||||
{boardName}
|
{boardName}
|
||||||
</Flex>
|
</Flex>
|
||||||
<SelectionOverlay isSelected={isSelected} isHovered={isHovered} />
|
<SelectionOverlay
|
||||||
|
isSelected={isSelected}
|
||||||
|
isHovered={isHovered}
|
||||||
|
/>
|
||||||
<IAIDroppable
|
<IAIDroppable
|
||||||
data={droppableData}
|
data={droppableData}
|
||||||
dropLabel={<Text fontSize="md">Move</Text>}
|
dropLabel={<Text fontSize="md">Move</Text>}
|
||||||
/>
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
</Tooltip>
|
||||||
)}
|
)}
|
||||||
</BoardContextMenu>
|
</BoardContextMenu>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -3,13 +3,12 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
|
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||||
import { uniq } from 'lodash-es';
|
|
||||||
import { MouseEvent, useCallback, useMemo } from 'react';
|
import { MouseEvent, useCallback, useMemo } from 'react';
|
||||||
import { useListImagesQuery } from 'services/api/endpoints/images';
|
import { useListImagesQuery } from 'services/api/endpoints/images';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import { selectionChanged } from '../store/gallerySlice';
|
|
||||||
import { imagesSelectors } from 'services/api/util';
|
import { imagesSelectors } from 'services/api/util';
|
||||||
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
|
||||||
|
import { selectionChanged } from '../store/gallerySlice';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector, selectListImagesBaseQueryArgs],
|
[stateSelector, selectListImagesBaseQueryArgs],
|
||||||
@ -60,7 +59,7 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
|
|||||||
const start = Math.min(lastClickedIndex, currentClickedIndex);
|
const start = Math.min(lastClickedIndex, currentClickedIndex);
|
||||||
const end = Math.max(lastClickedIndex, currentClickedIndex);
|
const end = Math.max(lastClickedIndex, currentClickedIndex);
|
||||||
const imagesToSelect = imageDTOs.slice(start, end + 1);
|
const imagesToSelect = imageDTOs.slice(start, end + 1);
|
||||||
dispatch(selectionChanged(uniq(selection.concat(imagesToSelect))));
|
dispatch(selectionChanged(selection.concat(imagesToSelect)));
|
||||||
}
|
}
|
||||||
} else if (e.ctrlKey || e.metaKey) {
|
} else if (e.ctrlKey || e.metaKey) {
|
||||||
if (
|
if (
|
||||||
@ -73,7 +72,7 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
|
|||||||
)
|
)
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
dispatch(selectionChanged(uniq(selection.concat(imageDTO))));
|
dispatch(selectionChanged(selection.concat(imageDTO)));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
dispatch(selectionChanged([imageDTO]));
|
dispatch(selectionChanged([imageDTO]));
|
||||||
|
@ -20,7 +20,7 @@ export const nextPrevImageButtonsSelector = createSelector(
|
|||||||
const { data, status } =
|
const { data, status } =
|
||||||
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
||||||
|
|
||||||
const { data: total } =
|
const { data: totalsData } =
|
||||||
state.gallery.galleryView === 'images'
|
state.gallery.galleryView === 'images'
|
||||||
? boardsApi.endpoints.getBoardImagesTotal.select(
|
? boardsApi.endpoints.getBoardImagesTotal.select(
|
||||||
baseQueryArgs.board_id ?? 'none'
|
baseQueryArgs.board_id ?? 'none'
|
||||||
@ -34,7 +34,7 @@ export const nextPrevImageButtonsSelector = createSelector(
|
|||||||
|
|
||||||
const isFetching = status === 'pending';
|
const isFetching = status === 'pending';
|
||||||
|
|
||||||
if (!data || !lastSelectedImage || total === 0) {
|
if (!data || !lastSelectedImage || totalsData?.total === 0) {
|
||||||
return {
|
return {
|
||||||
isFetching,
|
isFetching,
|
||||||
queryArgs: baseQueryArgs,
|
queryArgs: baseQueryArgs,
|
||||||
@ -74,7 +74,7 @@ export const nextPrevImageButtonsSelector = createSelector(
|
|||||||
return {
|
return {
|
||||||
loadedImagesCount: images.length,
|
loadedImagesCount: images.length,
|
||||||
currentImageIndex,
|
currentImageIndex,
|
||||||
areMoreImagesAvailable: (total ?? 0) > imagesLength,
|
areMoreImagesAvailable: (totalsData?.total ?? 0) > imagesLength,
|
||||||
isFetching: status === 'pending',
|
isFetching: status === 'pending',
|
||||||
nextImage,
|
nextImage,
|
||||||
prevImage,
|
prevImage,
|
||||||
|
@ -4,6 +4,7 @@ import { boardsApi } from 'services/api/endpoints/boards';
|
|||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import { BoardId, GalleryState, GalleryView } from './types';
|
import { BoardId, GalleryState, GalleryView } from './types';
|
||||||
|
import { uniqBy } from 'lodash-es';
|
||||||
|
|
||||||
export const initialGalleryState: GalleryState = {
|
export const initialGalleryState: GalleryState = {
|
||||||
selection: [],
|
selection: [],
|
||||||
@ -24,7 +25,7 @@ export const gallerySlice = createSlice({
|
|||||||
state.selection = action.payload ? [action.payload] : [];
|
state.selection = action.payload ? [action.payload] : [];
|
||||||
},
|
},
|
||||||
selectionChanged: (state, action: PayloadAction<ImageDTO[]>) => {
|
selectionChanged: (state, action: PayloadAction<ImageDTO[]>) => {
|
||||||
state.selection = action.payload;
|
state.selection = uniqBy(action.payload, (i) => i.image_name);
|
||||||
},
|
},
|
||||||
shouldAutoSwitchChanged: (state, action: PayloadAction<boolean>) => {
|
shouldAutoSwitchChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldAutoSwitch = action.payload;
|
state.shouldAutoSwitch = action.payload;
|
||||||
|
@ -16,6 +16,7 @@ import SchedulerInputField from './inputs/SchedulerInputField';
|
|||||||
import StringInputField from './inputs/StringInputField';
|
import StringInputField from './inputs/StringInputField';
|
||||||
import VaeModelInputField from './inputs/VaeModelInputField';
|
import VaeModelInputField from './inputs/VaeModelInputField';
|
||||||
import IPAdapterModelInputField from './inputs/IPAdapterModelInputField';
|
import IPAdapterModelInputField from './inputs/IPAdapterModelInputField';
|
||||||
|
import T2IAdapterModelInputField from './inputs/T2IAdapterModelInputField';
|
||||||
import BoardInputField from './inputs/BoardInputField';
|
import BoardInputField from './inputs/BoardInputField';
|
||||||
|
|
||||||
type InputFieldProps = {
|
type InputFieldProps = {
|
||||||
@ -188,6 +189,18 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
field?.type === 'T2IAdapterModelField' &&
|
||||||
|
fieldTemplate?.type === 'T2IAdapterModelField'
|
||||||
|
) {
|
||||||
|
return (
|
||||||
|
<T2IAdapterModelInputField
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
fieldTemplate={fieldTemplate}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
|
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
|
||||||
return (
|
return (
|
||||||
<ColorInputField
|
<ColorInputField
|
||||||
|
@ -0,0 +1,19 @@
|
|||||||
|
import {
|
||||||
|
T2IAdapterInputFieldTemplate,
|
||||||
|
T2IAdapterInputFieldValue,
|
||||||
|
T2IAdapterPolymorphicInputFieldTemplate,
|
||||||
|
T2IAdapterPolymorphicInputFieldValue,
|
||||||
|
FieldComponentProps,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
|
const T2IAdapterInputFieldComponent = (
|
||||||
|
_props: FieldComponentProps<
|
||||||
|
T2IAdapterInputFieldValue | T2IAdapterPolymorphicInputFieldValue,
|
||||||
|
T2IAdapterInputFieldTemplate | T2IAdapterPolymorphicInputFieldTemplate
|
||||||
|
>
|
||||||
|
) => {
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(T2IAdapterInputFieldComponent);
|
@ -0,0 +1,100 @@
|
|||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import {
|
||||||
|
T2IAdapterModelInputFieldTemplate,
|
||||||
|
T2IAdapterModelInputFieldValue,
|
||||||
|
FieldComponentProps,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
|
import { modelIdToT2IAdapterModelParam } from 'features/parameters/util/modelIdToT2IAdapterModelParam';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
const T2IAdapterModelInputFieldComponent = (
|
||||||
|
props: FieldComponentProps<
|
||||||
|
T2IAdapterModelInputFieldValue,
|
||||||
|
T2IAdapterModelInputFieldTemplate
|
||||||
|
>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
const t2iAdapterModel = field.value;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
|
||||||
|
|
||||||
|
// grab the full model entity from the RTK Query cache
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() =>
|
||||||
|
t2iAdapterModels?.entities[
|
||||||
|
`${t2iAdapterModel?.base_model}/t2i_adapter/${t2iAdapterModel?.model_name}`
|
||||||
|
] ?? null,
|
||||||
|
[
|
||||||
|
t2iAdapterModel?.base_model,
|
||||||
|
t2iAdapterModel?.model_name,
|
||||||
|
t2iAdapterModels?.entities,
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!t2iAdapterModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(t2iAdapterModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.model_name,
|
||||||
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [t2iAdapterModels]);
|
||||||
|
|
||||||
|
const handleValueChanged = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const newT2IAdapterModel = modelIdToT2IAdapterModelParam(v);
|
||||||
|
|
||||||
|
if (!newT2IAdapterModel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
fieldT2IAdapterModelValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value: newT2IAdapterModel,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIMantineSelect
|
||||||
|
className="nowheel nodrag"
|
||||||
|
tooltip={selectedModel?.description}
|
||||||
|
value={selectedModel?.id ?? null}
|
||||||
|
placeholder="Pick one"
|
||||||
|
error={!selectedModel}
|
||||||
|
data={data}
|
||||||
|
onChange={handleValueChanged}
|
||||||
|
sx={{ width: '100%' }}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(T2IAdapterModelInputFieldComponent);
|
@ -55,6 +55,7 @@ import {
|
|||||||
SchedulerInputFieldValue,
|
SchedulerInputFieldValue,
|
||||||
SDXLRefinerModelInputFieldValue,
|
SDXLRefinerModelInputFieldValue,
|
||||||
StringInputFieldValue,
|
StringInputFieldValue,
|
||||||
|
T2IAdapterModelInputFieldValue,
|
||||||
VaeModelInputFieldValue,
|
VaeModelInputFieldValue,
|
||||||
Workflow,
|
Workflow,
|
||||||
} from '../types/types';
|
} from '../types/types';
|
||||||
@ -645,6 +646,12 @@ const nodesSlice = createSlice({
|
|||||||
) => {
|
) => {
|
||||||
fieldValueReducer(state, action);
|
fieldValueReducer(state, action);
|
||||||
},
|
},
|
||||||
|
fieldT2IAdapterModelValueChanged: (
|
||||||
|
state,
|
||||||
|
action: FieldValueAction<T2IAdapterModelInputFieldValue>
|
||||||
|
) => {
|
||||||
|
fieldValueReducer(state, action);
|
||||||
|
},
|
||||||
fieldEnumModelValueChanged: (
|
fieldEnumModelValueChanged: (
|
||||||
state,
|
state,
|
||||||
action: FieldValueAction<EnumInputFieldValue>
|
action: FieldValueAction<EnumInputFieldValue>
|
||||||
@ -1009,6 +1016,7 @@ export const {
|
|||||||
fieldEnumModelValueChanged,
|
fieldEnumModelValueChanged,
|
||||||
fieldImageValueChanged,
|
fieldImageValueChanged,
|
||||||
fieldIPAdapterModelValueChanged,
|
fieldIPAdapterModelValueChanged,
|
||||||
|
fieldT2IAdapterModelValueChanged,
|
||||||
fieldLabelChanged,
|
fieldLabelChanged,
|
||||||
fieldLoRAModelValueChanged,
|
fieldLoRAModelValueChanged,
|
||||||
fieldMainModelValueChanged,
|
fieldMainModelValueChanged,
|
||||||
|
@ -31,6 +31,7 @@ export const COLLECTION_TYPES: FieldType[] = [
|
|||||||
'ConditioningCollection',
|
'ConditioningCollection',
|
||||||
'ControlCollection',
|
'ControlCollection',
|
||||||
'ColorCollection',
|
'ColorCollection',
|
||||||
|
'T2IAdapterCollection',
|
||||||
];
|
];
|
||||||
|
|
||||||
export const POLYMORPHIC_TYPES: FieldType[] = [
|
export const POLYMORPHIC_TYPES: FieldType[] = [
|
||||||
@ -43,6 +44,7 @@ export const POLYMORPHIC_TYPES: FieldType[] = [
|
|||||||
'ConditioningPolymorphic',
|
'ConditioningPolymorphic',
|
||||||
'ControlPolymorphic',
|
'ControlPolymorphic',
|
||||||
'ColorPolymorphic',
|
'ColorPolymorphic',
|
||||||
|
'T2IAdapterPolymorphic',
|
||||||
];
|
];
|
||||||
|
|
||||||
export const MODEL_TYPES: FieldType[] = [
|
export const MODEL_TYPES: FieldType[] = [
|
||||||
@ -57,6 +59,7 @@ export const MODEL_TYPES: FieldType[] = [
|
|||||||
'UNetField',
|
'UNetField',
|
||||||
'VaeField',
|
'VaeField',
|
||||||
'ClipField',
|
'ClipField',
|
||||||
|
'T2IAdapterModelField',
|
||||||
];
|
];
|
||||||
|
|
||||||
export const COLLECTION_MAP: FieldTypeMapWithNumber = {
|
export const COLLECTION_MAP: FieldTypeMapWithNumber = {
|
||||||
@ -70,6 +73,7 @@ export const COLLECTION_MAP: FieldTypeMapWithNumber = {
|
|||||||
ConditioningField: 'ConditioningCollection',
|
ConditioningField: 'ConditioningCollection',
|
||||||
ControlField: 'ControlCollection',
|
ControlField: 'ControlCollection',
|
||||||
ColorField: 'ColorCollection',
|
ColorField: 'ColorCollection',
|
||||||
|
T2IAdapterField: 'T2IAdapterCollection',
|
||||||
};
|
};
|
||||||
export const isCollectionItemType = (
|
export const isCollectionItemType = (
|
||||||
itemType: string | undefined
|
itemType: string | undefined
|
||||||
@ -87,6 +91,7 @@ export const SINGLE_TO_POLYMORPHIC_MAP: FieldTypeMapWithNumber = {
|
|||||||
ConditioningField: 'ConditioningPolymorphic',
|
ConditioningField: 'ConditioningPolymorphic',
|
||||||
ControlField: 'ControlPolymorphic',
|
ControlField: 'ControlPolymorphic',
|
||||||
ColorField: 'ColorPolymorphic',
|
ColorField: 'ColorPolymorphic',
|
||||||
|
T2IAdapterField: 'T2IAdapterPolymorphic',
|
||||||
};
|
};
|
||||||
|
|
||||||
export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
|
export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
|
||||||
@ -99,6 +104,7 @@ export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
|
|||||||
ConditioningPolymorphic: 'ConditioningField',
|
ConditioningPolymorphic: 'ConditioningField',
|
||||||
ControlPolymorphic: 'ControlField',
|
ControlPolymorphic: 'ControlField',
|
||||||
ColorPolymorphic: 'ColorField',
|
ColorPolymorphic: 'ColorField',
|
||||||
|
T2IAdapterPolymorphic: 'T2IAdapterField',
|
||||||
};
|
};
|
||||||
|
|
||||||
export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [
|
export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [
|
||||||
@ -123,6 +129,7 @@ export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [
|
|||||||
'Scheduler',
|
'Scheduler',
|
||||||
'IPAdapterModelField',
|
'IPAdapterModelField',
|
||||||
'BoardField',
|
'BoardField',
|
||||||
|
'T2IAdapterModelField',
|
||||||
];
|
];
|
||||||
|
|
||||||
export const isPolymorphicItemType = (
|
export const isPolymorphicItemType = (
|
||||||
@ -272,7 +279,7 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
|||||||
title: t('nodes.integerPolymorphic'),
|
title: t('nodes.integerPolymorphic'),
|
||||||
},
|
},
|
||||||
IPAdapterField: {
|
IPAdapterField: {
|
||||||
color: 'green.300',
|
color: 'teal.500',
|
||||||
description: 'IP-Adapter info passed between nodes.',
|
description: 'IP-Adapter info passed between nodes.',
|
||||||
title: 'IP-Adapter',
|
title: 'IP-Adapter',
|
||||||
},
|
},
|
||||||
@ -341,6 +348,26 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
|||||||
description: t('nodes.stringPolymorphicDescription'),
|
description: t('nodes.stringPolymorphicDescription'),
|
||||||
title: t('nodes.stringPolymorphic'),
|
title: t('nodes.stringPolymorphic'),
|
||||||
},
|
},
|
||||||
|
T2IAdapterCollection: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: t('nodes.t2iAdapterCollectionDescription'),
|
||||||
|
title: t('nodes.t2iAdapterCollection'),
|
||||||
|
},
|
||||||
|
T2IAdapterField: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: t('nodes.t2iAdapterFieldDescription'),
|
||||||
|
title: t('nodes.t2iAdapterField'),
|
||||||
|
},
|
||||||
|
T2IAdapterModelField: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: 'TODO',
|
||||||
|
title: 'T2I-Adapter',
|
||||||
|
},
|
||||||
|
T2IAdapterPolymorphic: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: 'T2I-Adapter info passed between nodes.',
|
||||||
|
title: 'T2I-Adapter Polymorphic',
|
||||||
|
},
|
||||||
UNetField: {
|
UNetField: {
|
||||||
color: 'red.500',
|
color: 'red.500',
|
||||||
description: t('nodes.uNetFieldDescription'),
|
description: t('nodes.uNetFieldDescription'),
|
||||||
|
@ -114,6 +114,10 @@ export const zFieldType = z.enum([
|
|||||||
'string',
|
'string',
|
||||||
'StringCollection',
|
'StringCollection',
|
||||||
'StringPolymorphic',
|
'StringPolymorphic',
|
||||||
|
'T2IAdapterCollection',
|
||||||
|
'T2IAdapterField',
|
||||||
|
'T2IAdapterModelField',
|
||||||
|
'T2IAdapterPolymorphic',
|
||||||
'UNetField',
|
'UNetField',
|
||||||
'VaeField',
|
'VaeField',
|
||||||
'VaeModelField',
|
'VaeModelField',
|
||||||
@ -426,6 +430,48 @@ export type IPAdapterInputFieldValue = z.infer<
|
|||||||
typeof zIPAdapterInputFieldValue
|
typeof zIPAdapterInputFieldValue
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
export const zT2IAdapterModel = zModelIdentifier;
|
||||||
|
export type T2IAdapterModel = z.infer<typeof zT2IAdapterModel>;
|
||||||
|
|
||||||
|
export const zT2IAdapterField = z.object({
|
||||||
|
image: zImageField,
|
||||||
|
t2i_adapter_model: zT2IAdapterModel,
|
||||||
|
weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||||
|
begin_step_percent: z.number().optional(),
|
||||||
|
end_step_percent: z.number().optional(),
|
||||||
|
resize_mode: z
|
||||||
|
.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple'])
|
||||||
|
.optional(),
|
||||||
|
});
|
||||||
|
export type T2IAdapterField = z.infer<typeof zT2IAdapterField>;
|
||||||
|
|
||||||
|
export const zT2IAdapterInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('T2IAdapterField'),
|
||||||
|
value: zT2IAdapterField.optional(),
|
||||||
|
});
|
||||||
|
export type T2IAdapterInputFieldValue = z.infer<
|
||||||
|
typeof zT2IAdapterInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zT2IAdapterPolymorphicInputFieldValue =
|
||||||
|
zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('T2IAdapterPolymorphic'),
|
||||||
|
value: z.union([zT2IAdapterField, z.array(zT2IAdapterField)]).optional(),
|
||||||
|
});
|
||||||
|
export type T2IAdapterPolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zT2IAdapterPolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zT2IAdapterCollectionInputFieldValue = zInputFieldValueBase.extend(
|
||||||
|
{
|
||||||
|
type: z.literal('T2IAdapterCollection'),
|
||||||
|
value: z.array(zT2IAdapterField).optional(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
export type T2IAdapterCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zT2IAdapterCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zModelType = z.enum([
|
export const zModelType = z.enum([
|
||||||
'onnx',
|
'onnx',
|
||||||
'main',
|
'main',
|
||||||
@ -592,6 +638,17 @@ export type IPAdapterModelInputFieldValue = z.infer<
|
|||||||
typeof zIPAdapterModelInputFieldValue
|
typeof zIPAdapterModelInputFieldValue
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
export const zT2IAdapterModelField = zModelIdentifier;
|
||||||
|
export type T2IAdapterModelField = z.infer<typeof zT2IAdapterModelField>;
|
||||||
|
|
||||||
|
export const zT2IAdapterModelInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('T2IAdapterModelField'),
|
||||||
|
value: zT2IAdapterModelField.optional(),
|
||||||
|
});
|
||||||
|
export type T2IAdapterModelInputFieldValue = z.infer<
|
||||||
|
typeof zT2IAdapterModelInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zCollectionInputFieldValue = zInputFieldValueBase.extend({
|
export const zCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('Collection'),
|
type: z.literal('Collection'),
|
||||||
value: z.array(z.any()).optional(), // TODO: should this field ever have a value?
|
value: z.array(z.any()).optional(), // TODO: should this field ever have a value?
|
||||||
@ -688,6 +745,10 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
|
|||||||
zStringCollectionInputFieldValue,
|
zStringCollectionInputFieldValue,
|
||||||
zStringPolymorphicInputFieldValue,
|
zStringPolymorphicInputFieldValue,
|
||||||
zStringInputFieldValue,
|
zStringInputFieldValue,
|
||||||
|
zT2IAdapterInputFieldValue,
|
||||||
|
zT2IAdapterModelInputFieldValue,
|
||||||
|
zT2IAdapterCollectionInputFieldValue,
|
||||||
|
zT2IAdapterPolymorphicInputFieldValue,
|
||||||
zUNetInputFieldValue,
|
zUNetInputFieldValue,
|
||||||
zVaeInputFieldValue,
|
zVaeInputFieldValue,
|
||||||
zVaeModelInputFieldValue,
|
zVaeModelInputFieldValue,
|
||||||
@ -889,6 +950,24 @@ export type IPAdapterInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'IPAdapterField';
|
type: 'IPAdapterField';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type T2IAdapterInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: undefined;
|
||||||
|
type: 'T2IAdapterField';
|
||||||
|
};
|
||||||
|
|
||||||
|
export type T2IAdapterCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: undefined;
|
||||||
|
type: 'T2IAdapterCollection';
|
||||||
|
item_default?: T2IAdapterField;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type T2IAdapterPolymorphicInputFieldTemplate = Omit<
|
||||||
|
T2IAdapterInputFieldTemplate,
|
||||||
|
'type'
|
||||||
|
> & {
|
||||||
|
type: 'T2IAdapterPolymorphic';
|
||||||
|
};
|
||||||
|
|
||||||
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
|
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: string;
|
default: string;
|
||||||
type: 'enum';
|
type: 'enum';
|
||||||
@ -931,6 +1010,11 @@ export type IPAdapterModelInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'IPAdapterModelField';
|
type: 'IPAdapterModelField';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type T2IAdapterModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: string;
|
||||||
|
type: 'T2IAdapterModelField';
|
||||||
|
};
|
||||||
|
|
||||||
export type CollectionInputFieldTemplate = InputFieldTemplateBase & {
|
export type CollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: [];
|
default: [];
|
||||||
type: 'Collection';
|
type: 'Collection';
|
||||||
@ -1016,6 +1100,10 @@ export type InputFieldTemplate =
|
|||||||
| StringCollectionInputFieldTemplate
|
| StringCollectionInputFieldTemplate
|
||||||
| StringPolymorphicInputFieldTemplate
|
| StringPolymorphicInputFieldTemplate
|
||||||
| StringInputFieldTemplate
|
| StringInputFieldTemplate
|
||||||
|
| T2IAdapterInputFieldTemplate
|
||||||
|
| T2IAdapterCollectionInputFieldTemplate
|
||||||
|
| T2IAdapterModelInputFieldTemplate
|
||||||
|
| T2IAdapterPolymorphicInputFieldTemplate
|
||||||
| UNetInputFieldTemplate
|
| UNetInputFieldTemplate
|
||||||
| VaeInputFieldTemplate
|
| VaeInputFieldTemplate
|
||||||
| VaeModelInputFieldTemplate;
|
| VaeModelInputFieldTemplate;
|
||||||
|
@ -62,6 +62,11 @@ import {
|
|||||||
ConditioningField,
|
ConditioningField,
|
||||||
IPAdapterInputFieldTemplate,
|
IPAdapterInputFieldTemplate,
|
||||||
IPAdapterModelInputFieldTemplate,
|
IPAdapterModelInputFieldTemplate,
|
||||||
|
T2IAdapterField,
|
||||||
|
T2IAdapterInputFieldTemplate,
|
||||||
|
T2IAdapterModelInputFieldTemplate,
|
||||||
|
T2IAdapterPolymorphicInputFieldTemplate,
|
||||||
|
T2IAdapterCollectionInputFieldTemplate,
|
||||||
BoardInputFieldTemplate,
|
BoardInputFieldTemplate,
|
||||||
InputFieldTemplate,
|
InputFieldTemplate,
|
||||||
} from '../types/types';
|
} from '../types/types';
|
||||||
@ -452,6 +457,19 @@ const buildIPAdapterModelInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildT2IAdapterModelInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): T2IAdapterModelInputFieldTemplate => {
|
||||||
|
const template: T2IAdapterModelInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'T2IAdapterModelField',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildBoardInputFieldTemplate = ({
|
const buildBoardInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -691,6 +709,46 @@ const buildIPAdapterInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildT2IAdapterInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): T2IAdapterInputFieldTemplate => {
|
||||||
|
const template: T2IAdapterInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'T2IAdapterField',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildT2IAdapterPolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): T2IAdapterPolymorphicInputFieldTemplate => {
|
||||||
|
const template: T2IAdapterPolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'T2IAdapterPolymorphic',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildT2IAdapterCollectionInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): T2IAdapterCollectionInputFieldTemplate => {
|
||||||
|
const template: T2IAdapterCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'T2IAdapterCollection',
|
||||||
|
default: schemaObject.default ?? [],
|
||||||
|
item_default: (schemaObject.item_default as T2IAdapterField) ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildEnumInputFieldTemplate = ({
|
const buildEnumInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -910,6 +968,10 @@ const TEMPLATE_BUILDER_MAP: {
|
|||||||
string: buildStringInputFieldTemplate,
|
string: buildStringInputFieldTemplate,
|
||||||
StringCollection: buildStringCollectionInputFieldTemplate,
|
StringCollection: buildStringCollectionInputFieldTemplate,
|
||||||
StringPolymorphic: buildStringPolymorphicInputFieldTemplate,
|
StringPolymorphic: buildStringPolymorphicInputFieldTemplate,
|
||||||
|
T2IAdapterCollection: buildT2IAdapterCollectionInputFieldTemplate,
|
||||||
|
T2IAdapterField: buildT2IAdapterInputFieldTemplate,
|
||||||
|
T2IAdapterModelField: buildT2IAdapterModelInputFieldTemplate,
|
||||||
|
T2IAdapterPolymorphic: buildT2IAdapterPolymorphicInputFieldTemplate,
|
||||||
UNetField: buildUNetInputFieldTemplate,
|
UNetField: buildUNetInputFieldTemplate,
|
||||||
VaeField: buildVaeInputFieldTemplate,
|
VaeField: buildVaeInputFieldTemplate,
|
||||||
VaeModelField: buildVaeModelInputFieldTemplate,
|
VaeModelField: buildVaeModelInputFieldTemplate,
|
||||||
|
@ -45,6 +45,10 @@ const FIELD_VALUE_FALLBACK_MAP: {
|
|||||||
string: '',
|
string: '',
|
||||||
StringCollection: [],
|
StringCollection: [],
|
||||||
StringPolymorphic: '',
|
StringPolymorphic: '',
|
||||||
|
T2IAdapterCollection: [],
|
||||||
|
T2IAdapterField: undefined,
|
||||||
|
T2IAdapterModelField: undefined,
|
||||||
|
T2IAdapterPolymorphic: undefined,
|
||||||
UNetField: undefined,
|
UNetField: undefined,
|
||||||
VaeField: undefined,
|
VaeField: undefined,
|
||||||
VaeModelField: undefined,
|
VaeModelField: undefined,
|
||||||
|
@ -340,6 +340,17 @@ export const zIPAdapterModel = z.object({
|
|||||||
* Type alias for model parameter, inferred from its zod schema
|
* Type alias for model parameter, inferred from its zod schema
|
||||||
*/
|
*/
|
||||||
export type IPAdapterModelParam = z.infer<typeof zIPAdapterModel>;
|
export type IPAdapterModelParam = z.infer<typeof zIPAdapterModel>;
|
||||||
|
/**
|
||||||
|
* Zod schema for T2I-Adapter models
|
||||||
|
*/
|
||||||
|
export const zT2IAdapterModel = z.object({
|
||||||
|
model_name: z.string().min(1),
|
||||||
|
base_model: zBaseModel,
|
||||||
|
});
|
||||||
|
/**
|
||||||
|
* Type alias for model parameter, inferred from its zod schema
|
||||||
|
*/
|
||||||
|
export type T2IAdapterModelParam = z.infer<typeof zT2IAdapterModel>;
|
||||||
/**
|
/**
|
||||||
* Zod schema for l2l strength parameter
|
* Zod schema for l2l strength parameter
|
||||||
*/
|
*/
|
||||||
|
@ -0,0 +1,29 @@
|
|||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { zT2IAdapterModel } from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { T2IAdapterModelField } from 'services/api/types';
|
||||||
|
|
||||||
|
export const modelIdToT2IAdapterModelParam = (
|
||||||
|
t2iAdapterModelId: string
|
||||||
|
): T2IAdapterModelField | undefined => {
|
||||||
|
const log = logger('models');
|
||||||
|
const [base_model, _model_type, model_name] = t2iAdapterModelId.split('/');
|
||||||
|
|
||||||
|
const result = zT2IAdapterModel.safeParse({
|
||||||
|
base_model,
|
||||||
|
model_name,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!result.success) {
|
||||||
|
log.error(
|
||||||
|
{
|
||||||
|
t2iAdapterModelId,
|
||||||
|
errors: result.error.format(),
|
||||||
|
},
|
||||||
|
'Failed to parse T2I-Adapter model id'
|
||||||
|
);
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.data;
|
||||||
|
};
|
@ -25,6 +25,9 @@ export const useCopyImageToClipboard = () => {
|
|||||||
try {
|
try {
|
||||||
const getImageBlob = async () => {
|
const getImageBlob = async () => {
|
||||||
const response = await fetch(image_url);
|
const response = await fetch(image_url);
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Problem retrieving image data`);
|
||||||
|
}
|
||||||
return await response.blob();
|
return await response.blob();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ export const boardsApi = api.injectEndpoints({
|
|||||||
keepUnusedDataFor: 0,
|
keepUnusedDataFor: 0,
|
||||||
}),
|
}),
|
||||||
|
|
||||||
getBoardImagesTotal: build.query<number, string | undefined>({
|
getBoardImagesTotal: build.query<{ total: number }, string | undefined>({
|
||||||
query: (board_id) => ({
|
query: (board_id) => ({
|
||||||
url: getListImagesUrl({
|
url: getListImagesUrl({
|
||||||
board_id: board_id ?? 'none',
|
board_id: board_id ?? 'none',
|
||||||
@ -85,11 +85,11 @@ export const boardsApi = api.injectEndpoints({
|
|||||||
{ type: 'BoardImagesTotal', id: arg ?? 'none' },
|
{ type: 'BoardImagesTotal', id: arg ?? 'none' },
|
||||||
],
|
],
|
||||||
transformResponse: (response: OffsetPaginatedResults_ImageDTO_) => {
|
transformResponse: (response: OffsetPaginatedResults_ImageDTO_) => {
|
||||||
return response.total;
|
return { total: response.total };
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
|
|
||||||
getBoardAssetsTotal: build.query<number, string | undefined>({
|
getBoardAssetsTotal: build.query<{ total: number }, string | undefined>({
|
||||||
query: (board_id) => ({
|
query: (board_id) => ({
|
||||||
url: getListImagesUrl({
|
url: getListImagesUrl({
|
||||||
board_id: board_id ?? 'none',
|
board_id: board_id ?? 'none',
|
||||||
@ -104,7 +104,7 @@ export const boardsApi = api.injectEndpoints({
|
|||||||
{ type: 'BoardAssetsTotal', id: arg ?? 'none' },
|
{ type: 'BoardAssetsTotal', id: arg ?? 'none' },
|
||||||
],
|
],
|
||||||
transformResponse: (response: OffsetPaginatedResults_ImageDTO_) => {
|
transformResponse: (response: OffsetPaginatedResults_ImageDTO_) => {
|
||||||
return response.total;
|
return { total: response.total };
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
@ -103,6 +103,9 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
query: () => ({ url: getListImagesUrl({ is_intermediate: true }) }),
|
query: () => ({ url: getListImagesUrl({ is_intermediate: true }) }),
|
||||||
providesTags: ['IntermediatesCount'],
|
providesTags: ['IntermediatesCount'],
|
||||||
transformResponse: (response: OffsetPaginatedResults_ImageDTO_) => {
|
transformResponse: (response: OffsetPaginatedResults_ImageDTO_) => {
|
||||||
|
// TODO: This is storing a primitive value in the cache. `immer` cannot track state changes, so
|
||||||
|
// attempts to use manual cache updates on this value will fail. This should be changed into an
|
||||||
|
// object.
|
||||||
return response.total;
|
return response.total;
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
@ -191,35 +194,51 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
url: `images/i/${image_name}`,
|
url: `images/i/${image_name}`,
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
}),
|
}),
|
||||||
invalidatesTags: (result, error, { board_id }) => [
|
|
||||||
{ type: 'BoardImagesTotal', id: board_id ?? 'none' },
|
|
||||||
{ type: 'BoardAssetsTotal', id: board_id ?? 'none' },
|
|
||||||
],
|
|
||||||
async onQueryStarted(imageDTO, { dispatch, queryFulfilled }) {
|
async onQueryStarted(imageDTO, { dispatch, queryFulfilled }) {
|
||||||
/**
|
/**
|
||||||
* Cache changes for `deleteImage`:
|
* Cache changes for `deleteImage`:
|
||||||
* - NOT POSSIBLE: *remove* from getImageDTO
|
* - NOT POSSIBLE: *remove* from getImageDTO
|
||||||
* - $cache = [board_id|no_board]/[images|assets]
|
* - $cache = [board_id|no_board]/[images|assets]
|
||||||
* - *remove* from $cache
|
* - *remove* from $cache
|
||||||
|
* - decrement the image's board's total
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const { image_name, board_id } = imageDTO;
|
const { image_name, board_id } = imageDTO;
|
||||||
|
const isAsset = ASSETS_CATEGORIES.includes(imageDTO.image_category);
|
||||||
|
|
||||||
const queryArg = {
|
const queryArg = {
|
||||||
board_id: board_id ?? 'none',
|
board_id: board_id ?? 'none',
|
||||||
categories: getCategories(imageDTO),
|
categories: getCategories(imageDTO),
|
||||||
};
|
};
|
||||||
|
|
||||||
const patch = dispatch(
|
const patches: PatchCollection[] = [];
|
||||||
|
|
||||||
|
patches.push(
|
||||||
|
dispatch(
|
||||||
imagesApi.util.updateQueryData('listImages', queryArg, (draft) => {
|
imagesApi.util.updateQueryData('listImages', queryArg, (draft) => {
|
||||||
imagesAdapter.removeOne(draft, image_name);
|
imagesAdapter.removeOne(draft, image_name);
|
||||||
})
|
})
|
||||||
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
patches.push(
|
||||||
|
dispatch(
|
||||||
|
boardsApi.util.updateQueryData(
|
||||||
|
isAsset ? 'getBoardAssetsTotal' : 'getBoardImagesTotal',
|
||||||
|
imageDTO.board_id ?? 'none',
|
||||||
|
(draft) => {
|
||||||
|
draft.total = Math.max(draft.total - 1, 0);
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
); // decrement the image board's total
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await queryFulfilled;
|
await queryFulfilled;
|
||||||
} catch {
|
} catch {
|
||||||
|
patches.forEach((patch) => {
|
||||||
patch.undo();
|
patch.undo();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
@ -237,18 +256,11 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: (result, error, { imageDTOs }) => {
|
|
||||||
// for now, assume bulk delete is all on one board
|
|
||||||
const boardId = imageDTOs[0]?.board_id;
|
|
||||||
return [
|
|
||||||
{ type: 'BoardImagesTotal', id: boardId ?? 'none' },
|
|
||||||
{ type: 'BoardAssetsTotal', id: boardId ?? 'none' },
|
|
||||||
];
|
|
||||||
},
|
|
||||||
async onQueryStarted({ imageDTOs }, { dispatch, queryFulfilled }) {
|
async onQueryStarted({ imageDTOs }, { dispatch, queryFulfilled }) {
|
||||||
/**
|
/**
|
||||||
* Cache changes for `deleteImages`:
|
* Cache changes for `deleteImages`:
|
||||||
* - *remove* the deleted images from their boards
|
* - *remove* the deleted images from their boards
|
||||||
|
* - decrement the images' board's totals
|
||||||
*
|
*
|
||||||
* Unfortunately we cannot do an optimistic update here due to how immer handles patching
|
* Unfortunately we cannot do an optimistic update here due to how immer handles patching
|
||||||
* arrays. You have to undo *all* patches, else the entity adapter's `ids` array is borked.
|
* arrays. You have to undo *all* patches, else the entity adapter's `ids` array is borked.
|
||||||
@ -279,6 +291,21 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const isAsset = ASSETS_CATEGORIES.includes(
|
||||||
|
imageDTO.image_category
|
||||||
|
);
|
||||||
|
|
||||||
|
// decrement the image board's total
|
||||||
|
dispatch(
|
||||||
|
boardsApi.util.updateQueryData(
|
||||||
|
isAsset ? 'getBoardAssetsTotal' : 'getBoardImagesTotal',
|
||||||
|
imageDTO.board_id ?? 'none',
|
||||||
|
(draft) => {
|
||||||
|
draft.total = Math.max(draft.total - 1, 0);
|
||||||
|
}
|
||||||
|
)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
} catch {
|
} catch {
|
||||||
@ -298,10 +325,6 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
method: 'PATCH',
|
method: 'PATCH',
|
||||||
body: { is_intermediate },
|
body: { is_intermediate },
|
||||||
}),
|
}),
|
||||||
invalidatesTags: (result, error, { imageDTO }) => [
|
|
||||||
{ type: 'BoardImagesTotal', id: imageDTO.board_id ?? 'none' },
|
|
||||||
{ type: 'BoardAssetsTotal', id: imageDTO.board_id ?? 'none' },
|
|
||||||
],
|
|
||||||
async onQueryStarted(
|
async onQueryStarted(
|
||||||
{ imageDTO, is_intermediate },
|
{ imageDTO, is_intermediate },
|
||||||
{ dispatch, queryFulfilled, getState }
|
{ dispatch, queryFulfilled, getState }
|
||||||
@ -312,9 +335,11 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
* - $cache = [board_id|no_board]/[images|assets]
|
* - $cache = [board_id|no_board]/[images|assets]
|
||||||
* - IF it is being changed to an intermediate:
|
* - IF it is being changed to an intermediate:
|
||||||
* - remove from $cache
|
* - remove from $cache
|
||||||
|
* - decrement the image's board's total
|
||||||
* - ELSE (it is being changed to a non-intermediate):
|
* - ELSE (it is being changed to a non-intermediate):
|
||||||
* - IF it eligible for insertion into existing $cache:
|
* - IF it eligible for insertion into existing $cache:
|
||||||
* - *upsert* to $cache
|
* - *upsert* to $cache
|
||||||
|
* - increment the image's board's total
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Store patches so we can undo if the query fails
|
// Store patches so we can undo if the query fails
|
||||||
@ -335,6 +360,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
|
|
||||||
// $cache = [board_id|no_board]/[images|assets]
|
// $cache = [board_id|no_board]/[images|assets]
|
||||||
const categories = getCategories(imageDTO);
|
const categories = getCategories(imageDTO);
|
||||||
|
const isAsset = ASSETS_CATEGORIES.includes(imageDTO.image_category);
|
||||||
|
|
||||||
if (is_intermediate) {
|
if (is_intermediate) {
|
||||||
// IF it is being changed to an intermediate:
|
// IF it is being changed to an intermediate:
|
||||||
@ -350,8 +376,35 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// decrement the image board's total
|
||||||
|
patches.push(
|
||||||
|
dispatch(
|
||||||
|
boardsApi.util.updateQueryData(
|
||||||
|
isAsset ? 'getBoardAssetsTotal' : 'getBoardImagesTotal',
|
||||||
|
imageDTO.board_id ?? 'none',
|
||||||
|
(draft) => {
|
||||||
|
draft.total = Math.max(draft.total - 1, 0);
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
// ELSE (it is being changed to a non-intermediate):
|
// ELSE (it is being changed to a non-intermediate):
|
||||||
|
|
||||||
|
// increment the image board's total
|
||||||
|
patches.push(
|
||||||
|
dispatch(
|
||||||
|
boardsApi.util.updateQueryData(
|
||||||
|
isAsset ? 'getBoardAssetsTotal' : 'getBoardImagesTotal',
|
||||||
|
imageDTO.board_id ?? 'none',
|
||||||
|
(draft) => {
|
||||||
|
draft.total += 1;
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
const queryArgs = {
|
const queryArgs = {
|
||||||
board_id: imageDTO.board_id ?? 'none',
|
board_id: imageDTO.board_id ?? 'none',
|
||||||
categories,
|
categories,
|
||||||
@ -361,9 +414,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
getState()
|
getState()
|
||||||
);
|
);
|
||||||
|
|
||||||
const { data: total } = IMAGE_CATEGORIES.includes(
|
const { data } = IMAGE_CATEGORIES.includes(imageDTO.image_category)
|
||||||
imageDTO.image_category
|
|
||||||
)
|
|
||||||
? boardsApi.endpoints.getBoardImagesTotal.select(
|
? boardsApi.endpoints.getBoardImagesTotal.select(
|
||||||
imageDTO.board_id ?? 'none'
|
imageDTO.board_id ?? 'none'
|
||||||
)(getState())
|
)(getState())
|
||||||
@ -378,7 +429,8 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
// - The image's `created_at` is within the range of the cached images
|
// - The image's `created_at` is within the range of the cached images
|
||||||
|
|
||||||
const isCacheFullyPopulated =
|
const isCacheFullyPopulated =
|
||||||
currentCache.data && currentCache.data.ids.length >= (total ?? 0);
|
currentCache.data &&
|
||||||
|
currentCache.data.ids.length >= (data?.total ?? 0);
|
||||||
|
|
||||||
const isInDateRange = getIsImageInDateRange(
|
const isInDateRange = getIsImageInDateRange(
|
||||||
currentCache.data,
|
currentCache.data,
|
||||||
@ -420,10 +472,6 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
method: 'PATCH',
|
method: 'PATCH',
|
||||||
body: { session_id },
|
body: { session_id },
|
||||||
}),
|
}),
|
||||||
invalidatesTags: (result, error, { imageDTO }) => [
|
|
||||||
{ type: 'BoardImagesTotal', id: imageDTO.board_id ?? 'none' },
|
|
||||||
{ type: 'BoardAssetsTotal', id: imageDTO.board_id ?? 'none' },
|
|
||||||
],
|
|
||||||
async onQueryStarted(
|
async onQueryStarted(
|
||||||
{ imageDTO, session_id },
|
{ imageDTO, session_id },
|
||||||
{ dispatch, queryFulfilled }
|
{ dispatch, queryFulfilled }
|
||||||
@ -473,6 +521,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
if (images[0]) {
|
if (images[0]) {
|
||||||
const categories = getCategories(images[0]);
|
const categories = getCategories(images[0]);
|
||||||
const boardId = images[0].board_id;
|
const boardId = images[0].board_id;
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
type: 'ImageList',
|
type: 'ImageList',
|
||||||
@ -481,6 +530,10 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
categories,
|
categories,
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
type: 'Board',
|
||||||
|
id: boardId,
|
||||||
|
},
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
return [];
|
return [];
|
||||||
@ -530,9 +583,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
queryArgs
|
queryArgs
|
||||||
)(getState());
|
)(getState());
|
||||||
|
|
||||||
const { data: previousTotal } = IMAGE_CATEGORIES.includes(
|
const { data } = IMAGE_CATEGORIES.includes(imageDTO.image_category)
|
||||||
imageDTO.image_category
|
|
||||||
)
|
|
||||||
? boardsApi.endpoints.getBoardImagesTotal.select(
|
? boardsApi.endpoints.getBoardImagesTotal.select(
|
||||||
boardId ?? 'none'
|
boardId ?? 'none'
|
||||||
)(getState())
|
)(getState())
|
||||||
@ -542,10 +593,10 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
|
|
||||||
const isCacheFullyPopulated =
|
const isCacheFullyPopulated =
|
||||||
currentCache.data &&
|
currentCache.data &&
|
||||||
currentCache.data.ids.length >= (previousTotal ?? 0);
|
currentCache.data.ids.length >= (data?.total ?? 0);
|
||||||
|
|
||||||
const isInDateRange =
|
const isInDateRange =
|
||||||
(previousTotal || 0) >= IMAGE_LIMIT
|
(data?.total ?? 0) >= IMAGE_LIMIT
|
||||||
? getIsImageInDateRange(currentCache.data, imageDTO)
|
? getIsImageInDateRange(currentCache.data, imageDTO)
|
||||||
: true;
|
: true;
|
||||||
|
|
||||||
@ -595,6 +646,10 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
categories,
|
categories,
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
type: 'Board',
|
||||||
|
id: boardId,
|
||||||
|
},
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
return [];
|
return [];
|
||||||
@ -643,9 +698,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
queryArgs
|
queryArgs
|
||||||
)(getState());
|
)(getState());
|
||||||
|
|
||||||
const { data: previousTotal } = IMAGE_CATEGORIES.includes(
|
const { data } = IMAGE_CATEGORIES.includes(imageDTO.image_category)
|
||||||
imageDTO.image_category
|
|
||||||
)
|
|
||||||
? boardsApi.endpoints.getBoardImagesTotal.select(
|
? boardsApi.endpoints.getBoardImagesTotal.select(
|
||||||
boardId ?? 'none'
|
boardId ?? 'none'
|
||||||
)(getState())
|
)(getState())
|
||||||
@ -655,10 +708,10 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
|
|
||||||
const isCacheFullyPopulated =
|
const isCacheFullyPopulated =
|
||||||
currentCache.data &&
|
currentCache.data &&
|
||||||
currentCache.data.ids.length >= (previousTotal ?? 0);
|
currentCache.data.ids.length >= (data?.total ?? 0);
|
||||||
|
|
||||||
const isInDateRange =
|
const isInDateRange =
|
||||||
(previousTotal || 0) >= IMAGE_LIMIT
|
(data?.total ?? 0) >= IMAGE_LIMIT
|
||||||
? getIsImageInDateRange(currentCache.data, imageDTO)
|
? getIsImageInDateRange(currentCache.data, imageDTO)
|
||||||
: true;
|
: true;
|
||||||
|
|
||||||
@ -727,6 +780,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
* - BAIL OUT
|
* - BAIL OUT
|
||||||
* - *add* to `getImageDTO`
|
* - *add* to `getImageDTO`
|
||||||
* - *add* to no_board/assets
|
* - *add* to no_board/assets
|
||||||
|
* - update the image's board's assets total
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const { data: imageDTO } = await queryFulfilled;
|
const { data: imageDTO } = await queryFulfilled;
|
||||||
@ -761,11 +815,15 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// increment new board's total
|
||||||
dispatch(
|
dispatch(
|
||||||
imagesApi.util.invalidateTags([
|
boardsApi.util.updateQueryData(
|
||||||
{ type: 'BoardImagesTotal', id: imageDTO.board_id ?? 'none' },
|
'getBoardAssetsTotal',
|
||||||
{ type: 'BoardAssetsTotal', id: imageDTO.board_id ?? 'none' },
|
imageDTO.board_id ?? 'none',
|
||||||
])
|
(draft) => {
|
||||||
|
draft.total += 1;
|
||||||
|
}
|
||||||
|
)
|
||||||
);
|
);
|
||||||
} catch {
|
} catch {
|
||||||
// query failed, no action needed
|
// query failed, no action needed
|
||||||
@ -792,8 +850,6 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
categories: ASSETS_CATEGORIES,
|
categories: ASSETS_CATEGORIES,
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
{ type: 'BoardImagesTotal', id: 'none' },
|
|
||||||
{ type: 'BoardAssetsTotal', id: 'none' },
|
|
||||||
],
|
],
|
||||||
async onQueryStarted(board_id, { dispatch, queryFulfilled }) {
|
async onQueryStarted(board_id, { dispatch, queryFulfilled }) {
|
||||||
/**
|
/**
|
||||||
@ -806,6 +862,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
* have access to the deleted images DTOs - only the names, and a network request
|
* have access to the deleted images DTOs - only the names, and a network request
|
||||||
* for all of a board's DTOs could be very large. Instead, we invalidate the 'No Board'
|
* for all of a board's DTOs could be very large. Instead, we invalidate the 'No Board'
|
||||||
* cache.
|
* cache.
|
||||||
|
* - set the board's totals to zero
|
||||||
*/
|
*/
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@ -825,6 +882,28 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// set the board's asset total to 0 (feels unnecessary since we are deleting it?)
|
||||||
|
dispatch(
|
||||||
|
boardsApi.util.updateQueryData(
|
||||||
|
'getBoardAssetsTotal',
|
||||||
|
board_id,
|
||||||
|
(draft) => {
|
||||||
|
draft.total = 0;
|
||||||
|
}
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
// set the board's images total to 0 (feels unnecessary since we are deleting it?)
|
||||||
|
dispatch(
|
||||||
|
boardsApi.util.updateQueryData(
|
||||||
|
'getBoardImagesTotal',
|
||||||
|
board_id,
|
||||||
|
(draft) => {
|
||||||
|
draft.total = 0;
|
||||||
|
}
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
// update 'All Images' & 'All Assets' caches
|
// update 'All Images' & 'All Assets' caches
|
||||||
const queryArgsToUpdate = [
|
const queryArgsToUpdate = [
|
||||||
{
|
{
|
||||||
@ -881,8 +960,6 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
categories: ASSETS_CATEGORIES,
|
categories: ASSETS_CATEGORIES,
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
{ type: 'BoardImagesTotal', id: 'none' },
|
|
||||||
{ type: 'BoardAssetsTotal', id: 'none' },
|
|
||||||
],
|
],
|
||||||
async onQueryStarted(board_id, { dispatch, queryFulfilled }) {
|
async onQueryStarted(board_id, { dispatch, queryFulfilled }) {
|
||||||
/**
|
/**
|
||||||
@ -892,6 +969,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
* Instead, we rely on the UI to remove all components that use the deleted images.
|
* Instead, we rely on the UI to remove all components that use the deleted images.
|
||||||
* - Remove every image in the 'All Images' cache that has the board_id
|
* - Remove every image in the 'All Images' cache that has the board_id
|
||||||
* - Remove every image in the 'All Assets' cache that has the board_id
|
* - Remove every image in the 'All Assets' cache that has the board_id
|
||||||
|
* - set the board's totals to zero
|
||||||
*/
|
*/
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@ -919,6 +997,28 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
)
|
)
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// set the board's asset total to 0 (feels unnecessary since we are deleting it?)
|
||||||
|
dispatch(
|
||||||
|
boardsApi.util.updateQueryData(
|
||||||
|
'getBoardAssetsTotal',
|
||||||
|
board_id,
|
||||||
|
(draft) => {
|
||||||
|
draft.total = 0;
|
||||||
|
}
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
// set the board's images total to 0 (feels unnecessary since we are deleting it?)
|
||||||
|
dispatch(
|
||||||
|
boardsApi.util.updateQueryData(
|
||||||
|
'getBoardImagesTotal',
|
||||||
|
board_id,
|
||||||
|
(draft) => {
|
||||||
|
draft.total = 0;
|
||||||
|
}
|
||||||
|
)
|
||||||
|
);
|
||||||
} catch {
|
} catch {
|
||||||
//no-op
|
//no-op
|
||||||
}
|
}
|
||||||
@ -936,15 +1036,9 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
body: { board_id, image_name },
|
body: { board_id, image_name },
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: (result, error, { board_id, imageDTO }) => [
|
invalidatesTags: (result, error, { board_id }) => [
|
||||||
// refresh the board itself
|
// refresh the board itself
|
||||||
{ type: 'Board', id: board_id },
|
{ type: 'Board', id: board_id },
|
||||||
// update old board totals
|
|
||||||
{ type: 'BoardImagesTotal', id: board_id },
|
|
||||||
{ type: 'BoardAssetsTotal', id: board_id },
|
|
||||||
// update new board totals
|
|
||||||
{ type: 'BoardImagesTotal', id: imageDTO.board_id ?? 'none' },
|
|
||||||
{ type: 'BoardAssetsTotal', id: imageDTO.board_id ?? 'none' },
|
|
||||||
],
|
],
|
||||||
async onQueryStarted(
|
async onQueryStarted(
|
||||||
{ board_id, imageDTO },
|
{ board_id, imageDTO },
|
||||||
@ -961,11 +1055,13 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
* - $cache = board_id/[images|assets]
|
* - $cache = board_id/[images|assets]
|
||||||
* - IF it eligible for insertion into existing $cache:
|
* - IF it eligible for insertion into existing $cache:
|
||||||
* - THEN *add* to $cache
|
* - THEN *add* to $cache
|
||||||
|
* - decrement both old board's total
|
||||||
|
* - increment the new board's total
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const patches: PatchCollection[] = [];
|
const patches: PatchCollection[] = [];
|
||||||
const categories = getCategories(imageDTO);
|
const categories = getCategories(imageDTO);
|
||||||
|
const isAsset = ASSETS_CATEGORIES.includes(imageDTO.image_category);
|
||||||
// *update* getImageDTO
|
// *update* getImageDTO
|
||||||
patches.push(
|
patches.push(
|
||||||
dispatch(
|
dispatch(
|
||||||
@ -996,6 +1092,32 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// decrement old board's total
|
||||||
|
patches.push(
|
||||||
|
dispatch(
|
||||||
|
boardsApi.util.updateQueryData(
|
||||||
|
isAsset ? 'getBoardAssetsTotal' : 'getBoardImagesTotal',
|
||||||
|
imageDTO.board_id ?? 'none',
|
||||||
|
(draft) => {
|
||||||
|
draft.total = Math.max(draft.total - 1, 0);
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
// increment new board's total
|
||||||
|
patches.push(
|
||||||
|
dispatch(
|
||||||
|
boardsApi.util.updateQueryData(
|
||||||
|
isAsset ? 'getBoardAssetsTotal' : 'getBoardImagesTotal',
|
||||||
|
board_id ?? 'none',
|
||||||
|
(draft) => {
|
||||||
|
draft.total += 1;
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
// $cache = board_id/[images|assets]
|
// $cache = board_id/[images|assets]
|
||||||
const queryArgs = { board_id: board_id ?? 'none', categories };
|
const queryArgs = { board_id: board_id ?? 'none', categories };
|
||||||
const currentCache = imagesApi.endpoints.listImages.select(queryArgs)(
|
const currentCache = imagesApi.endpoints.listImages.select(queryArgs)(
|
||||||
@ -1008,9 +1130,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
// OR
|
// OR
|
||||||
// - The image's `created_at` is within the range of the cached images
|
// - The image's `created_at` is within the range of the cached images
|
||||||
|
|
||||||
const { data: total } = IMAGE_CATEGORIES.includes(
|
const { data } = IMAGE_CATEGORIES.includes(imageDTO.image_category)
|
||||||
imageDTO.image_category
|
|
||||||
)
|
|
||||||
? boardsApi.endpoints.getBoardImagesTotal.select(
|
? boardsApi.endpoints.getBoardImagesTotal.select(
|
||||||
imageDTO.board_id ?? 'none'
|
imageDTO.board_id ?? 'none'
|
||||||
)(getState())
|
)(getState())
|
||||||
@ -1019,7 +1139,8 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
)(getState());
|
)(getState());
|
||||||
|
|
||||||
const isCacheFullyPopulated =
|
const isCacheFullyPopulated =
|
||||||
currentCache.data && currentCache.data.ids.length >= (total ?? 0);
|
currentCache.data &&
|
||||||
|
currentCache.data.ids.length >= (data?.total ?? 0);
|
||||||
|
|
||||||
const isInDateRange = getIsImageInDateRange(
|
const isInDateRange = getIsImageInDateRange(
|
||||||
currentCache.data,
|
currentCache.data,
|
||||||
@ -1063,12 +1184,6 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
return [
|
return [
|
||||||
// invalidate the image's old board
|
// invalidate the image's old board
|
||||||
{ type: 'Board', id: board_id ?? 'none' },
|
{ type: 'Board', id: board_id ?? 'none' },
|
||||||
// update old board totals
|
|
||||||
{ type: 'BoardImagesTotal', id: board_id ?? 'none' },
|
|
||||||
{ type: 'BoardAssetsTotal', id: board_id ?? 'none' },
|
|
||||||
// update the no_board totals
|
|
||||||
{ type: 'BoardImagesTotal', id: 'none' },
|
|
||||||
{ type: 'BoardAssetsTotal', id: 'none' },
|
|
||||||
];
|
];
|
||||||
},
|
},
|
||||||
async onQueryStarted(
|
async onQueryStarted(
|
||||||
@ -1082,10 +1197,13 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
* - $cache = no_board/[images|assets]
|
* - $cache = no_board/[images|assets]
|
||||||
* - IF it eligible for insertion into existing $cache:
|
* - IF it eligible for insertion into existing $cache:
|
||||||
* - THEN *upsert* to $cache
|
* - THEN *upsert* to $cache
|
||||||
|
* - decrement old board's total
|
||||||
|
* - increment the new board's total (no board)
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const categories = getCategories(imageDTO);
|
const categories = getCategories(imageDTO);
|
||||||
const patches: PatchCollection[] = [];
|
const patches: PatchCollection[] = [];
|
||||||
|
const isAsset = ASSETS_CATEGORIES.includes(imageDTO.image_category);
|
||||||
|
|
||||||
// *update* getImageDTO
|
// *update* getImageDTO
|
||||||
patches.push(
|
patches.push(
|
||||||
@ -1116,6 +1234,32 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// decrement old board's total
|
||||||
|
patches.push(
|
||||||
|
dispatch(
|
||||||
|
boardsApi.util.updateQueryData(
|
||||||
|
isAsset ? 'getBoardAssetsTotal' : 'getBoardImagesTotal',
|
||||||
|
imageDTO.board_id ?? 'none',
|
||||||
|
(draft) => {
|
||||||
|
draft.total = Math.max(draft.total - 1, 0);
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
// increment new board's total (no board)
|
||||||
|
patches.push(
|
||||||
|
dispatch(
|
||||||
|
boardsApi.util.updateQueryData(
|
||||||
|
isAsset ? 'getBoardAssetsTotal' : 'getBoardImagesTotal',
|
||||||
|
'none',
|
||||||
|
(draft) => {
|
||||||
|
draft.total += 1;
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
// $cache = no_board/[images|assets]
|
// $cache = no_board/[images|assets]
|
||||||
const queryArgs = { board_id: 'none', categories };
|
const queryArgs = { board_id: 'none', categories };
|
||||||
const currentCache = imagesApi.endpoints.listImages.select(queryArgs)(
|
const currentCache = imagesApi.endpoints.listImages.select(queryArgs)(
|
||||||
@ -1128,9 +1272,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
// OR
|
// OR
|
||||||
// - The image's `created_at` is within the range of the cached images
|
// - The image's `created_at` is within the range of the cached images
|
||||||
|
|
||||||
const { data: total } = IMAGE_CATEGORIES.includes(
|
const { data } = IMAGE_CATEGORIES.includes(imageDTO.image_category)
|
||||||
imageDTO.image_category
|
|
||||||
)
|
|
||||||
? boardsApi.endpoints.getBoardImagesTotal.select(
|
? boardsApi.endpoints.getBoardImagesTotal.select(
|
||||||
imageDTO.board_id ?? 'none'
|
imageDTO.board_id ?? 'none'
|
||||||
)(getState())
|
)(getState())
|
||||||
@ -1139,7 +1281,8 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
)(getState());
|
)(getState());
|
||||||
|
|
||||||
const isCacheFullyPopulated =
|
const isCacheFullyPopulated =
|
||||||
currentCache.data && currentCache.data.ids.length >= (total ?? 0);
|
currentCache.data &&
|
||||||
|
currentCache.data.ids.length >= (data?.total ?? 0);
|
||||||
|
|
||||||
const isInDateRange = getIsImageInDateRange(
|
const isInDateRange = getIsImageInDateRange(
|
||||||
currentCache.data,
|
currentCache.data,
|
||||||
@ -1183,21 +1326,10 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
board_id,
|
board_id,
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
invalidatesTags: (result, error, { imageDTOs, board_id }) => {
|
invalidatesTags: (result, error, { board_id }) => {
|
||||||
//assume all images are being moved from one board for now
|
|
||||||
const oldBoardId = imageDTOs[0]?.board_id;
|
|
||||||
return [
|
return [
|
||||||
// update the destination board
|
// update the destination board
|
||||||
{ type: 'Board', id: board_id ?? 'none' },
|
{ type: 'Board', id: board_id ?? 'none' },
|
||||||
// update new board totals
|
|
||||||
{ type: 'BoardImagesTotal', id: board_id ?? 'none' },
|
|
||||||
{ type: 'BoardAssetsTotal', id: board_id ?? 'none' },
|
|
||||||
// update old board totals
|
|
||||||
{ type: 'BoardImagesTotal', id: oldBoardId ?? 'none' },
|
|
||||||
{ type: 'BoardAssetsTotal', id: oldBoardId ?? 'none' },
|
|
||||||
// update the no_board totals
|
|
||||||
{ type: 'BoardImagesTotal', id: 'none' },
|
|
||||||
{ type: 'BoardAssetsTotal', id: 'none' },
|
|
||||||
];
|
];
|
||||||
},
|
},
|
||||||
async onQueryStarted(
|
async onQueryStarted(
|
||||||
@ -1213,6 +1345,8 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
* - *update* getImageDTO for each image
|
* - *update* getImageDTO for each image
|
||||||
* - *add* to board_id/[images|assets]
|
* - *add* to board_id/[images|assets]
|
||||||
* - *remove* from [old_board_id|no_board]/[images|assets]
|
* - *remove* from [old_board_id|no_board]/[images|assets]
|
||||||
|
* - decrement old board's totals for each image
|
||||||
|
* - increment new board's totals for each image
|
||||||
*/
|
*/
|
||||||
|
|
||||||
added_image_names.forEach((image_name) => {
|
added_image_names.forEach((image_name) => {
|
||||||
@ -1221,7 +1355,8 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
'getImageDTO',
|
'getImageDTO',
|
||||||
image_name,
|
image_name,
|
||||||
(draft) => {
|
(draft) => {
|
||||||
draft.board_id = new_board_id;
|
draft.board_id =
|
||||||
|
new_board_id === 'none' ? undefined : new_board_id;
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
@ -1234,6 +1369,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
|
|
||||||
const categories = getCategories(imageDTO);
|
const categories = getCategories(imageDTO);
|
||||||
const old_board_id = imageDTO.board_id;
|
const old_board_id = imageDTO.board_id;
|
||||||
|
const isAsset = ASSETS_CATEGORIES.includes(imageDTO.image_category);
|
||||||
|
|
||||||
// remove from the old board
|
// remove from the old board
|
||||||
dispatch(
|
dispatch(
|
||||||
@ -1246,6 +1382,28 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// decrement old board's total
|
||||||
|
dispatch(
|
||||||
|
boardsApi.util.updateQueryData(
|
||||||
|
isAsset ? 'getBoardAssetsTotal' : 'getBoardImagesTotal',
|
||||||
|
old_board_id ?? 'none',
|
||||||
|
(draft) => {
|
||||||
|
draft.total = Math.max(draft.total - 1, 0);
|
||||||
|
}
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
// increment new board's total
|
||||||
|
dispatch(
|
||||||
|
boardsApi.util.updateQueryData(
|
||||||
|
isAsset ? 'getBoardAssetsTotal' : 'getBoardImagesTotal',
|
||||||
|
new_board_id ?? 'none',
|
||||||
|
(draft) => {
|
||||||
|
draft.total += 1;
|
||||||
|
}
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
const queryArgs = {
|
const queryArgs = {
|
||||||
board_id: new_board_id,
|
board_id: new_board_id,
|
||||||
categories,
|
categories,
|
||||||
@ -1255,9 +1413,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
queryArgs
|
queryArgs
|
||||||
)(getState());
|
)(getState());
|
||||||
|
|
||||||
const { data: previousTotal } = IMAGE_CATEGORIES.includes(
|
const { data } = IMAGE_CATEGORIES.includes(imageDTO.image_category)
|
||||||
imageDTO.image_category
|
|
||||||
)
|
|
||||||
? boardsApi.endpoints.getBoardImagesTotal.select(
|
? boardsApi.endpoints.getBoardImagesTotal.select(
|
||||||
new_board_id ?? 'none'
|
new_board_id ?? 'none'
|
||||||
)(getState())
|
)(getState())
|
||||||
@ -1267,10 +1423,10 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
|
|
||||||
const isCacheFullyPopulated =
|
const isCacheFullyPopulated =
|
||||||
currentCache.data &&
|
currentCache.data &&
|
||||||
currentCache.data.ids.length >= (previousTotal ?? 0);
|
currentCache.data.ids.length >= (data?.total ?? 0);
|
||||||
|
|
||||||
const isInDateRange =
|
const isInDateRange =
|
||||||
(previousTotal || 0) >= IMAGE_LIMIT
|
(data?.total ?? 0) >= IMAGE_LIMIT
|
||||||
? getIsImageInDateRange(currentCache.data, imageDTO)
|
? getIsImageInDateRange(currentCache.data, imageDTO)
|
||||||
: true;
|
: true;
|
||||||
|
|
||||||
@ -1310,10 +1466,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
}),
|
}),
|
||||||
invalidatesTags: (result, error, { imageDTOs }) => {
|
invalidatesTags: (result, error, { imageDTOs }) => {
|
||||||
const touchedBoardIds: string[] = [];
|
const touchedBoardIds: string[] = [];
|
||||||
const tags: ApiTagDescription[] = [
|
const tags: ApiTagDescription[] = [];
|
||||||
{ type: 'BoardImagesTotal', id: 'none' },
|
|
||||||
{ type: 'BoardAssetsTotal', id: 'none' },
|
|
||||||
];
|
|
||||||
|
|
||||||
result?.removed_image_names.forEach((image_name) => {
|
result?.removed_image_names.forEach((image_name) => {
|
||||||
const board_id = imageDTOs.find((i) => i.image_name === image_name)
|
const board_id = imageDTOs.find((i) => i.image_name === image_name)
|
||||||
@ -1324,8 +1477,6 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
}
|
}
|
||||||
|
|
||||||
tags.push({ type: 'Board', id: board_id });
|
tags.push({ type: 'Board', id: board_id });
|
||||||
tags.push({ type: 'BoardImagesTotal', id: board_id });
|
|
||||||
tags.push({ type: 'BoardAssetsTotal', id: board_id });
|
|
||||||
});
|
});
|
||||||
|
|
||||||
return tags;
|
return tags;
|
||||||
@ -1343,6 +1494,8 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
* - *update* getImageDTO for each image
|
* - *update* getImageDTO for each image
|
||||||
* - *remove* from old_board_id/[images|assets]
|
* - *remove* from old_board_id/[images|assets]
|
||||||
* - *add* to no_board/[images|assets]
|
* - *add* to no_board/[images|assets]
|
||||||
|
* - decrement old board's totals for each image
|
||||||
|
* - increment new board's (no board) totals for each image
|
||||||
*/
|
*/
|
||||||
|
|
||||||
removed_image_names.forEach((image_name) => {
|
removed_image_names.forEach((image_name) => {
|
||||||
@ -1363,6 +1516,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
}
|
}
|
||||||
|
|
||||||
const categories = getCategories(imageDTO);
|
const categories = getCategories(imageDTO);
|
||||||
|
const isAsset = ASSETS_CATEGORIES.includes(imageDTO.image_category);
|
||||||
|
|
||||||
// remove from the old board
|
// remove from the old board
|
||||||
dispatch(
|
dispatch(
|
||||||
@ -1375,6 +1529,28 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// decrement old board's total
|
||||||
|
dispatch(
|
||||||
|
boardsApi.util.updateQueryData(
|
||||||
|
isAsset ? 'getBoardAssetsTotal' : 'getBoardImagesTotal',
|
||||||
|
imageDTO.board_id ?? 'none',
|
||||||
|
(draft) => {
|
||||||
|
draft.total = Math.max(draft.total - 1, 0);
|
||||||
|
}
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
// increment new board's total (no board)
|
||||||
|
dispatch(
|
||||||
|
boardsApi.util.updateQueryData(
|
||||||
|
isAsset ? 'getBoardAssetsTotal' : 'getBoardImagesTotal',
|
||||||
|
'none',
|
||||||
|
(draft) => {
|
||||||
|
draft.total += 1;
|
||||||
|
}
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
// add to `no_board`
|
// add to `no_board`
|
||||||
const queryArgs = {
|
const queryArgs = {
|
||||||
board_id: 'none',
|
board_id: 'none',
|
||||||
@ -1385,9 +1561,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
queryArgs
|
queryArgs
|
||||||
)(getState());
|
)(getState());
|
||||||
|
|
||||||
const { data: total } = IMAGE_CATEGORIES.includes(
|
const { data } = IMAGE_CATEGORIES.includes(imageDTO.image_category)
|
||||||
imageDTO.image_category
|
|
||||||
)
|
|
||||||
? boardsApi.endpoints.getBoardImagesTotal.select(
|
? boardsApi.endpoints.getBoardImagesTotal.select(
|
||||||
imageDTO.board_id ?? 'none'
|
imageDTO.board_id ?? 'none'
|
||||||
)(getState())
|
)(getState())
|
||||||
@ -1396,10 +1570,11 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
)(getState());
|
)(getState());
|
||||||
|
|
||||||
const isCacheFullyPopulated =
|
const isCacheFullyPopulated =
|
||||||
currentCache.data && currentCache.data.ids.length >= (total ?? 0);
|
currentCache.data &&
|
||||||
|
currentCache.data.ids.length >= (data?.total ?? 0);
|
||||||
|
|
||||||
const isInDateRange =
|
const isInDateRange =
|
||||||
(total || 0) >= IMAGE_LIMIT
|
(data?.total ?? 0) >= IMAGE_LIMIT
|
||||||
? getIsImageInDateRange(currentCache.data, imageDTO)
|
? getIsImageInDateRange(currentCache.data, imageDTO)
|
||||||
: true;
|
: true;
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import {
|
|||||||
CheckpointModelConfig,
|
CheckpointModelConfig,
|
||||||
ControlNetModelConfig,
|
ControlNetModelConfig,
|
||||||
IPAdapterModelConfig,
|
IPAdapterModelConfig,
|
||||||
|
T2IAdapterModelConfig,
|
||||||
DiffusersModelConfig,
|
DiffusersModelConfig,
|
||||||
ImportModelConfig,
|
ImportModelConfig,
|
||||||
LoRAModelConfig,
|
LoRAModelConfig,
|
||||||
@ -41,6 +42,10 @@ export type IPAdapterModelConfigEntity = IPAdapterModelConfig & {
|
|||||||
id: string;
|
id: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type T2IAdapterModelConfigEntity = T2IAdapterModelConfig & {
|
||||||
|
id: string;
|
||||||
|
};
|
||||||
|
|
||||||
export type TextualInversionModelConfigEntity = TextualInversionModelConfig & {
|
export type TextualInversionModelConfigEntity = TextualInversionModelConfig & {
|
||||||
id: string;
|
id: string;
|
||||||
};
|
};
|
||||||
@ -53,6 +58,7 @@ type AnyModelConfigEntity =
|
|||||||
| LoRAModelConfigEntity
|
| LoRAModelConfigEntity
|
||||||
| ControlNetModelConfigEntity
|
| ControlNetModelConfigEntity
|
||||||
| IPAdapterModelConfigEntity
|
| IPAdapterModelConfigEntity
|
||||||
|
| T2IAdapterModelConfigEntity
|
||||||
| TextualInversionModelConfigEntity
|
| TextualInversionModelConfigEntity
|
||||||
| VaeModelConfigEntity;
|
| VaeModelConfigEntity;
|
||||||
|
|
||||||
@ -145,6 +151,10 @@ export const ipAdapterModelsAdapter =
|
|||||||
createEntityAdapter<IPAdapterModelConfigEntity>({
|
createEntityAdapter<IPAdapterModelConfigEntity>({
|
||||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||||
});
|
});
|
||||||
|
export const t2iAdapterModelsAdapter =
|
||||||
|
createEntityAdapter<T2IAdapterModelConfigEntity>({
|
||||||
|
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||||
|
});
|
||||||
export const textualInversionModelsAdapter =
|
export const textualInversionModelsAdapter =
|
||||||
createEntityAdapter<TextualInversionModelConfigEntity>({
|
createEntityAdapter<TextualInversionModelConfigEntity>({
|
||||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||||
@ -470,6 +480,37 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
);
|
);
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
|
getT2IAdapterModels: build.query<
|
||||||
|
EntityState<T2IAdapterModelConfigEntity>,
|
||||||
|
void
|
||||||
|
>({
|
||||||
|
query: () => ({ url: 'models/', params: { model_type: 't2i_adapter' } }),
|
||||||
|
providesTags: (result) => {
|
||||||
|
const tags: ApiTagDescription[] = [
|
||||||
|
{ type: 'T2IAdapterModel', id: LIST_TAG },
|
||||||
|
];
|
||||||
|
|
||||||
|
if (result) {
|
||||||
|
tags.push(
|
||||||
|
...result.ids.map((id) => ({
|
||||||
|
type: 'T2IAdapterModel' as const,
|
||||||
|
id,
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return tags;
|
||||||
|
},
|
||||||
|
transformResponse: (response: { models: T2IAdapterModelConfig[] }) => {
|
||||||
|
const entities = createModelEntities<T2IAdapterModelConfigEntity>(
|
||||||
|
response.models
|
||||||
|
);
|
||||||
|
return t2iAdapterModelsAdapter.setAll(
|
||||||
|
t2iAdapterModelsAdapter.getInitialState(),
|
||||||
|
entities
|
||||||
|
);
|
||||||
|
},
|
||||||
|
}),
|
||||||
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
|
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
|
||||||
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
|
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
|
||||||
providesTags: (result) => {
|
providesTags: (result) => {
|
||||||
@ -567,6 +608,7 @@ export const {
|
|||||||
useGetOnnxModelsQuery,
|
useGetOnnxModelsQuery,
|
||||||
useGetControlNetModelsQuery,
|
useGetControlNetModelsQuery,
|
||||||
useGetIPAdapterModelsQuery,
|
useGetIPAdapterModelsQuery,
|
||||||
|
useGetT2IAdapterModelsQuery,
|
||||||
useGetLoRAModelsQuery,
|
useGetLoRAModelsQuery,
|
||||||
useGetTextualInversionModelsQuery,
|
useGetTextualInversionModelsQuery,
|
||||||
useGetVaeModelsQuery,
|
useGetVaeModelsQuery,
|
||||||
|
@ -13,7 +13,7 @@ export const useBoardTotal = (board_id: BoardId) => {
|
|||||||
const { data: totalAssets } = useGetBoardAssetsTotalQuery(board_id);
|
const { data: totalAssets } = useGetBoardAssetsTotalQuery(board_id);
|
||||||
|
|
||||||
const currentViewTotal = useMemo(
|
const currentViewTotal = useMemo(
|
||||||
() => (galleryView === 'images' ? totalImages : totalAssets),
|
() => (galleryView === 'images' ? totalImages?.total : totalAssets?.total),
|
||||||
[galleryView, totalAssets, totalImages]
|
[galleryView, totalAssets, totalImages]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
594
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
594
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
@ -67,6 +67,7 @@ export type VAEModelField = s['VAEModelField'];
|
|||||||
export type LoRAModelField = s['LoRAModelField'];
|
export type LoRAModelField = s['LoRAModelField'];
|
||||||
export type ControlNetModelField = s['ControlNetModelField'];
|
export type ControlNetModelField = s['ControlNetModelField'];
|
||||||
export type IPAdapterModelField = s['IPAdapterModelField'];
|
export type IPAdapterModelField = s['IPAdapterModelField'];
|
||||||
|
export type T2IAdapterModelField = s['T2IAdapterModelField'];
|
||||||
export type ModelsList = s['ModelsList'];
|
export type ModelsList = s['ModelsList'];
|
||||||
export type ControlField = s['ControlField'];
|
export type ControlField = s['ControlField'];
|
||||||
export type IPAdapterField = s['IPAdapterField'];
|
export type IPAdapterField = s['IPAdapterField'];
|
||||||
@ -83,6 +84,9 @@ export type ControlNetModelConfig =
|
|||||||
| ControlNetModelDiffusersConfig;
|
| ControlNetModelDiffusersConfig;
|
||||||
export type IPAdapterModelInvokeAIConfig = s['IPAdapterModelInvokeAIConfig'];
|
export type IPAdapterModelInvokeAIConfig = s['IPAdapterModelInvokeAIConfig'];
|
||||||
export type IPAdapterModelConfig = IPAdapterModelInvokeAIConfig;
|
export type IPAdapterModelConfig = IPAdapterModelInvokeAIConfig;
|
||||||
|
export type T2IAdapterModelDiffusersConfig =
|
||||||
|
s['T2IAdapterModelDiffusersConfig'];
|
||||||
|
export type T2IAdapterModelConfig = T2IAdapterModelDiffusersConfig;
|
||||||
export type TextualInversionModelConfig = s['TextualInversionModelConfig'];
|
export type TextualInversionModelConfig = s['TextualInversionModelConfig'];
|
||||||
export type DiffusersModelConfig =
|
export type DiffusersModelConfig =
|
||||||
| s['StableDiffusion1ModelDiffusersConfig']
|
| s['StableDiffusion1ModelDiffusersConfig']
|
||||||
@ -99,6 +103,7 @@ export type AnyModelConfig =
|
|||||||
| VaeModelConfig
|
| VaeModelConfig
|
||||||
| ControlNetModelConfig
|
| ControlNetModelConfig
|
||||||
| IPAdapterModelConfig
|
| IPAdapterModelConfig
|
||||||
|
| T2IAdapterModelConfig
|
||||||
| TextualInversionModelConfig
|
| TextualInversionModelConfig
|
||||||
| MainModelConfig
|
| MainModelConfig
|
||||||
| OnnxModelConfig;
|
| OnnxModelConfig;
|
||||||
|
@ -161,16 +161,16 @@ version = { attr = "invokeai.version.__version__" }
|
|||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
"where" = ["."]
|
"where" = ["."]
|
||||||
"include" = [
|
"include" = [
|
||||||
"invokeai.assets.web*","invokeai.version*",
|
"invokeai.assets.fonts*","invokeai.version*",
|
||||||
"invokeai.generator*","invokeai.backend*",
|
"invokeai.generator*","invokeai.backend*",
|
||||||
"invokeai.frontend*", "invokeai.frontend.web.dist*",
|
"invokeai.frontend*", "invokeai.frontend.web.dist*",
|
||||||
"invokeai.frontend.web.static*",
|
"invokeai.frontend.web.static*",
|
||||||
"invokeai.configs*",
|
"invokeai.configs*",
|
||||||
"invokeai.app*","ldm*",
|
"invokeai.app*",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
"invokeai.assets.web" = ["**.png","**.js","**.woff2","**.css"]
|
"invokeai.assets.fonts" = ["**/*.ttf"]
|
||||||
"invokeai.backend" = ["**.png"]
|
"invokeai.backend" = ["**.png"]
|
||||||
"invokeai.configs" = ["*.example", "**/*.yaml", "*.txt"]
|
"invokeai.configs" = ["*.example", "**/*.yaml", "*.txt"]
|
||||||
"invokeai.frontend.web.dist" = ["**"]
|
"invokeai.frontend.web.dist" = ["**"]
|
||||||
|
0
tests/app/__init__.py
Normal file
0
tests/app/__init__.py
Normal file
0
tests/app/util/__init__.py
Normal file
0
tests/app/util/__init__.py
Normal file
42
tests/app/util/test_controlnet_utils.py
Normal file
42
tests/app/util/test_controlnet_utils.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_channels", [1, 2, 3])
|
||||||
|
def test_prepare_control_image_num_channels(num_channels):
|
||||||
|
"""Test that the `num_channels` parameter is applied correctly in prepare_control_image(...)."""
|
||||||
|
np_image = np.zeros((256, 256, 3), dtype=np.uint8)
|
||||||
|
pil_image = Image.fromarray(np_image)
|
||||||
|
|
||||||
|
torch_image = prepare_control_image(
|
||||||
|
image=pil_image,
|
||||||
|
width=256,
|
||||||
|
height=256,
|
||||||
|
num_channels=num_channels,
|
||||||
|
device="cpu",
|
||||||
|
do_classifier_free_guidance=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch_image.shape == (1, num_channels, 256, 256)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_channels", [0, 4])
|
||||||
|
def test_prepare_control_image_num_channels_too_large(num_channels):
|
||||||
|
"""Test that an exception is raised in prepare_control_image(...) if the `num_channels` parameter is out of the
|
||||||
|
supported range.
|
||||||
|
"""
|
||||||
|
np_image = np.zeros((256, 256, 3), dtype=np.uint8)
|
||||||
|
pil_image = Image.fromarray(np_image)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = prepare_control_image(
|
||||||
|
image=pil_image,
|
||||||
|
width=256,
|
||||||
|
height=256,
|
||||||
|
num_channels=num_channels,
|
||||||
|
device="cpu",
|
||||||
|
do_classifier_free_guidance=False,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user