mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
144 Commits
psyche/fea
...
main
Author | SHA1 | Date | |
---|---|---|---|
87261bdbc9 | |||
4e4b6c6dbc | |||
5e8cf9fb6a | |||
c738fe051f | |||
29fe1533f2 | |||
77090070bd | |||
6ba9b1b6b0 | |||
c578b8df1e | |||
cad9a41433 | |||
5fefb3b0f4 | |||
5284a870b0 | |||
e064377c05 | |||
3e569c8312 | |||
16825ee6e9 | |||
3f5340fa53 | |||
f2a1a39b33 | |||
326de55d3e | |||
b2df909570 | |||
026ac36b06 | |||
92125e5fd2 | |||
c0c139da88 | |||
404ad6a7fd | |||
fc39086fb4 | |||
cd215700fe | |||
e97fd85904 | |||
0a263fa5b1 | |||
fae3836a8d | |||
b3d2eb4178 | |||
576f1cbb75 | |||
50085b40bb | |||
cff382715a | |||
54d54d1bf2 | |||
e84ea68282 | |||
160dd36782 | |||
65bb46bcca | |||
2d185fb766 | |||
2ba9b02932 | |||
849da67cc7 | |||
3ea6c9666e | |||
cf633e4ef2 | |||
bbf934d980 | |||
620f733110 | |||
67928609a3 | |||
5f15afb7db | |||
635d2f480d | |||
70c278c810 | |||
56b9906e2e | |||
a808ce81fd | |||
83f82c5ddf | |||
101de8c25d | |||
3339a4baf0 | |||
dff4a88baa | |||
a21f6c4964 | |||
97562504b7 | |||
75d8ac378c | |||
b9dd354e2b | |||
33c2fbd201 | |||
5063be92bf | |||
1047584b3e | |||
6764dcfdaa | |||
012864ceb1 | |||
a0bf20bcee | |||
14ab339b33 | |||
25c91efbb6 | |||
1c1f2c6664 | |||
d7c22b3bf7 | |||
185f2a395f | |||
0c5649491e | |||
94aba5892a | |||
ef093dde29 | |||
34451e5f27 | |||
1f9bdd1a9a | |||
c27d59baf7 | |||
f130ddec7c | |||
a0a259eef1 | |||
b66f19d4d1 | |||
4105a78b83 | |||
19a68afb3a | |||
fd68a2475b | |||
28ff7ba830 | |||
5d0b248fdb | |||
01a4e0f6ef | |||
91e0731506 | |||
d1f904d41f | |||
269388c9f4 | |||
b8486379ce | |||
400eb94d3b | |||
e210c96485 | |||
5f567f41f4 | |||
5fed573a29 | |||
cfac7c8189 | |||
1787de6836 | |||
ac96f187bd | |||
72398350b4 | |||
df9445c351 | |||
87b7a2e39b | |||
f7e46622a1 | |||
71f18353a9 | |||
4228de707b | |||
b6a05629ef | |||
fbaa820643 | |||
db2a2d5e38 | |||
8ba6e6b1f8 | |||
57168d719b | |||
dee6d2c98e | |||
e49105ece5 | |||
0c5e11f521 | |||
a63f842a13 | |||
4bd7fda694 | |||
81f0886d6f | |||
2eb87f3306 | |||
723f3ab0a9 | |||
1bd90e0fd4 | |||
436f18ff55 | |||
cde9696214 | |||
2d9042fb93 | |||
9ed53af520 | |||
56fda669fd | |||
1d8545a76c | |||
5f59a828f9 | |||
1fa6bddc89 | |||
d3a5ca5247 | |||
f01f56a98e | |||
99b0f79784 | |||
e1eb104345 | |||
5c2f95ef50 | |||
b63df9bab9 | |||
a52c899c6d | |||
eeabb7ebe5 | |||
8b1cef978c | |||
152da482cd | |||
3cf0365a35 | |||
5870742bb9 | |||
01d8c62c57 | |||
55a242b2d6 | |||
45263b339f | |||
3319491861 | |||
e687afac90 | |||
b39031ea53 | |||
0b77511271 | |||
c99cd989c1 | |||
317fdadb21 | |||
4e294f9e3e | |||
526e0f30a0 |
2
.github/workflows/python-tests.yml
vendored
2
.github/workflows/python-tests.yml
vendored
@ -60,7 +60,7 @@ jobs:
|
||||
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: macos-default
|
||||
os: macOS-12
|
||||
os: macOS-14
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: windows-cpu
|
||||
os: windows-2022
|
||||
|
@ -11,7 +11,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByOriginResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
PruneResult,
|
||||
@ -106,19 +105,6 @@ async def cancel_by_batch_ids(
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/cancel_by_origin",
|
||||
operation_id="cancel_by_origin",
|
||||
responses={200: {"model": CancelByBatchIDsResult}},
|
||||
)
|
||||
async def cancel_by_origin(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
origin: str = Query(description="The origin to cancel all queue items for"),
|
||||
) -> CancelByOriginResult:
|
||||
"""Immediately cancels all queue items with the given origin"""
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_origin(queue_id=queue_id, origin=origin)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/clear",
|
||||
operation_id="clear",
|
||||
|
@ -40,14 +40,18 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
|
||||
# region Model Field Types
|
||||
MainModel = "MainModelField"
|
||||
FluxMainModel = "FluxMainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VAEModel = "VAEModelField"
|
||||
FluxVAEModel = "FluxVAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
T5EncoderModel = "T5EncoderModelField"
|
||||
CLIPEmbedModel = "CLIPEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
# endregion
|
||||
|
||||
@ -125,13 +129,17 @@ class FieldDescriptions:
|
||||
negative_cond = "Negative conditioning tensor"
|
||||
noise = "Noise tensor"
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
t5_encoder = "T5 tokenizer and text encoder"
|
||||
clip_embed_model = "CLIP Embed loader"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
transformer = "Transformer"
|
||||
vae = "VAE"
|
||||
cond = "Conditioning tensor"
|
||||
controlnet_model = "ControlNet model to load"
|
||||
vae_model = "VAE model to load"
|
||||
lora_model = "LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
flux_model = "Flux model (Transformer) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
@ -231,6 +239,12 @@ class ColorField(BaseModel):
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class FluxConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
|
92
invokeai/app/invocations/flux_text_encoder.py
Normal file
92
invokeai/app/invocations/flux_text_encoder.py
Normal file
@ -0,0 +1,92 @@
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_encoder",
|
||||
title="FLUX Text Encoding",
|
||||
tags=["prompt", "conditioning", "flux"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for a flux image."""
|
||||
|
||||
clip: CLIPField = InputField(
|
||||
title="CLIP",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
t5_encoder: T5EncoderField = InputField(
|
||||
title="T5Encoder",
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
t5_max_seq_len: Literal[256, 512] = InputField(
|
||||
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
|
||||
)
|
||||
prompt: str = InputField(description="Text prompt to encode.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
|
||||
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
|
||||
t5_embeddings = self._t5_encode(context)
|
||||
clip_embeddings = self._clip_encode(context)
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return FluxConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
t5_text_encoder_info as t5_text_encoder,
|
||||
t5_tokenizer_info as t5_tokenizer,
|
||||
):
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(t5_tokenizer, T5Tokenizer)
|
||||
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||
|
||||
prompt_embeds = t5_encoder(prompt)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
|
||||
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
clip_text_encoder_info as clip_text_encoder,
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
):
|
||||
assert isinstance(clip_text_encoder, CLIPTextModel)
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
|
||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
||||
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return pooled_prompt_embeds
|
169
invokeai/app/invocations/flux_text_to_image.py
Normal file
169
invokeai/app/invocations/flux_text_to_image.py
Normal file
@ -0,0 +1,169 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_to_image",
|
||||
title="FLUX Text to Image",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Text-to-image generation using a FLUX model."""
|
||||
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(
|
||||
default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50."
|
||||
)
|
||||
guidance: float = InputField(
|
||||
default=4.0,
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
image = self._run_vae_decoding(context, latents)
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = torch.bfloat16
|
||||
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
|
||||
# Prepare input noise.
|
||||
x = get_noise(
|
||||
num_samples=1,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
x, img_ids = prepare_latent_img_patches(x)
|
||||
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=x.shape[1],
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
|
||||
with transformer_info as transformer:
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
def step_callback() -> None:
|
||||
if context.util.is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
# TODO: Make this look like the image before re-enabling
|
||||
# latent_image = unpack(img.float(), self.height, self.width)
|
||||
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
|
||||
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
|
||||
|
||||
# # Create a new tensor of the required shape [255, 255, 3]
|
||||
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
|
||||
|
||||
# # Convert to a NumPy array and then to a PIL Image
|
||||
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
|
||||
|
||||
# (width, height) = image.size
|
||||
# width *= 8
|
||||
# height *= 8
|
||||
|
||||
# dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
# # TODO: move this whole function to invocation context to properly reference these variables
|
||||
# context._services.events.emit_invocation_denoise_progress(
|
||||
# context._data.queue_item,
|
||||
# context._data.invocation,
|
||||
# state,
|
||||
# ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
# )
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=t5_embeddings,
|
||||
txt_ids=txt_ids,
|
||||
vec=clip_embeddings,
|
||||
timesteps=timesteps,
|
||||
step_callback=step_callback,
|
||||
guidance=self.guidance,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
|
||||
return x
|
||||
|
||||
def _run_vae_decoding(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
latents: torch.Tensor,
|
||||
) -> Image.Image:
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
|
||||
img = vae.decode(latents)
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
return img_pil
|
@ -6,19 +6,13 @@ import cv2
|
||||
import numpy
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||
from invokeai.app.invocations.fields import (
|
||||
ColorField,
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
@ -1013,66 +1007,3 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation_output("canvas_v2_mask_and_crop_output")
|
||||
class CanvasV2MaskAndCropOutput(ImageOutput):
|
||||
offset_x: int = OutputField(description="The x offset of the image, after cropping")
|
||||
offset_y: int = OutputField(description="The y offset of the image, after cropping")
|
||||
|
||||
|
||||
@invocation(
|
||||
"canvas_v2_mask_and_crop",
|
||||
title="Canvas V2 Mask and Crop",
|
||||
tags=["image", "mask", "id"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Handles Canvas V2 image output masking and cropping"""
|
||||
|
||||
image: ImageField = InputField(description="The image to apply the mask to")
|
||||
mask: ImageField = InputField(description="The mask to apply")
|
||||
mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")
|
||||
|
||||
def _prepare_mask(self, mask: Image.Image) -> Image.Image:
|
||||
mask_array = numpy.array(mask)
|
||||
kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
|
||||
dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
|
||||
dilated_mask = Image.fromarray(dilated_mask_array)
|
||||
if self.mask_blur > 0:
|
||||
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
|
||||
return ImageOps.invert(mask.convert("L"))
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CanvasV2MaskAndCropOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
|
||||
image.putalpha(mask)
|
||||
# bbox = image.getbbox()
|
||||
# image = image.crop(bbox)
|
||||
image_dto = context.images.save(image=image)
|
||||
|
||||
return CanvasV2MaskAndCropOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
offset_x=0,
|
||||
offset_y=0,
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
|
||||
# def invoke(self, context: InvocationContext) -> CanvasV2MaskAndCropOutput:
|
||||
# image = context.images.get_pil(self.image.image_name)
|
||||
# mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
|
||||
# image.putalpha(mask)
|
||||
# bbox = image.getbbox()
|
||||
# image = image.crop(bbox)
|
||||
# image_dto = context.images.save(image=image)
|
||||
|
||||
# return CanvasV2MaskAndCropOutput(
|
||||
# image=ImageField(image_name=image_dto.image_name),
|
||||
# offset_x=bbox[0],
|
||||
# offset_y=bbox[1],
|
||||
# width=image.width,
|
||||
# height=image.height,
|
||||
# )
|
||||
|
@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import List, Optional
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -13,7 +13,14 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
|
||||
class ModelIdentifierField(BaseModel):
|
||||
@ -60,6 +67,15 @@ class CLIPField(BaseModel):
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
|
||||
|
||||
class T5EncoderField(BaseModel):
|
||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||
|
||||
|
||||
class VAEField(BaseModel):
|
||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
@ -122,6 +138,78 @@ class ModelIdentifierInvocation(BaseInvocation):
|
||||
return ModelIdentifierOutput(model=self.model)
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Flux base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
max_seq_len: Literal[256, 512] = OutputField(
|
||||
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
|
||||
title="Max Seq Length",
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_model_loader",
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
ui_type=UIType.FluxMainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
|
@ -12,6 +12,7 @@ from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
@ -414,6 +415,17 @@ class MaskOutput(BaseInvocationOutput):
|
||||
height: int = OutputField(description="The height of the mask in pixels.")
|
||||
|
||||
|
||||
@invocation_output("flux_conditioning_output")
|
||||
class FluxConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
||||
conditioning: FluxConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
@classmethod
|
||||
def build(cls, conditioning_name: str) -> "FluxConditioningOutput":
|
||||
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("conditioning_output")
|
||||
class ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
@ -88,7 +88,6 @@ class QueueItemEventBase(QueueEventBase):
|
||||
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
origin: str | None = Field(default=None, description="The origin of the batch")
|
||||
|
||||
|
||||
class InvocationEventBase(QueueItemEventBase):
|
||||
@ -96,6 +95,8 @@ class InvocationEventBase(QueueItemEventBase):
|
||||
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
queue_id: str = Field(description="The ID of the queue")
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
invocation: AnyInvocation = Field(description="The ID of the invocation")
|
||||
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
||||
@ -113,7 +114,6 @@ class InvocationStartedEvent(InvocationEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@ -147,7 +147,6 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@ -185,7 +184,6 @@ class InvocationCompleteEvent(InvocationEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@ -218,7 +216,6 @@ class InvocationErrorEvent(InvocationEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@ -256,7 +253,6 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
session_id=queue_item.session_id,
|
||||
status=queue_item.status,
|
||||
error_type=queue_item.error_type,
|
||||
@ -283,14 +279,12 @@ class BatchEnqueuedEvent(QueueEventBase):
|
||||
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
|
||||
)
|
||||
priority: int = Field(description="The priority of the batch")
|
||||
origin: str | None = Field(default=None, description="The origin of the batch")
|
||||
|
||||
@classmethod
|
||||
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
|
||||
return cls(
|
||||
queue_id=enqueue_result.queue_id,
|
||||
batch_id=enqueue_result.batch.batch_id,
|
||||
origin=enqueue_result.batch.origin,
|
||||
enqueued=enqueue_result.enqueued,
|
||||
requested=enqueue_result.requested,
|
||||
priority=enqueue_result.priority,
|
||||
|
@ -783,8 +783,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
|
||||
if subfolder:
|
||||
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
||||
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
|
||||
path_to_add = Path(f"{top}_{subfolder}")
|
||||
path_to_remove = top / subfolder # sdxl-turbo/vae/
|
||||
subfolder_rename = subfolder.name.replace("/", "_").replace("\\", "_")
|
||||
path_to_add = Path(f"{top}_{subfolder_rename}")
|
||||
else:
|
||||
path_to_remove = Path(".")
|
||||
path_to_add = Path(".")
|
||||
|
@ -77,6 +77,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
||||
type: Optional[ModelType] = Field(description="Type of model", default=None)
|
||||
key: Optional[str] = Field(description="Database ID for this model", default=None)
|
||||
hash: Optional[str] = Field(description="hash of model file", default=None)
|
||||
format: Optional[str] = Field(description="format of model file", default=None)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
|
@ -6,7 +6,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByOriginResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
@ -96,11 +95,6 @@ class SessionQueueBase(ABC):
|
||||
"""Cancels all queue items with matching batch IDs"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
|
||||
"""Cancels all queue items with the given batch origin"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
"""Cancels all queue items with matching queue ID"""
|
||||
|
@ -77,7 +77,6 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]]
|
||||
|
||||
class Batch(BaseModel):
|
||||
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
||||
origin: str | None = Field(default=None, description="The origin of this batch.")
|
||||
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
||||
graph: Graph = Field(description="The graph to initialize the session with")
|
||||
workflow: Optional[WorkflowWithoutID] = Field(
|
||||
@ -196,7 +195,6 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
|
||||
priority: int = Field(default=0, description="The priority of this queue item")
|
||||
batch_id: str = Field(description="The ID of the batch associated with this queue item")
|
||||
origin: str | None = Field(default=None, description="The origin of this queue item. ")
|
||||
session_id: str = Field(
|
||||
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
||||
)
|
||||
@ -296,7 +294,6 @@ class SessionQueueStatus(BaseModel):
|
||||
class BatchStatus(BaseModel):
|
||||
queue_id: str = Field(..., description="The ID of the queue")
|
||||
batch_id: str = Field(..., description="The ID of the batch")
|
||||
origin: str | None = Field(..., description="The origin of the batch")
|
||||
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
||||
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
||||
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
||||
@ -331,12 +328,6 @@ class CancelByBatchIDsResult(BaseModel):
|
||||
canceled: int = Field(..., description="Number of queue items canceled")
|
||||
|
||||
|
||||
class CancelByOriginResult(BaseModel):
|
||||
"""Result of canceling by list of batch ids"""
|
||||
|
||||
canceled: int = Field(..., description="Number of queue items canceled")
|
||||
|
||||
|
||||
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
||||
"""Result of canceling by queue id"""
|
||||
|
||||
@ -442,7 +433,6 @@ class SessionQueueValueToInsert(NamedTuple):
|
||||
field_values: Optional[str] # field_values json
|
||||
priority: int # priority
|
||||
workflow: Optional[str] # workflow json
|
||||
origin: str | None
|
||||
|
||||
|
||||
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
||||
@ -463,7 +453,6 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
|
||||
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
|
||||
priority, # priority
|
||||
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
|
||||
batch.origin, # origin
|
||||
)
|
||||
)
|
||||
return values_to_insert
|
||||
|
@ -10,7 +10,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByOriginResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
@ -128,8 +127,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
self.__cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
@ -418,7 +417,11 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
current_queue_item, batch_status, queue_status
|
||||
)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
@ -426,46 +429,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__lock.release()
|
||||
return CancelByBatchIDsResult(canceled=count)
|
||||
|
||||
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
|
||||
try:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
self.__lock.acquire()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
AND origin == ?
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
"""
|
||||
params = (queue_id, origin)
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
params,
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
{where};
|
||||
""",
|
||||
params,
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.origin == origin:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CancelByOriginResult(canceled=count)
|
||||
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
try:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
@ -578,8 +541,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
started_at,
|
||||
session_id,
|
||||
batch_id,
|
||||
queue_id,
|
||||
origin
|
||||
queue_id
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
@ -659,7 +621,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*), origin
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
@ -671,7 +633,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
total = sum(row[1] for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
origin = result[0]["origin"] if result else None
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
@ -680,7 +641,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
return BatchStatus(
|
||||
batch_id=batch_id,
|
||||
origin=origin,
|
||||
queue_id=queue_id,
|
||||
pending=counts.get("pending", 0),
|
||||
in_progress=counts.get("in_progress", 0),
|
||||
|
@ -17,7 +17,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import build_migration_15
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@ -52,7 +51,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_12(app_config=config))
|
||||
migrator.register_migration(build_migration_13())
|
||||
migrator.register_migration(build_migration_14())
|
||||
migrator.register_migration(build_migration_15())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
@ -1,31 +0,0 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration15Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._add_origin_col(cursor)
|
||||
|
||||
def _add_origin_col(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
- Adds `origin` column to the session queue table.
|
||||
"""
|
||||
|
||||
cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;")
|
||||
|
||||
|
||||
def build_migration_15() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 14 to 15.
|
||||
|
||||
This migration does the following:
|
||||
- Adds `origin` column to the session queue table.
|
||||
"""
|
||||
migration_15 = Migration(
|
||||
from_version=14,
|
||||
to_version=15,
|
||||
callback=Migration15Callback(),
|
||||
)
|
||||
|
||||
return migration_15
|
@ -0,0 +1,260 @@
|
||||
{
|
||||
"name": "FLUX Text to Image",
|
||||
"author": "InvokeAI",
|
||||
"description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||
"version": "1.0.4",
|
||||
"contact": "",
|
||||
"tags": "text2image, flux",
|
||||
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"fieldName": "prompt"
|
||||
},
|
||||
{
|
||||
"nodeId": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"fieldName": "num_steps"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "t5_encoder_model"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"version": "3.0.0",
|
||||
"category": "default"
|
||||
},
|
||||
"nodes": [
|
||||
{
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "flux_model_loader",
|
||||
"version": "1.0.4",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"inputs": {
|
||||
"model": {
|
||||
"name": "model",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder_model": {
|
||||
"name": "t5_encoder_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_embed_model": {
|
||||
"name": "clip_embed_model",
|
||||
"label": ""
|
||||
},
|
||||
"vae_model": {
|
||||
"name": "vae_model",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 381.1882713063478,
|
||||
"y": -95.89663532854017
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"type": "flux_text_encoder",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"clip": {
|
||||
"name": "clip",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": ""
|
||||
},
|
||||
"t5_max_seq_len": {
|
||||
"name": "t5_max_seq_len",
|
||||
"label": "T5 Max Seq Len",
|
||||
"value": 256
|
||||
},
|
||||
"prompt": {
|
||||
"name": "prompt",
|
||||
"label": "",
|
||||
"value": "a cat"
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 824.1970602278849,
|
||||
"y": 146.98251001061735
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"type": "rand_int",
|
||||
"version": "1.0.1",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"inputs": {
|
||||
"low": {
|
||||
"name": "low",
|
||||
"label": "",
|
||||
"value": 0
|
||||
},
|
||||
"high": {
|
||||
"name": "high",
|
||||
"label": "",
|
||||
"value": 2147483647
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 822.9899179655476,
|
||||
"y": 360.9657214885052
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"type": "flux_text_to_image",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
},
|
||||
"positive_text_conditioning": {
|
||||
"name": "positive_text_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"num_steps": {
|
||||
"name": "num_steps",
|
||||
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
||||
"value": 30
|
||||
},
|
||||
"guidance": {
|
||||
"name": "guidance",
|
||||
"label": "",
|
||||
"value": 4
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1216.3900791301849,
|
||||
"y": 5.500841807102248
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "max_seq_len",
|
||||
"targetHandle": "t5_max_seq_len"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning",
|
||||
"type": "default",
|
||||
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_text_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-159bdf1b-79e7-4174-b86e-d40e646964c8seed",
|
||||
"type": "default",
|
||||
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
}
|
||||
]
|
||||
}
|
32
invokeai/backend/flux/math.py
Normal file
32
invokeai/backend/flux/math.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
117
invokeai/backend/flux/model.py
Normal file
117
invokeai/backend/flux/model.py
Normal file
@ -0,0 +1,117 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.modules.layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list[int]
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, params: FluxParams):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
310
invokeai/backend/flux/modules/autoencoder.py
Normal file
310
invokeai/backend/flux/modules/autoencoder.py
Normal file
@ -0,0 +1,310 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoEncoderParams:
|
||||
resolution: int
|
||||
in_channels: int
|
||||
ch: int
|
||||
out_ch: int
|
||||
ch_mult: list[int]
|
||||
num_res_blocks: int
|
||||
z_channels: int
|
||||
scale_factor: float
|
||||
shift_factor: float
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
|
||||
def attention(self, h_: Tensor) -> Tensor:
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
||||
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
||||
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
||||
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x + self.proj_out(self.attention(x))
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
h = self.norm2(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
in_channels: int,
|
||||
ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
# downsampling
|
||||
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
block_in = self.ch
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1])
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch: int,
|
||||
out_ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
in_channels: int,
|
||||
resolution: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.ffactor = 2 ** (self.num_resolutions - 1)
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class DiagonalGaussian(nn.Module):
|
||||
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
||||
super().__init__()
|
||||
self.sample = sample
|
||||
self.chunk_dim = chunk_dim
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
||||
if self.sample:
|
||||
std = torch.exp(0.5 * logvar)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
else:
|
||||
return mean
|
||||
|
||||
|
||||
class AutoEncoder(nn.Module):
|
||||
def __init__(self, params: AutoEncoderParams):
|
||||
super().__init__()
|
||||
self.encoder = Encoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
out_ch=params.out_ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.reg = DiagonalGaussian()
|
||||
|
||||
self.scale_factor = params.scale_factor
|
||||
self.shift_factor = params.shift_factor
|
||||
|
||||
def encode(self, x: Tensor) -> Tensor:
|
||||
z = self.reg(self.encoder(x))
|
||||
z = self.scale_factor * (z - self.shift_factor)
|
||||
return z
|
||||
|
||||
def decode(self, z: Tensor) -> Tensor:
|
||||
z = z / self.scale_factor + self.shift_factor
|
||||
return self.decoder(z)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.decode(self.encode(x))
|
33
invokeai/backend/flux/modules/conditioner.py
Normal file
33
invokeai/backend/flux/modules/conditioner.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from torch import Tensor, nn
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
class HFEncoder(nn.Module):
|
||||
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
|
||||
super().__init__()
|
||||
self.max_length = max_length
|
||||
self.is_clip = is_clip
|
||||
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
||||
self.tokenizer = tokenizer
|
||||
self.hf_module = encoder
|
||||
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
||||
|
||||
def forward(self, text: list[str]) -> Tensor:
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
outputs = self.hf_module(
|
||||
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
||||
attention_mask=None,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return outputs[self.output_key]
|
253
invokeai/backend/flux/modules/layers.py
Normal file
253
invokeai/backend/flux/modules/layers.py
Normal file
@ -0,0 +1,253 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.math import attention, rope
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: Tensor) -> Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(t)
|
||||
return embedding
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim)
|
||||
self.key_norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: Tensor
|
||||
scale: Tensor
|
||||
gate: Tensor
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
ModulationOut(*out[3:]) if self.is_double else None,
|
||||
)
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
# proj and mlp_out
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
167
invokeai/backend/flux/sampling.py
Normal file
167
invokeai/backend/flux/sampling.py
Normal file
@ -0,0 +1,167 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
|
||||
|
||||
def get_noise(
|
||||
num_samples: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
):
|
||||
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float16
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
16,
|
||||
# allow for packing
|
||||
2 * math.ceil(height / 16),
|
||||
2 * math.ceil(width / 16),
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
||||
bs, c, h, w = img.shape
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
txt = t5(prompt)
|
||||
if txt.shape[0] == 1 and bs > 1:
|
||||
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
||||
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
||||
|
||||
vec = clip(prompt)
|
||||
if vec.shape[0] == 1 and bs > 1:
|
||||
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
return {
|
||||
"img": img,
|
||||
"img_ids": img_ids.to(img.device),
|
||||
"txt": txt.to(img.device),
|
||||
"txt_ids": txt_ids.to(img.device),
|
||||
"vec": vec.to(img.device),
|
||||
}
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: Tensor):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
|
||||
def get_schedule(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> list[float]:
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# eastimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def denoise(
|
||||
model: Flux,
|
||||
# model input
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
vec: Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[], None],
|
||||
guidance: float = 4.0,
|
||||
):
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
step_callback()
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=math.ceil(height / 16),
|
||||
w=math.ceil(width / 16),
|
||||
ph=2,
|
||||
pw=2,
|
||||
)
|
||||
|
||||
|
||||
def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert an input image in latent space to patches for diffusion.
|
||||
|
||||
This implementation was extracted from:
|
||||
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
|
||||
"""
|
||||
bs, c, h, w = latent_img.shape
|
||||
|
||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
# Generate patch position ids.
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
return img, img_ids
|
71
invokeai/backend/flux/util.py
Normal file
71
invokeai/backend/flux/util.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Literal
|
||||
|
||||
from invokeai.backend.flux.model import FluxParams
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSpec:
|
||||
params: FluxParams
|
||||
ae_params: AutoEncoderParams
|
||||
ckpt_path: str | None
|
||||
ae_path: str | None
|
||||
repo_id: str | None
|
||||
repo_flow: str | None
|
||||
repo_ae: str | None
|
||||
|
||||
|
||||
max_seq_lengths: Dict[str, Literal[256, 512]] = {
|
||||
"flux-dev": 512,
|
||||
"flux-schnell": 256,
|
||||
}
|
||||
|
||||
|
||||
ae_params = {
|
||||
"flux": AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
params = {
|
||||
"flux-dev": FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
"flux-schnell": FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=False,
|
||||
),
|
||||
}
|
@ -52,6 +52,7 @@ class BaseModelType(str, Enum):
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
Flux = "flux"
|
||||
# Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
|
||||
@ -66,7 +67,9 @@ class ModelType(str, Enum):
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
CLIPVision = "clip_vision"
|
||||
CLIPEmbed = "clip_embed"
|
||||
T2IAdapter = "t2i_adapter"
|
||||
T5Encoder = "t5_encoder"
|
||||
SpandrelImageToImage = "spandrel_image_to_image"
|
||||
|
||||
|
||||
@ -74,6 +77,7 @@ class SubModelType(str, Enum):
|
||||
"""Submodel type."""
|
||||
|
||||
UNet = "unet"
|
||||
Transformer = "transformer"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
Tokenizer = "tokenizer"
|
||||
@ -104,6 +108,9 @@ class ModelFormat(str, Enum):
|
||||
EmbeddingFile = "embedding_file"
|
||||
EmbeddingFolder = "embedding_folder"
|
||||
InvokeAI = "invokeai"
|
||||
T5Encoder = "t5_encoder"
|
||||
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
|
||||
BnbQuantizednf4b = "bnb_quantized_nf4b"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
@ -186,7 +193,9 @@ class ModelConfigBase(BaseModel):
|
||||
class CheckpointConfigBase(ModelConfigBase):
|
||||
"""Model config for checkpoint-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(
|
||||
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
|
||||
)
|
||||
config_path: str = Field(description="path to the checkpoint model config file")
|
||||
converted_at: Optional[float] = Field(
|
||||
description="When this model was last converted to diffusers", default_factory=time.time
|
||||
@ -205,6 +214,26 @@ class LoRAConfigBase(ModelConfigBase):
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
|
||||
|
||||
class T5EncoderConfigBase(ModelConfigBase):
|
||||
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
|
||||
|
||||
|
||||
class T5EncoderConfig(T5EncoderConfigBase):
|
||||
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
|
||||
|
||||
|
||||
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase):
|
||||
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.BnbQuantizedLlmInt8b.value}")
|
||||
|
||||
|
||||
class LoRALyCORISConfig(LoRAConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
@ -229,7 +258,6 @@ class VAECheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
@ -268,7 +296,6 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase)
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
@ -317,6 +344,21 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.format = ModelFormat.BnbQuantizednf4b
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
@ -350,6 +392,17 @@ class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
|
||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for Clip Embeddings."""
|
||||
|
||||
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for CLIPVision."""
|
||||
|
||||
@ -408,12 +461,15 @@ AnyModelConfig = Annotated[
|
||||
Union[
|
||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
|
||||
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
||||
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
|
||||
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
|
||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
||||
@ -421,6 +477,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
@ -72,6 +72,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
pass
|
||||
|
||||
config.path = str(self._get_model_path(config))
|
||||
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
|
@ -193,15 +193,6 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> bool:
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
|
@ -1,22 +1,6 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
|
||||
The cache returns context manager generators designed to load the
|
||||
model into the GPU within the context, and unload outside the
|
||||
context. Use like this:
|
||||
|
||||
cache = ModelCache(max_cache_size=7.5)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
||||
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
||||
do_something_in_GPU(SD1,SD2)
|
||||
|
||||
|
||||
"""
|
||||
""" """
|
||||
|
||||
import gc
|
||||
import math
|
||||
@ -40,45 +24,64 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
|
||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
# Size of a GB in bytes.
|
||||
GB = 2**30
|
||||
|
||||
# Size of a MB in bytes.
|
||||
MB = 2**20
|
||||
|
||||
|
||||
class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Implementation of ModelCacheBase."""
|
||||
"""A cache for managing models in memory.
|
||||
|
||||
The cache is based on two levels of model storage:
|
||||
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
|
||||
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
|
||||
|
||||
The model cache is based on the following assumptions:
|
||||
- storage_device_mem_size > execution_device_mem_size
|
||||
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
|
||||
|
||||
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
|
||||
the execution_device.
|
||||
|
||||
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
|
||||
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
|
||||
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
|
||||
|
||||
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
|
||||
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
|
||||
configuration.
|
||||
|
||||
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
|
||||
the context, and unload outside the context.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
|
||||
do_something_on_gpu(SD1)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||
max_cache_size: float,
|
||||
max_vram_cache_size: float,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
:param max_cache_size: Maximum size of the storage_device cache in GBs.
|
||||
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded.
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
@ -86,7 +89,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._precision: torch.dtype = precision
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
@ -145,15 +147,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
total += cache_record.size
|
||||
return total
|
||||
|
||||
def exists(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> bool:
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
return key in self._cached_models
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
@ -203,7 +196,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||
self.stats.cache_size = int(self._max_cache_size * GB)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
@ -231,10 +224,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
return model_key
|
||||
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Move any unused models from VRAM."""
|
||||
reserved = self._max_vram_cache_size * GIG
|
||||
"""Offload models from the execution_device to make room for size_required.
|
||||
|
||||
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
|
||||
"""
|
||||
reserved = self._max_vram_cache_size * GB
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
@ -245,7 +241,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
|
||||
)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
@ -303,7 +299,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self.logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
||||
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
@ -326,14 +322,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
||||
f" {(cache_entry.size/GB):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log CUDA diagnostics."""
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||
ram = "%4.2fG" % (self.cache_size() / GIG)
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
|
||||
ram = "%4.2fG" % (self.cache_size() / GB)
|
||||
|
||||
in_ram_models = 0
|
||||
in_vram_models = 0
|
||||
@ -353,17 +349,20 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
)
|
||||
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
# calculate how much memory this model will require
|
||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size.
|
||||
|
||||
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
|
||||
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
||||
garbage-collected.
|
||||
"""
|
||||
bytes_needed = size
|
||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
|
||||
current_size = self.cache_size()
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(
|
||||
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GIG):.2f} GB"
|
||||
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GB):.2f} GB"
|
||||
)
|
||||
|
||||
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||
@ -380,7 +379,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
if not cache_entry.locked:
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
|
234
invokeai/backend/model_manager/load/model_loaders/flux.py
Normal file
234
invokeai/backend/model_manager/load/model_loaders/flux.py
Normal file
@ -0,0 +1,234 @@
|
||||
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
||||
"""Class for Flux model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.util import ae_params, params
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
CLIPEmbedDiffusersConfig,
|
||||
MainBnbQuantized4bCheckpointConfig,
|
||||
MainCheckpointConfig,
|
||||
T5EncoderBnbQuantizedLlmInt8bConfig,
|
||||
T5EncoderConfig,
|
||||
VAECheckpointConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
try:
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
|
||||
bnb_available = True
|
||||
except ImportError:
|
||||
bnb_available = False
|
||||
|
||||
app_config = get_config()
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
class FluxVAELoader(ModelLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, VAECheckpointConfig):
|
||||
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings():
|
||||
model = AutoEncoder(ae_params[config.config_path])
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
model.to(dtype=self._torch_dtype)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
|
||||
class ClipCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CLIPEmbedDiffusersConfig):
|
||||
raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer:
|
||||
return CLIPTokenizer.from_pretrained(Path(config.path) / "tokenizer")
|
||||
case SubModelType.TextEncoder:
|
||||
return CLIPTextModel.from_pretrained(Path(config.path) / "text_encoder")
|
||||
|
||||
raise ValueError(
|
||||
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.BnbQuantizedLlmInt8b)
|
||||
class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, T5EncoderBnbQuantizedLlmInt8bConfig):
|
||||
raise ValueError("Only T5EncoderBnbQuantizedLlmInt8bConfig models are currently supported here.")
|
||||
if not bnb_available:
|
||||
raise ImportError(
|
||||
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
|
||||
)
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
te2_model_path = Path(config.path) / "text_encoder_2"
|
||||
model_config = AutoConfig.from_pretrained(te2_model_path)
|
||||
with accelerate.init_empty_weights():
|
||||
model = AutoModelForTextEncoding.from_config(model_config)
|
||||
model = quantize_model_llm_int8(model, modules_to_not_convert=set())
|
||||
|
||||
state_dict_path = te2_model_path / "bnb_llm_int8_model.safetensors"
|
||||
state_dict = load_file(state_dict_path)
|
||||
self._load_state_dict_into_t5(model, state_dict)
|
||||
|
||||
return model
|
||||
|
||||
raise ValueError(
|
||||
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _load_state_dict_into_t5(cls, model: T5EncoderModel, state_dict: dict[str, torch.Tensor]):
|
||||
# There is a shared reference to a single weight tensor in the model.
|
||||
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
|
||||
# be present in the state_dict.
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
|
||||
assert len(unexpected_keys) == 0
|
||||
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
|
||||
# Assert that the layers we expect to be shared are actually shared.
|
||||
assert model.encoder.embed_tokens.weight is model.shared.weight
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
||||
class T5EncoderCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, T5EncoderConfig):
|
||||
raise ValueError("Only T5EncoderConfig models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2")
|
||||
|
||||
raise ValueError(
|
||||
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
class FluxCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config)
|
||||
|
||||
raise ValueError(
|
||||
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings():
|
||||
model = Flux(params[config.config_path])
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
|
||||
class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config)
|
||||
|
||||
raise ValueError(
|
||||
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
|
||||
if not bnb_available:
|
||||
raise ImportError(
|
||||
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
|
||||
)
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings():
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params[config.config_path])
|
||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
@ -78,7 +78,12 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
|
||||
# TO DO: Add exception handling
|
||||
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
|
||||
if module in ["diffusers", "transformers"]:
|
||||
if module in [
|
||||
"diffusers",
|
||||
"transformers",
|
||||
"invokeai.backend.quantization.fast_quantized_transformers_model",
|
||||
"invokeai.backend.quantization.fast_quantized_diffusion_model",
|
||||
]:
|
||||
res_type = sys.modules[module]
|
||||
else:
|
||||
res_type = sys.modules["diffusers"].pipelines
|
||||
|
@ -36,8 +36,18 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
||||
}
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
|
||||
)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint
|
||||
)
|
||||
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
|
@ -9,7 +9,7 @@ from typing import Optional
|
||||
import torch
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from transformers import CLIPTokenizer
|
||||
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
|
||||
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
@ -50,6 +50,17 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
),
|
||||
):
|
||||
return model.calc_size()
|
||||
elif isinstance(
|
||||
model,
|
||||
(
|
||||
T5TokenizerFast,
|
||||
T5Tokenizer,
|
||||
),
|
||||
):
|
||||
# HACK(ryand): len(model) just returns the vocabulary size, so this is blatantly wrong. It should be small
|
||||
# relative to the text encoder that it's used with, so shouldn't matter too much, but we should fix this at some
|
||||
# point.
|
||||
return len(model)
|
||||
else:
|
||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||
# supported model types.
|
||||
|
@ -95,6 +95,7 @@ class ModelProbe(object):
|
||||
}
|
||||
|
||||
CLASS2TYPE = {
|
||||
"FluxPipeline": ModelType.Main,
|
||||
"StableDiffusionPipeline": ModelType.Main,
|
||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
@ -106,6 +107,7 @@ class ModelProbe(object):
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||
"T2IAdapter": ModelType.T2IAdapter,
|
||||
"CLIPModel": ModelType.CLIPEmbed,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -161,7 +163,7 @@ class ModelProbe(object):
|
||||
fields["description"] = (
|
||||
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format()
|
||||
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||
|
||||
fields["default_settings"] = fields.get("default_settings")
|
||||
@ -176,10 +178,10 @@ class ModelProbe(object):
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
||||
# additional fields needed for main and controlnet models
|
||||
if (
|
||||
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
|
||||
and fields["format"] is ModelFormat.Checkpoint
|
||||
):
|
||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [
|
||||
ModelFormat.Checkpoint,
|
||||
ModelFormat.BnbQuantizednf4b,
|
||||
]:
|
||||
ckpt_config_path = cls._get_checkpoint_config_path(
|
||||
model_path,
|
||||
model_type=fields["type"],
|
||||
@ -222,7 +224,8 @@ class ModelProbe(object):
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
for key in [str(k) for k in ckpt.keys()]:
|
||||
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")):
|
||||
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.", "double_blocks.")):
|
||||
# Keys starting with double_blocks are associated with Flux models
|
||||
return ModelType.Main
|
||||
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
||||
return ModelType.VAE
|
||||
@ -321,10 +324,27 @@ class ModelProbe(object):
|
||||
return possible_conf.absolute()
|
||||
|
||||
if model_type is ModelType.Main:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
config_file = f"stable-diffusion/{config_file}"
|
||||
if base_type == BaseModelType.Flux:
|
||||
# TODO: Decide between dev/schnell
|
||||
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||
if "guidance_in.out_layer.weight" in state_dict:
|
||||
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||
# Due to model type and format being the descriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
config_file = "flux-dev"
|
||||
else:
|
||||
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||
# Due to model type and format being the descriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
config_file = "flux-schnell"
|
||||
else:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
config_file = f"stable-diffusion/{config_file}"
|
||||
elif model_type is ModelType.ControlNet:
|
||||
config_file = (
|
||||
"controlnet/cldm_v15.yaml"
|
||||
@ -333,7 +353,13 @@ class ModelProbe(object):
|
||||
)
|
||||
elif model_type is ModelType.VAE:
|
||||
config_file = (
|
||||
"stable-diffusion/v1-inference.yaml"
|
||||
# For flux, this is a key in invokeai.backend.flux.util.ae_params
|
||||
# Due to model type and format being the descriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
"flux"
|
||||
if base_type is BaseModelType.Flux
|
||||
else "stable-diffusion/v1-inference.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
else "stable-diffusion/sd_xl_base.yaml"
|
||||
if base_type is BaseModelType.StableDiffusionXL
|
||||
@ -416,11 +442,15 @@ class CheckpointProbeBase(ProbeBase):
|
||||
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
if "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict:
|
||||
return ModelFormat.BnbQuantizednf4b
|
||||
return ModelFormat("checkpoint")
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
|
||||
if model_type != ModelType.Main:
|
||||
base_type = self.get_base_type()
|
||||
if model_type != ModelType.Main or base_type == BaseModelType.Flux:
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
@ -440,6 +470,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
|
||||
return BaseModelType.Flux
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
@ -482,6 +514,7 @@ class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
(r"xl", BaseModelType.StableDiffusionXL),
|
||||
(r"sd2", BaseModelType.StableDiffusion2),
|
||||
(r"vae", BaseModelType.StableDiffusion1),
|
||||
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
|
||||
]:
|
||||
if re.search(regexp, self.model_path.name, re.IGNORECASE):
|
||||
return basetype
|
||||
@ -713,6 +746,11 @@ class TextualInversionFolderProbe(FolderProbeBase):
|
||||
return TextualInversionCheckpointProbe(path).get_base_type()
|
||||
|
||||
|
||||
class T5EncoderFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat.T5Encoder
|
||||
|
||||
|
||||
class ONNXFolderProbe(PipelineFolderProbe):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
# Due to the way the installer is set up, the configuration file for safetensors
|
||||
@ -805,6 +843,11 @@ class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class CLIPEmbedFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class SpandrelImageToImageFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
@ -835,8 +878,10 @@ ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPEmbed, CLIPEmbedFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
|
||||
|
@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
|
||||
|
||||
|
||||
class StarterModelWithoutDependencies(BaseModel):
|
||||
@ -11,6 +11,7 @@ class StarterModelWithoutDependencies(BaseModel):
|
||||
name: str
|
||||
base: BaseModelType
|
||||
type: ModelType
|
||||
format: Optional[ModelFormat] = None
|
||||
is_installed: bool = False
|
||||
|
||||
|
||||
@ -51,10 +52,76 @@ cyberrealistic_negative = StarterModel(
|
||||
type=ModelType.TextualInversion,
|
||||
)
|
||||
|
||||
t5_base_encoder = StarterModel(
|
||||
name="t5_base_encoder",
|
||||
base=BaseModelType.Any,
|
||||
source="InvokeAI/t5-v1_1-xxl::bfloat16",
|
||||
description="T5-XXL text encoder (used in FLUX pipelines). ~8GB",
|
||||
type=ModelType.T5Encoder,
|
||||
)
|
||||
|
||||
t5_8b_quantized_encoder = StarterModel(
|
||||
name="t5_bnb_int8_quantized_encoder",
|
||||
base=BaseModelType.Any,
|
||||
source="InvokeAI/t5-v1_1-xxl::bnb_llm_int8",
|
||||
description="T5-XXL text encoder with bitsandbytes LLM.int8() quantization (used in FLUX pipelines). ~5GB",
|
||||
type=ModelType.T5Encoder,
|
||||
format=ModelFormat.BnbQuantizedLlmInt8b,
|
||||
)
|
||||
|
||||
clip_l_encoder = StarterModel(
|
||||
name="clip-vit-large-patch14",
|
||||
base=BaseModelType.Any,
|
||||
source="InvokeAI/clip-vit-large-patch14-text-encoder::bfloat16",
|
||||
description="CLIP-L text encoder (used in FLUX pipelines). ~250MB",
|
||||
type=ModelType.CLIPEmbed,
|
||||
)
|
||||
|
||||
flux_vae = StarterModel(
|
||||
name="FLUX.1-schnell_ae",
|
||||
base=BaseModelType.Flux,
|
||||
source="black-forest-labs/FLUX.1-schnell::ae.safetensors",
|
||||
description="FLUX VAE compatible with both schnell and dev variants.",
|
||||
type=ModelType.VAE,
|
||||
)
|
||||
|
||||
|
||||
# List of starter models, displayed on the frontend.
|
||||
# The order/sort of this list is not changed by the frontend - set it how you want it here.
|
||||
STARTER_MODELS: list[StarterModel] = [
|
||||
# region: Main
|
||||
StarterModel(
|
||||
name="FLUX Schnell (Quantized)",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_schnell::transformer/bnb_nf4/flux1-schnell-bnb_nf4.safetensors",
|
||||
description="FLUX schnell transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
),
|
||||
StarterModel(
|
||||
name="FLUX Dev (Quantized)",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_dev::transformer/bnb_nf4/flux1-dev-bnb_nf4.safetensors",
|
||||
description="FLUX dev transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
),
|
||||
StarterModel(
|
||||
name="FLUX Schnell",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_schnell::transformer/base/flux1-schnell.safetensors",
|
||||
description="FLUX schnell transformer in bfloat16. Total size with dependencies: ~33GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
),
|
||||
StarterModel(
|
||||
name="FLUX Dev",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_dev::transformer/base/flux1-dev.safetensors",
|
||||
description="FLUX dev transformer in bfloat16. Total size with dependencies: ~33GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
),
|
||||
StarterModel(
|
||||
name="CyberRealistic v4.1",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
@ -125,6 +192,7 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
# endregion
|
||||
# region VAE
|
||||
sdxl_fp16_vae_fix,
|
||||
flux_vae,
|
||||
# endregion
|
||||
# region LoRA
|
||||
StarterModel(
|
||||
@ -450,6 +518,11 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
),
|
||||
# endregion
|
||||
# region TextEncoders
|
||||
t5_base_encoder,
|
||||
t5_8b_quantized_encoder,
|
||||
clip_l_encoder,
|
||||
# endregion
|
||||
]
|
||||
|
||||
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"
|
||||
|
@ -54,6 +54,7 @@ def filter_files(
|
||||
"lora_weights.safetensors",
|
||||
"weights.pb",
|
||||
"onnx_data",
|
||||
"spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`.
|
||||
)
|
||||
):
|
||||
paths.append(file)
|
||||
@ -62,13 +63,13 @@ def filter_files(
|
||||
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
||||
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
||||
# will adhere to this naming convention, so this is an area to be careful of.
|
||||
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||
elif re.search(r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||
paths.append(file)
|
||||
|
||||
# limit search to subfolder if requested
|
||||
if subfolder:
|
||||
subfolder = root / subfolder
|
||||
paths = [x for x in paths if x.parent == Path(subfolder)]
|
||||
paths = [x for x in paths if Path(subfolder) in x.parents]
|
||||
|
||||
# _filter_by_variant uniquifies the paths and returns a set
|
||||
return sorted(_filter_by_variant(paths, variant))
|
||||
@ -97,7 +98,9 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
if variant == ModelRepoVariant.Flax:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".json", ".txt"]:
|
||||
# Note: '.model' was added to support:
|
||||
# https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/768d12a373ed5cc9ef9a9dea7504dc09fcc14842/tokenizer_2/spiece.model
|
||||
elif path.suffix in [".json", ".txt", ".model"]:
|
||||
result.add(path)
|
||||
|
||||
elif variant in [
|
||||
@ -140,6 +143,23 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
continue
|
||||
|
||||
for candidate_list in subfolder_weights.values():
|
||||
# Check if at least one of the files has the explicit fp16 variant.
|
||||
at_least_one_fp16 = False
|
||||
for candidate in candidate_list:
|
||||
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
|
||||
at_least_one_fp16 = True
|
||||
break
|
||||
|
||||
if not at_least_one_fp16:
|
||||
# If none of the candidates in this candidate_list have the explicit fp16 variant label, then this
|
||||
# candidate_list probably doesn't adhere to the variant naming convention that we expected. In this case,
|
||||
# we'll simply keep all the candidates. An example of a model that hits this case is
|
||||
# `black-forest-labs/FLUX.1-schnell` (as of commit 012d2fd).
|
||||
for candidate in candidate_list:
|
||||
result.add(candidate.path)
|
||||
|
||||
# The candidate_list seems to have the expected variant naming convention. We'll select the highest scoring
|
||||
# candidate.
|
||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||
if highest_score_candidate:
|
||||
result.add(highest_score_candidate.path)
|
||||
|
0
invokeai/backend/quantization/__init__.py
Normal file
0
invokeai/backend/quantization/__init__.py
Normal file
135
invokeai/backend/quantization/bnb_llm_int8.py
Normal file
135
invokeai/backend/quantization/bnb_llm_int8.py
Normal file
@ -0,0 +1,135 @@
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
# This file contains utils for working with models that use bitsandbytes LLM.int8() quantization.
|
||||
# The utils in this file are partially inspired by:
|
||||
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
|
||||
|
||||
|
||||
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
|
||||
# cleaner by re-implementing bnb.nn.Linear8bitLt with proper use of buffers and less magic. But, for now, we try to
|
||||
# stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
|
||||
|
||||
|
||||
class InvokeInt8Params(bnb.nn.Int8Params):
|
||||
"""We override cuda() to avoid re-quantizing the weights in the following cases:
|
||||
- We loaded quantized weights from a state_dict on the cpu, and then moved the model to the gpu.
|
||||
- We are moving the model back-and-forth between the cpu and gpu.
|
||||
"""
|
||||
|
||||
def cuda(self, device):
|
||||
if self.has_fp16_weights:
|
||||
return super().cuda(device)
|
||||
elif self.CB is not None and self.SCB is not None:
|
||||
self.data = self.data.cuda()
|
||||
self.CB = self.data
|
||||
self.SCB = self.SCB.cuda()
|
||||
else:
|
||||
# we store the 8-bit rows-major weight
|
||||
# we convert this weight to the turning/ampere weight during the first inference pass
|
||||
B = self.data.contiguous().half().cuda(device)
|
||||
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
|
||||
del CBt
|
||||
del SCBt
|
||||
self.data = CB
|
||||
self.CB = CB
|
||||
self.SCB = SCB
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
prefix: str,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
weight = state_dict.pop(prefix + "weight")
|
||||
bias = state_dict.pop(prefix + "bias", None)
|
||||
|
||||
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
|
||||
scb = state_dict.pop(prefix + "SCB", None)
|
||||
|
||||
# Currently, we only support weight_format=0.
|
||||
weight_format = state_dict.pop(prefix + "weight_format", None)
|
||||
assert weight_format == 0
|
||||
|
||||
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||
# rather than raising an exception to correctly implement this API.
|
||||
assert len(state_dict) == 0
|
||||
|
||||
if scb is not None:
|
||||
# We are loading a pre-quantized state dict.
|
||||
self.weight = InvokeInt8Params(
|
||||
data=weight,
|
||||
requires_grad=self.weight.requires_grad,
|
||||
has_fp16_weights=False,
|
||||
# Note: After quantization, CB is the same as weight.
|
||||
CB=weight,
|
||||
SCB=scb,
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
else:
|
||||
# We are loading a non-quantized state dict.
|
||||
|
||||
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
|
||||
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
|
||||
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
|
||||
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
|
||||
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
|
||||
self.weight = InvokeInt8Params(
|
||||
data=weight,
|
||||
requires_grad=self.weight.requires_grad,
|
||||
has_fp16_weights=False,
|
||||
CB=None,
|
||||
SCB=None,
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
|
||||
# Reset the state. The persisted fields are based on the initialization behaviour in
|
||||
# `bnb.nn.Linear8bitLt.__init__()`.
|
||||
new_state = bnb.MatmulLtState()
|
||||
new_state.threshold = self.state.threshold
|
||||
new_state.has_fp16_weights = False
|
||||
new_state.use_pool = self.state.use_pool
|
||||
self.state = new_state
|
||||
|
||||
|
||||
def _convert_linear_layers_to_llm_8bit(
|
||||
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
|
||||
) -> None:
|
||||
"""Convert all linear layers in the module to bnb.nn.Linear8bitLt layers."""
|
||||
for name, child in module.named_children():
|
||||
fullname = f"{prefix}.{name}" if prefix else name
|
||||
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||
has_bias = child.bias is not None
|
||||
replacement = InvokeLinear8bitLt(
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=has_bias,
|
||||
has_fp16_weights=False,
|
||||
threshold=outlier_threshold,
|
||||
)
|
||||
replacement.weight.data = child.weight.data
|
||||
if has_bias:
|
||||
replacement.bias.data = child.bias.data
|
||||
replacement.requires_grad_(False)
|
||||
module.__setattr__(name, replacement)
|
||||
else:
|
||||
_convert_linear_layers_to_llm_8bit(
|
||||
child, ignore_modules, outlier_threshold=outlier_threshold, prefix=fullname
|
||||
)
|
||||
|
||||
|
||||
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0):
|
||||
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
|
||||
_convert_linear_layers_to_llm_8bit(
|
||||
module=model, ignore_modules=modules_to_not_convert, outlier_threshold=outlier_threshold
|
||||
)
|
||||
|
||||
return model
|
156
invokeai/backend/quantization/bnb_nf4.py
Normal file
156
invokeai/backend/quantization/bnb_nf4.py
Normal file
@ -0,0 +1,156 @@
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
# This file contains utils for working with models that use bitsandbytes NF4 quantization.
|
||||
# The utils in this file are partially inspired by:
|
||||
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
|
||||
|
||||
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
|
||||
# cleaner by re-implementing bnb.nn.LinearNF4 with proper use of buffers and less magic. But, for now, we try to stick
|
||||
# close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
|
||||
|
||||
|
||||
class InvokeLinearNF4(bnb.nn.LinearNF4):
|
||||
"""A class that extends `bnb.nn.LinearNF4` to add the following functionality:
|
||||
- Ability to load Linear NF4 layers from a pre-quantized state_dict.
|
||||
- Ability to load Linear NF4 layers from a state_dict when the model is on the "meta" device.
|
||||
"""
|
||||
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
prefix: str,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
"""This method is based on the logic in the bitsandbytes serialization unit tests for `Linear4bit`:
|
||||
https://github.com/bitsandbytes-foundation/bitsandbytes/blob/6d714a5cce3db5bd7f577bc447becc7a92d5ccc7/tests/test_linear4bit.py#L52-L71
|
||||
"""
|
||||
weight = state_dict.pop(prefix + "weight")
|
||||
bias = state_dict.pop(prefix + "bias", None)
|
||||
# We expect the remaining keys to be quant_state keys.
|
||||
quant_state_sd = state_dict
|
||||
|
||||
# During serialization, the quant_state is stored as subkeys of "weight." (See
|
||||
# `bnb.nn.LinearNF4._save_to_state_dict()`). We validate that they at least have the correct prefix.
|
||||
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||
# rather than raising an exception to correctly implement this API.
|
||||
assert all(k.startswith(prefix + "weight.") for k in quant_state_sd.keys())
|
||||
|
||||
if len(quant_state_sd) > 0:
|
||||
# We are loading a pre-quantized state dict.
|
||||
self.weight = bnb.nn.Params4bit.from_prequantized(
|
||||
data=weight, quantized_stats=quant_state_sd, device=weight.device
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias, requires_grad=False)
|
||||
else:
|
||||
# We are loading a non-quantized state dict.
|
||||
|
||||
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
|
||||
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
|
||||
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
|
||||
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
|
||||
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
|
||||
self.weight = bnb.nn.Params4bit(
|
||||
data=weight,
|
||||
requires_grad=self.weight.requires_grad,
|
||||
compress_statistics=self.weight.compress_statistics,
|
||||
quant_type=self.weight.quant_type,
|
||||
quant_storage=self.weight.quant_storage,
|
||||
module=self,
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
|
||||
|
||||
def _replace_param(
|
||||
param: torch.nn.Parameter | bnb.nn.Params4bit,
|
||||
data: torch.Tensor,
|
||||
) -> torch.nn.Parameter:
|
||||
"""A helper function to replace the data of a model parameter with new data in a way that allows replacing params on
|
||||
the "meta" device.
|
||||
|
||||
Supports both `torch.nn.Parameter` and `bnb.nn.Params4bit` parameters.
|
||||
"""
|
||||
if param.device.type == "meta":
|
||||
# Doing `param.data = data` raises a RuntimeError if param.data was on the "meta" device, so we need to
|
||||
# re-create the param instead of overwriting the data.
|
||||
if isinstance(param, bnb.nn.Params4bit):
|
||||
return bnb.nn.Params4bit(
|
||||
data,
|
||||
requires_grad=data.requires_grad,
|
||||
quant_state=param.quant_state,
|
||||
compress_statistics=param.compress_statistics,
|
||||
quant_type=param.quant_type,
|
||||
)
|
||||
return torch.nn.Parameter(data, requires_grad=data.requires_grad)
|
||||
|
||||
param.data = data
|
||||
return param
|
||||
|
||||
|
||||
def _convert_linear_layers_to_nf4(
|
||||
module: torch.nn.Module,
|
||||
ignore_modules: set[str],
|
||||
compute_dtype: torch.dtype,
|
||||
compress_statistics: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
"""Convert all linear layers in the model to NF4 quantized linear layers.
|
||||
|
||||
Args:
|
||||
module: All linear layers in this module will be converted.
|
||||
ignore_modules: A set of module prefixes to ignore when converting linear layers.
|
||||
compute_dtype: The dtype to use for computation in the quantized linear layers.
|
||||
compress_statistics: Whether to enable nested quantization (aka double quantization) where the quantization
|
||||
constants from the first quantization are quantized again.
|
||||
prefix: The prefix of the current module in the model. Used to call this function recursively.
|
||||
"""
|
||||
for name, child in module.named_children():
|
||||
fullname = f"{prefix}.{name}" if prefix else name
|
||||
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||
has_bias = child.bias is not None
|
||||
replacement = InvokeLinearNF4(
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=has_bias,
|
||||
compute_dtype=compute_dtype,
|
||||
compress_statistics=compress_statistics,
|
||||
)
|
||||
if has_bias:
|
||||
replacement.bias = _replace_param(replacement.bias, child.bias.data)
|
||||
replacement.weight = _replace_param(replacement.weight, child.weight.data)
|
||||
replacement.requires_grad_(False)
|
||||
module.__setattr__(name, replacement)
|
||||
else:
|
||||
_convert_linear_layers_to_nf4(child, ignore_modules, compute_dtype=compute_dtype, prefix=fullname)
|
||||
|
||||
|
||||
def quantize_model_nf4(model: torch.nn.Module, modules_to_not_convert: set[str], compute_dtype: torch.dtype):
|
||||
"""Apply bitsandbytes nf4 quantization to the model.
|
||||
|
||||
You likely want to call this function inside a `accelerate.init_empty_weights()` context.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
# Initialize the model from a config on the meta device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = ModelClass.from_config(...)
|
||||
|
||||
# Add NF4 quantization linear layers to the model - still on the meta device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.float16)
|
||||
|
||||
# Load a state_dict into the model. (Could be either a prequantized or non-quantized state_dict.)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
# Move the model to the "cuda" device. If the model was non-quantized, this is where the weight quantization takes
|
||||
# place.
|
||||
model.to("cuda")
|
||||
```
|
||||
"""
|
||||
_convert_linear_layers_to_nf4(module=model, ignore_modules=modules_to_not_convert, compute_dtype=compute_dtype)
|
||||
|
||||
return model
|
@ -0,0 +1,79 @@
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import params
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
|
||||
|
||||
|
||||
def main():
|
||||
"""A script for quantizing a FLUX transformer model using the bitsandbytes LLM.int8() quantization method.
|
||||
|
||||
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
|
||||
etc.) are hardcoded and would need to be modified for other use cases.
|
||||
"""
|
||||
# Load the FLUX transformer model onto the meta device.
|
||||
model_path = Path(
|
||||
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
|
||||
)
|
||||
|
||||
with log_time("Intialize FLUX transformer on meta device"):
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
p = params["flux-schnell"]
|
||||
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(p)
|
||||
|
||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||
modules_to_not_convert: set[str] = set()
|
||||
|
||||
model_int8_path = model_path.parent / "bnb_llm_int8.safetensors"
|
||||
if model_int8_path.exists():
|
||||
# The quantized model already exists, load it and return it.
|
||||
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
|
||||
|
||||
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
|
||||
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
sd = load_file(model_int8_path)
|
||||
model.load_state_dict(sd, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda"):
|
||||
model = model.to("cuda")
|
||||
|
||||
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
|
||||
|
||||
else:
|
||||
# The quantized model does not exist, quantize the model and save it.
|
||||
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
|
||||
|
||||
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
state_dict = load_file(model_path)
|
||||
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda and quantize"):
|
||||
model = model.to("cuda")
|
||||
|
||||
with log_time("Save quantized model"):
|
||||
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
save_file(model.state_dict(), model_int8_path)
|
||||
|
||||
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
|
||||
|
||||
assert isinstance(model, Flux)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,96 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import params
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
|
||||
|
||||
@contextmanager
|
||||
def log_time(name: str):
|
||||
"""Helper context manager to log the time taken by a block of code."""
|
||||
start = time.time()
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
end = time.time()
|
||||
print(f"'{name}' took {end - start:.4f} secs")
|
||||
|
||||
|
||||
def main():
|
||||
"""A script for quantizing a FLUX transformer model using the bitsandbytes NF4 quantization method.
|
||||
|
||||
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
|
||||
etc.) are hardcoded and would need to be modified for other use cases.
|
||||
"""
|
||||
model_path = Path(
|
||||
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
|
||||
)
|
||||
|
||||
# inference_dtype = torch.bfloat16
|
||||
with log_time("Intialize FLUX transformer on meta device"):
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
p = params["flux-schnell"]
|
||||
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(p)
|
||||
|
||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||
modules_to_not_convert: set[str] = set()
|
||||
|
||||
model_nf4_path = model_path.parent / "bnb_nf4.safetensors"
|
||||
if model_nf4_path.exists():
|
||||
# The quantized model already exists, load it and return it.
|
||||
print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...")
|
||||
|
||||
# Replace the linear layers with NF4 quantized linear layers (still on the meta device).
|
||||
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(
|
||||
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
state_dict = load_file(model_nf4_path)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda"):
|
||||
model = model.to("cuda")
|
||||
|
||||
print(f"Successfully loaded pre-quantized model from '{model_nf4_path}'.")
|
||||
|
||||
else:
|
||||
# The quantized model does not exist, quantize the model and save it.
|
||||
print(f"No pre-quantized model found at '{model_nf4_path}'. Quantizing the model...")
|
||||
|
||||
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(
|
||||
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
state_dict = load_file(model_path)
|
||||
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda and quantize"):
|
||||
model = model.to("cuda")
|
||||
|
||||
with log_time("Save quantized model"):
|
||||
model_nf4_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
save_file(model.state_dict(), model_nf4_path)
|
||||
|
||||
print(f"Successfully quantized and saved model to '{model_nf4_path}'.")
|
||||
|
||||
assert isinstance(model, Flux)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,92 @@
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
from safetensors.torch import load_file, save_file
|
||||
from transformers import AutoConfig, AutoModelForTextEncoding, T5EncoderModel
|
||||
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
|
||||
|
||||
|
||||
def load_state_dict_into_t5(model: T5EncoderModel, state_dict: dict):
|
||||
# There is a shared reference to a single weight tensor in the model.
|
||||
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
|
||||
# be present in the state_dict.
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
|
||||
assert len(unexpected_keys) == 0
|
||||
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
|
||||
# Assert that the layers we expect to be shared are actually shared.
|
||||
assert model.encoder.embed_tokens.weight is model.shared.weight
|
||||
|
||||
|
||||
def main():
|
||||
"""A script for quantizing a T5 text encoder model using the bitsandbytes LLM.int8() quantization method.
|
||||
|
||||
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
|
||||
etc.) are hardcoded and would need to be modified for other use cases.
|
||||
"""
|
||||
model_path = Path("/data/misc/text_encoder_2")
|
||||
|
||||
with log_time("Intialize T5 on meta device"):
|
||||
model_config = AutoConfig.from_pretrained(model_path)
|
||||
with accelerate.init_empty_weights():
|
||||
model = AutoModelForTextEncoding.from_config(model_config)
|
||||
|
||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||
modules_to_not_convert: set[str] = set()
|
||||
|
||||
model_int8_path = model_path / "bnb_llm_int8.safetensors"
|
||||
if model_int8_path.exists():
|
||||
# The quantized model already exists, load it and return it.
|
||||
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
|
||||
|
||||
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
|
||||
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
sd = load_file(model_int8_path)
|
||||
load_state_dict_into_t5(model, sd)
|
||||
|
||||
with log_time("Move model to cuda"):
|
||||
model = model.to("cuda")
|
||||
|
||||
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
|
||||
|
||||
else:
|
||||
# The quantized model does not exist, quantize the model and save it.
|
||||
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
|
||||
|
||||
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
# Load sharded state dict.
|
||||
files = list(model_path.glob("*.safetensors"))
|
||||
state_dict = {}
|
||||
for file in files:
|
||||
sd = load_file(file)
|
||||
state_dict.update(sd)
|
||||
load_state_dict_into_t5(model, state_dict)
|
||||
|
||||
with log_time("Move model to cuda and quantize"):
|
||||
model = model.to("cuda")
|
||||
|
||||
with log_time("Save quantized model"):
|
||||
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
state_dict = model.state_dict()
|
||||
state_dict.pop("encoder.embed_tokens.weight")
|
||||
save_file(state_dict, model_int8_path)
|
||||
# This handling of shared weights could also be achieved with save_model(...), but then we'd lose control
|
||||
# over which keys are kept. And, the corresponding load_model(...) function does not support assign=True.
|
||||
# save_model(model, model_int8_path)
|
||||
|
||||
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
|
||||
|
||||
assert isinstance(model, T5EncoderModel)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -25,11 +25,6 @@ class BasicConditioningInfo:
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
"""SDXL text conditioning information produced by Compel."""
|
||||
@ -43,6 +38,22 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
return super().to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FLUXConditioningInfo:
|
||||
clip_embeds: torch.Tensor
|
||||
t5_embeds: torch.Tensor
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype)
|
||||
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterConditioningInfo:
|
||||
cond_image_prompt_embeds: torch.Tensor
|
||||
|
@ -3,10 +3,9 @@ Initialization file for invokeai.backend.util
|
||||
"""
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.util import GIG, Chdir, directory_size
|
||||
from invokeai.backend.util.util import Chdir, directory_size
|
||||
|
||||
__all__ = [
|
||||
"GIG",
|
||||
"directory_size",
|
||||
"Chdir",
|
||||
"InvokeAILogger",
|
||||
|
@ -7,9 +7,6 @@ from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
def slugify(value: str, allow_unicode: bool = False) -> str:
|
||||
"""
|
||||
|
@ -12,10 +12,6 @@ module.exports = {
|
||||
'i18next/no-literal-string': 'error',
|
||||
// https://eslint.org/docs/latest/rules/no-console
|
||||
'no-console': 'error',
|
||||
// https://eslint.org/docs/latest/rules/no-promise-executor-return
|
||||
'no-promise-executor-return': 'error',
|
||||
// https://eslint.org/docs/latest/rules/require-await
|
||||
'require-await': 'error',
|
||||
},
|
||||
overrides: [
|
||||
/**
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { PropsWithChildren, memo, useEffect } from 'react';
|
||||
import { modelChanged } from '../src/features/controlLayers/store/canvasV2Slice';
|
||||
import { modelChanged } from '../src/features/parameters/store/generationSlice';
|
||||
import { useAppDispatch } from '../src/app/store/storeHooks';
|
||||
import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
|
||||
/**
|
||||
@ -10,9 +10,7 @@ export const ReduxInit = memo((props: PropsWithChildren) => {
|
||||
const dispatch = useAppDispatch();
|
||||
useGlobalModifiersInit();
|
||||
useEffect(() => {
|
||||
dispatch(
|
||||
modelChanged({ model: { key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' } })
|
||||
);
|
||||
dispatch(modelChanged({ key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' }));
|
||||
}, []);
|
||||
|
||||
return props.children;
|
||||
|
@ -9,8 +9,6 @@ const config: KnipConfig = {
|
||||
'src/services/api/schema.ts',
|
||||
'src/features/nodes/types/v1/**',
|
||||
'src/features/nodes/types/v2/**',
|
||||
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
|
||||
'src/features/controlLayers/konva/util.ts',
|
||||
],
|
||||
ignoreBinaries: ['only-allow'],
|
||||
paths: {
|
||||
|
@ -52,17 +52,17 @@
|
||||
}
|
||||
},
|
||||
"dependencies": {
|
||||
"@chakra-ui/react-use-size": "^2.1.0",
|
||||
"@dagrejs/dagre": "^1.1.3",
|
||||
"@dagrejs/graphlib": "^2.2.3",
|
||||
"@dnd-kit/core": "^6.1.0",
|
||||
"@dnd-kit/sortable": "^8.0.0",
|
||||
"@dnd-kit/utilities": "^3.2.2",
|
||||
"@fontsource-variable/inter": "^5.0.20",
|
||||
"@invoke-ai/ui-library": "^0.0.31",
|
||||
"@invoke-ai/ui-library": "^0.0.29",
|
||||
"@nanostores/react": "^0.7.3",
|
||||
"@reduxjs/toolkit": "2.2.3",
|
||||
"@roarr/browser-log-writer": "^1.3.0",
|
||||
"async-mutex": "^0.5.0",
|
||||
"chakra-react-select": "^4.9.1",
|
||||
"compare-versions": "^6.1.1",
|
||||
"dateformat": "^5.0.3",
|
||||
@ -74,8 +74,6 @@
|
||||
"jsondiffpatch": "^0.6.0",
|
||||
"konva": "^9.3.14",
|
||||
"lodash-es": "^4.17.21",
|
||||
"lru-cache": "^11.0.0",
|
||||
"nanoid": "^5.0.7",
|
||||
"nanostores": "^0.11.2",
|
||||
"new-github-issue-url": "^1.0.0",
|
||||
"overlayscrollbars": "^2.10.0",
|
||||
@ -90,6 +88,7 @@
|
||||
"react-hotkeys-hook": "4.5.0",
|
||||
"react-i18next": "^14.1.3",
|
||||
"react-icons": "^5.2.1",
|
||||
"react-konva": "^18.2.10",
|
||||
"react-redux": "9.1.2",
|
||||
"react-resizable-panels": "^2.0.23",
|
||||
"react-select": "5.8.0",
|
||||
@ -103,9 +102,9 @@
|
||||
"roarr": "^7.21.1",
|
||||
"serialize-error": "^11.0.3",
|
||||
"socket.io-client": "^4.7.5",
|
||||
"stable-hash": "^0.0.4",
|
||||
"use-debounce": "^10.0.2",
|
||||
"use-device-pixel-ratio": "^1.1.2",
|
||||
"use-image": "^1.1.1",
|
||||
"uuid": "^10.0.0",
|
||||
"zod": "^3.23.8",
|
||||
"zod-validation-error": "^3.3.1"
|
||||
|
102
invokeai/frontend/web/pnpm-lock.yaml
generated
102
invokeai/frontend/web/pnpm-lock.yaml
generated
@ -5,6 +5,9 @@ settings:
|
||||
excludeLinksFromLockfile: false
|
||||
|
||||
dependencies:
|
||||
'@chakra-ui/react-use-size':
|
||||
specifier: ^2.1.0
|
||||
version: 2.1.0(react@18.3.1)
|
||||
'@dagrejs/dagre':
|
||||
specifier: ^1.1.3
|
||||
version: 1.1.3
|
||||
@ -24,8 +27,8 @@ dependencies:
|
||||
specifier: ^5.0.20
|
||||
version: 5.0.20
|
||||
'@invoke-ai/ui-library':
|
||||
specifier: ^0.0.31
|
||||
version: 0.0.31(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.20)(@types/react@18.3.3)(i18next@23.12.2)(react-dom@18.3.1)(react@18.3.1)
|
||||
specifier: ^0.0.29
|
||||
version: 0.0.29(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.20)(@types/react@18.3.3)(i18next@23.12.2)(react-dom@18.3.1)(react@18.3.1)
|
||||
'@nanostores/react':
|
||||
specifier: ^0.7.3
|
||||
version: 0.7.3(nanostores@0.11.2)(react@18.3.1)
|
||||
@ -35,9 +38,6 @@ dependencies:
|
||||
'@roarr/browser-log-writer':
|
||||
specifier: ^1.3.0
|
||||
version: 1.3.0
|
||||
async-mutex:
|
||||
specifier: ^0.5.0
|
||||
version: 0.5.0
|
||||
chakra-react-select:
|
||||
specifier: ^4.9.1
|
||||
version: 4.9.1(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/layout@2.3.1)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@emotion/react@11.13.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1)
|
||||
@ -71,12 +71,6 @@ dependencies:
|
||||
lodash-es:
|
||||
specifier: ^4.17.21
|
||||
version: 4.17.21
|
||||
lru-cache:
|
||||
specifier: ^11.0.0
|
||||
version: 11.0.0
|
||||
nanoid:
|
||||
specifier: ^5.0.7
|
||||
version: 5.0.7
|
||||
nanostores:
|
||||
specifier: ^0.11.2
|
||||
version: 0.11.2
|
||||
@ -119,6 +113,9 @@ dependencies:
|
||||
react-icons:
|
||||
specifier: ^5.2.1
|
||||
version: 5.2.1(react@18.3.1)
|
||||
react-konva:
|
||||
specifier: ^18.2.10
|
||||
version: 18.2.10(konva@9.3.14)(react-dom@18.3.1)(react@18.3.1)
|
||||
react-redux:
|
||||
specifier: 9.1.2
|
||||
version: 9.1.2(@types/react@18.3.3)(react@18.3.1)(redux@5.0.1)
|
||||
@ -158,15 +155,15 @@ dependencies:
|
||||
socket.io-client:
|
||||
specifier: ^4.7.5
|
||||
version: 4.7.5
|
||||
stable-hash:
|
||||
specifier: ^0.0.4
|
||||
version: 0.0.4
|
||||
use-debounce:
|
||||
specifier: ^10.0.2
|
||||
version: 10.0.2(react@18.3.1)
|
||||
use-device-pixel-ratio:
|
||||
specifier: ^1.1.2
|
||||
version: 1.1.2(react@18.3.1)
|
||||
use-image:
|
||||
specifier: ^1.1.1
|
||||
version: 1.1.1(react-dom@18.3.1)(react@18.3.1)
|
||||
uuid:
|
||||
specifier: ^10.0.0
|
||||
version: 10.0.0
|
||||
@ -3564,8 +3561,8 @@ packages:
|
||||
prettier: 3.3.3
|
||||
dev: true
|
||||
|
||||
/@invoke-ai/ui-library@0.0.31(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.20)(@types/react@18.3.3)(i18next@23.12.2)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-7LtOUN/bcGHc8jCRd2m22DvP2eeogqwM/shdXQpLH5RY2FzWJNXlWdVT4hIPGDu7znnk3xvXlZvo6tiGSjbnCQ==}
|
||||
/@invoke-ai/ui-library@0.0.29(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.20)(@types/react@18.3.3)(i18next@23.12.2)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-7SYOaiEEKk9iHk0hg2R2yVxiuV3I1x6bDEv0R3Y2tCH/Aq5XDG2tR+d7SQAPqf5+za3S+qfNFjbjl7GvEMwqmA==}
|
||||
peerDependencies:
|
||||
'@fontsource-variable/inter': ^5.0.16
|
||||
react: ^18.2.0
|
||||
@ -5205,6 +5202,12 @@ packages:
|
||||
'@types/react': 18.3.3
|
||||
dev: true
|
||||
|
||||
/@types/react-reconciler@0.28.8:
|
||||
resolution: {integrity: sha512-SN9c4kxXZonFhbX4hJrZy37yw9e7EIxcpHCxQv5JUS18wDE5ovkQKlqQEkufdJCCMfuI9BnjUJvhYeJ9x5Ra7g==}
|
||||
dependencies:
|
||||
'@types/react': 18.3.3
|
||||
dev: false
|
||||
|
||||
/@types/react-transition-group@4.4.10:
|
||||
resolution: {integrity: sha512-hT/+s0VQs2ojCX823m60m5f0sL5idt9SO6Tj6Dg+rdphGPIeJbJ6CxvBYkgkGKrYeDjvIpKTR38UzmtHJOGW3Q==}
|
||||
dependencies:
|
||||
@ -5900,12 +5903,6 @@ packages:
|
||||
tslib: 2.6.3
|
||||
dev: true
|
||||
|
||||
/async-mutex@0.5.0:
|
||||
resolution: {integrity: sha512-1A94B18jkJ3DYq284ohPxoXbfTA5HsQ7/Mf4DEhcyLx3Bz27Rh59iScbB6EPiP+B+joue6YCxcMXSbFC1tZKwA==}
|
||||
dependencies:
|
||||
tslib: 2.6.3
|
||||
dev: false
|
||||
|
||||
/attr-accept@2.2.2:
|
||||
resolution: {integrity: sha512-7prDjvt9HmqiZ0cl5CRjtS84sEyhsHP2coDkaZKRKVfCDo9s7iw7ChVmar78Gu9pC4SoR/28wFu/G5JJhTnqEg==}
|
||||
engines: {node: '>=4'}
|
||||
@ -8387,6 +8384,15 @@ packages:
|
||||
set-function-name: 2.0.2
|
||||
dev: true
|
||||
|
||||
/its-fine@1.2.5(react@18.3.1):
|
||||
resolution: {integrity: sha512-fXtDA0X0t0eBYAGLVM5YsgJGsJ5jEmqZEPrGbzdf5awjv0xE7nqv3TVnvtUF060Tkes15DbDAKW/I48vsb6SyA==}
|
||||
peerDependencies:
|
||||
react: '>=18.0'
|
||||
dependencies:
|
||||
'@types/react-reconciler': 0.28.8
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/jackspeak@2.3.6:
|
||||
resolution: {integrity: sha512-N3yCS/NegsOBokc8GAdM8UcmfsKiSS8cipheD/nivzr700H+nsMOxJjQnvwOcRYVuFkdH0wGUvW2WbXGmrZGbQ==}
|
||||
engines: {node: '>=14'}
|
||||
@ -8704,11 +8710,6 @@ packages:
|
||||
engines: {node: 14 || >=16.14}
|
||||
dev: true
|
||||
|
||||
/lru-cache@11.0.0:
|
||||
resolution: {integrity: sha512-Qv32eSV1RSCfhY3fpPE2GNZ8jgM9X7rdAfemLWqTUxwiyIC4jJ6Sy0fZ8H+oLWevO6i4/bizg7c8d8i6bxrzbA==}
|
||||
engines: {node: 20 || >=22}
|
||||
dev: false
|
||||
|
||||
/lru-cache@5.1.1:
|
||||
resolution: {integrity: sha512-KpNARQA3Iwv+jTA0utUVVbrh+Jlrr1Fv0e56GGzAFOXN7dk/FviaDW8LHmK52DlcH4WP2n6gI8vN1aesBFgo9w==}
|
||||
dependencies:
|
||||
@ -8996,12 +8997,6 @@ packages:
|
||||
hasBin: true
|
||||
dev: true
|
||||
|
||||
/nanoid@5.0.7:
|
||||
resolution: {integrity: sha512-oLxFY2gd2IqnjcYyOXD8XGCftpGtZP2AbHbOkthDkvRywH5ayNtPVy9YlOPcHckXzbLTCHpkb7FB+yuxKV13pQ==}
|
||||
engines: {node: ^18 || >=20}
|
||||
hasBin: true
|
||||
dev: false
|
||||
|
||||
/nanostores@0.11.2:
|
||||
resolution: {integrity: sha512-6bucNxMJA5rNV554WQl+MWGng0QVMzlRgpKTHHfIbVLrhQ+yRXBychV9ECGVuuUfCMQPjfIG9bj8oJFZ9hYP/Q==}
|
||||
engines: {node: ^18.0.0 || >=20.0.0}
|
||||
@ -9806,6 +9801,33 @@ packages:
|
||||
resolution: {integrity: sha512-/LLMVyas0ljjAtoYiPqYiL8VWXzUUdThrmU5+n20DZv+a+ClRoevUzw5JxU+Ieh5/c87ytoTBV9G1FiKfNJdmg==}
|
||||
dev: true
|
||||
|
||||
/react-konva@18.2.10(konva@9.3.14)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-ohcX1BJINL43m4ynjZ24MxFI1syjBdrXhqVxYVDw2rKgr3yuS0x/6m1Y2Z4sl4T/gKhfreBx8KHisd0XC6OT1g==}
|
||||
peerDependencies:
|
||||
konva: ^8.0.1 || ^7.2.5 || ^9.0.0
|
||||
react: '>=18.0.0'
|
||||
react-dom: '>=18.0.0'
|
||||
dependencies:
|
||||
'@types/react-reconciler': 0.28.8
|
||||
its-fine: 1.2.5(react@18.3.1)
|
||||
konva: 9.3.14
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
react-reconciler: 0.29.2(react@18.3.1)
|
||||
scheduler: 0.23.2
|
||||
dev: false
|
||||
|
||||
/react-reconciler@0.29.2(react@18.3.1):
|
||||
resolution: {integrity: sha512-zZQqIiYgDCTP/f1N/mAR10nJGrPD2ZR+jDSEsKWJHYC7Cm2wodlwbR3upZRdC3cjIjSlTLNVyO7Iu0Yy7t2AYg==}
|
||||
engines: {node: '>=0.10.0'}
|
||||
peerDependencies:
|
||||
react: ^18.3.1
|
||||
dependencies:
|
||||
loose-envify: 1.4.0
|
||||
react: 18.3.1
|
||||
scheduler: 0.23.2
|
||||
dev: false
|
||||
|
||||
/react-redux@9.1.2(@types/react@18.3.3)(react@18.3.1)(redux@5.0.1):
|
||||
resolution: {integrity: sha512-0OA4dhM1W48l3uzmv6B7TXPCGmokUU4p1M44DGN2/D9a1FjVPukVjER1PcPX97jIg6aUeLq1XJo1IpfbgULn0w==}
|
||||
peerDependencies:
|
||||
@ -10603,10 +10625,6 @@ packages:
|
||||
resolution: {integrity: sha512-D9cPgkvLlV3t3IzL0D0YLvGA9Ahk4PcvVwUbN0dSGr1aP0Nrt4AEnTUbuGvquEC0mA64Gqt1fzirlRs5ibXx8g==}
|
||||
dev: true
|
||||
|
||||
/stable-hash@0.0.4:
|
||||
resolution: {integrity: sha512-LjdcbuBeLcdETCrPn9i8AYAZ1eCtu4ECAWtP7UleOiZ9LzVxRzzUZEoZ8zB24nhkQnDWyET0I+3sWokSDS3E7g==}
|
||||
dev: false
|
||||
|
||||
/stack-generator@2.0.10:
|
||||
resolution: {integrity: sha512-mwnua/hkqM6pF4k8SnmZ2zfETsRUpWXREfA/goT8SLCV4iOFa4bzOX2nDipWAZFPTjLvQB82f5yaodMVhK0yJQ==}
|
||||
dependencies:
|
||||
@ -11312,6 +11330,16 @@ packages:
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/use-image@1.1.1(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-n4YO2k8AJG/BcDtxmBx8Aa+47kxY5m335dJiCQA5tTeVU4XdhrhqR6wT0WISRXwdMEOv5CSjqekDZkEMiiWaYQ==}
|
||||
peerDependencies:
|
||||
react: '>=16.8.0'
|
||||
react-dom: '>=16.8.0'
|
||||
dependencies:
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
/use-isomorphic-layout-effect@1.1.2(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-49L8yCO3iGT/ZF9QttjwLF/ZD9Iwto5LnH5LmEdk/6cFmXddqi2ulF0edxTwjj+7mqvpVVGQWvbXZdn32wRSHA==}
|
||||
peerDependencies:
|
||||
|
@ -80,7 +80,6 @@
|
||||
"aboutDesc": "Using Invoke for work? Check out:",
|
||||
"aboutHeading": "Own Your Creative Power",
|
||||
"accept": "Accept",
|
||||
"apply": "Apply",
|
||||
"add": "Add",
|
||||
"advanced": "Advanced",
|
||||
"ai": "ai",
|
||||
@ -116,7 +115,6 @@
|
||||
"githubLabel": "Github",
|
||||
"goTo": "Go to",
|
||||
"hotkeysLabel": "Hotkeys",
|
||||
"loadingImage": "Loading Image",
|
||||
"imageFailedToLoad": "Unable to Load Image",
|
||||
"img2img": "Image To Image",
|
||||
"inpaint": "inpaint",
|
||||
@ -327,10 +325,6 @@
|
||||
"canceled": "Canceled",
|
||||
"completedIn": "Completed in",
|
||||
"batch": "Batch",
|
||||
"origin": "Origin",
|
||||
"originCanvas": "Canvas",
|
||||
"originWorkflows": "Workflows",
|
||||
"originOther": "Other",
|
||||
"batchFieldValues": "Batch Field Values",
|
||||
"item": "Item",
|
||||
"session": "Session",
|
||||
@ -702,6 +696,8 @@
|
||||
"availableModels": "Available Models",
|
||||
"baseModel": "Base Model",
|
||||
"cancel": "Cancel",
|
||||
"clipEmbed": "CLIP Embed",
|
||||
"clipVision": "CLIP Vision",
|
||||
"config": "Config",
|
||||
"convert": "Convert",
|
||||
"convertingModelBegin": "Converting Model. Please wait.",
|
||||
@ -789,13 +785,16 @@
|
||||
"settings": "Settings",
|
||||
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
|
||||
"source": "Source",
|
||||
"spandrelImageToImage": "Image to Image (Spandrel)",
|
||||
"starterModels": "Starter Models",
|
||||
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
|
||||
"syncModels": "Sync Models",
|
||||
"textualInversions": "Textual Inversions",
|
||||
"triggerPhrases": "Trigger Phrases",
|
||||
"loraTriggerPhrases": "LoRA Trigger Phrases",
|
||||
"mainModelTriggerPhrases": "Main Model Trigger Phrases",
|
||||
"typePhraseHere": "Type phrase here",
|
||||
"t5Encoder": "T5 Encoder",
|
||||
"upcastAttention": "Upcast Attention",
|
||||
"uploadImage": "Upload Image",
|
||||
"urlOrLocalPath": "URL or Local Path",
|
||||
@ -1101,6 +1100,7 @@
|
||||
"confirmOnDelete": "Confirm On Delete",
|
||||
"developer": "Developer",
|
||||
"displayInProgress": "Display Progress Images",
|
||||
"enableImageDebugging": "Enable Image Debugging",
|
||||
"enableInformationalPopovers": "Enable Informational Popovers",
|
||||
"informationalPopoversDisabled": "Informational Popovers Disabled",
|
||||
"informationalPopoversDisabledDesc": "Informational popovers have been disabled. Enable them in Settings.",
|
||||
@ -1567,7 +1567,7 @@
|
||||
"copyToClipboard": "Copy to Clipboard",
|
||||
"cursorPosition": "Cursor Position",
|
||||
"darkenOutsideSelection": "Darken Outside Selection",
|
||||
"discardAll": "Discard All & Cancel Pending Generations",
|
||||
"discardAll": "Discard All",
|
||||
"discardCurrent": "Discard Current",
|
||||
"downloadAsImage": "Download As Image",
|
||||
"enableMask": "Enable Mask",
|
||||
@ -1645,31 +1645,16 @@
|
||||
"storeNotInitialized": "Store is not initialized"
|
||||
},
|
||||
"controlLayers": {
|
||||
"generateMode": "Generate",
|
||||
"generateModeDesc": "Create individual images. Generated images are added directly to the gallery.",
|
||||
"composeMode": "Compose",
|
||||
"composeModeDesc": "Compose your work iterative. Generated images are added back to the canvas.",
|
||||
"autoSave": "Auto-save to Gallery",
|
||||
"resetCanvas": "Reset Canvas",
|
||||
"resetAll": "Reset All",
|
||||
"deleteAll": "Delete All",
|
||||
"clearCaches": "Clear Caches",
|
||||
"recalculateRects": "Recalculate Rects",
|
||||
"clipToBbox": "Clip Strokes to Bbox",
|
||||
"addLayer": "Add Layer",
|
||||
"moveToFront": "Move to Front",
|
||||
"moveToBack": "Move to Back",
|
||||
"moveForward": "Move Forward",
|
||||
"moveBackward": "Move Backward",
|
||||
"brushSize": "Brush Size",
|
||||
"width": "Width",
|
||||
"zoom": "Zoom",
|
||||
"resetView": "Reset View",
|
||||
"controlLayers": "Control Layers",
|
||||
"globalMaskOpacity": "Global Mask Opacity",
|
||||
"autoNegative": "Auto Negative",
|
||||
"enableAutoNegative": "Enable Auto Negative",
|
||||
"disableAutoNegative": "Disable Auto Negative",
|
||||
"deletePrompt": "Delete Prompt",
|
||||
"resetRegion": "Reset Region",
|
||||
"debugLayers": "Debug Layers",
|
||||
@ -1678,83 +1663,21 @@
|
||||
"addPositivePrompt": "Add $t(common.positivePrompt)",
|
||||
"addNegativePrompt": "Add $t(common.negativePrompt)",
|
||||
"addIPAdapter": "Add $t(common.ipAdapter)",
|
||||
"regionalGuidance": "Regional Guidance",
|
||||
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",
|
||||
"raster": "Raster",
|
||||
"rasterLayer_one": "Raster Layer",
|
||||
"controlLayer_one": "Control Layer",
|
||||
"inpaintMask_one": "Inpaint Mask",
|
||||
"regionalGuidance_one": "Regional Guidance",
|
||||
"ipAdapter_one": "IP Adapter",
|
||||
"rasterLayer_other": "Raster Layers",
|
||||
"controlLayer_other": "Control Layers",
|
||||
"inpaintMask_other": "Inpaint Masks",
|
||||
"regionalGuidance_other": "Regional Guidance",
|
||||
"ipAdapter_other": "IP Adapters",
|
||||
"opacity": "Opacity",
|
||||
"regionalGuidance_withCount_hidden": "Regional Guidance ({{count}} hidden)",
|
||||
"controlAdapters_withCount_hidden": "Control Adapters ({{count}} hidden)",
|
||||
"controlLayers_withCount_hidden": "Control Layers ({{count}} hidden)",
|
||||
"rasterLayers_withCount_hidden": "Raster Layers ({{count}} hidden)",
|
||||
"ipAdapters_withCount_hidden": "IP Adapters ({{count}} hidden)",
|
||||
"inpaintMasks_withCount_hidden": "Inpaint Masks ({{count}} hidden)",
|
||||
"regionalGuidance_withCount_visible": "Regional Guidance ({{count}})",
|
||||
"controlAdapters_withCount_visible": "Control Adapters ({{count}})",
|
||||
"controlLayers_withCount_visible": "Control Layers ({{count}})",
|
||||
"rasterLayers_withCount_visible": "Raster Layers ({{count}})",
|
||||
"ipAdapters_withCount_visible": "IP Adapters ({{count}})",
|
||||
"inpaintMasks_withCount_visible": "Inpaint Masks ({{count}})",
|
||||
"globalControlAdapter": "Global $t(controlnet.controlAdapter_one)",
|
||||
"globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",
|
||||
"globalIPAdapter": "Global $t(common.ipAdapter)",
|
||||
"globalIPAdapterLayer": "Global $t(common.ipAdapter) $t(unifiedCanvas.layer)",
|
||||
"globalInitialImage": "Global Initial Image",
|
||||
"globalInitialImageLayer": "$t(controlLayers.globalInitialImage) $t(unifiedCanvas.layer)",
|
||||
"layer": "Layer",
|
||||
"opacityFilter": "Opacity Filter",
|
||||
"clearProcessor": "Clear Processor",
|
||||
"resetProcessor": "Reset Processor to Defaults",
|
||||
"noLayersAdded": "No Layers Added",
|
||||
"layers_one": "Layer",
|
||||
"layers_other": "Layers",
|
||||
"objects_zero": "empty",
|
||||
"objects_one": "{{count}} object",
|
||||
"objects_other": "{{count}} objects",
|
||||
"convertToControlLayer": "Convert to Control Layer",
|
||||
"convertToRasterLayer": "Convert to Raster Layer",
|
||||
"transparency": "Transparency",
|
||||
"enableTransparencyEffect": "Enable Transparency Effect",
|
||||
"disableTransparencyEffect": "Disable Transparency Effect",
|
||||
"hidingType": "Hiding {{type}}",
|
||||
"showingType": "Showing {{type}}",
|
||||
"dynamicGrid": "Dynamic Grid",
|
||||
"logDebugInfo": "Log Debug Info",
|
||||
"fill": {
|
||||
"fillStyle": "Fill Style",
|
||||
"solid": "Solid",
|
||||
"grid": "Grid",
|
||||
"crosshatch": "Crosshatch",
|
||||
"vertical": "Vertical",
|
||||
"horizontal": "Horizontal",
|
||||
"diagonal": "Diagonal"
|
||||
},
|
||||
"tool": {
|
||||
"brush": "Brush",
|
||||
"eraser": "Eraser",
|
||||
"rectangle": "Rectangle",
|
||||
"bbox": "Bbox",
|
||||
"move": "Move",
|
||||
"view": "View",
|
||||
"transform": "Transform",
|
||||
"colorPicker": "Color Picker"
|
||||
},
|
||||
"filter": {
|
||||
"filter": "Filter",
|
||||
"filters": "Filters",
|
||||
"filterType": "Filter Type",
|
||||
"preview": "Preview",
|
||||
"apply": "Apply",
|
||||
"cancel": "Cancel"
|
||||
}
|
||||
"layers_other": "Layers"
|
||||
},
|
||||
"upscaling": {
|
||||
"upscale": "Upscale",
|
||||
@ -1842,30 +1765,5 @@
|
||||
"upscaling": "Upscaling",
|
||||
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
|
||||
}
|
||||
},
|
||||
"system": {
|
||||
"enableLogging": "Enable Logging",
|
||||
"logLevel": {
|
||||
"logLevel": "Log Level",
|
||||
"trace": "Trace",
|
||||
"debug": "Debug",
|
||||
"info": "Info",
|
||||
"warn": "Warn",
|
||||
"error": "Error",
|
||||
"fatal": "Fatal"
|
||||
},
|
||||
"logNamespaces": {
|
||||
"logNamespaces": "Log Namespaces",
|
||||
"gallery": "Gallery",
|
||||
"models": "Models",
|
||||
"config": "Config",
|
||||
"canvas": "Canvas",
|
||||
"generation": "Generation",
|
||||
"workflows": "Workflows",
|
||||
"system": "System",
|
||||
"events": "Events",
|
||||
"queue": "Queue",
|
||||
"metadata": "Metadata"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -38,7 +38,7 @@ async function generateTypes(schema) {
|
||||
process.stdout.write(`\nOK!\r\n`);
|
||||
}
|
||||
|
||||
function main() {
|
||||
async function main() {
|
||||
const encoding = 'utf-8';
|
||||
|
||||
if (process.stdin.isTTY) {
|
||||
|
@ -6,7 +6,6 @@ import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/ap
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import type { PartialAppConfig } from 'app/types/invokeai';
|
||||
import ImageUploadOverlay from 'common/components/ImageUploadOverlay';
|
||||
import { useScopeFocusWatcher } from 'common/hooks/interactionScopes';
|
||||
import { useClearStorage } from 'common/hooks/useClearStorage';
|
||||
import { useFullscreenDropzone } from 'common/hooks/useFullscreenDropzone';
|
||||
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
|
||||
@ -14,13 +13,13 @@ import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardMo
|
||||
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
|
||||
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
|
||||
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
||||
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
|
||||
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||
import { AppContent } from 'features/ui/components/AppContent';
|
||||
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import type { TabName } from 'features/ui/store/uiTypes';
|
||||
import { useGetAndLoadLibraryWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadLibraryWorkflow';
|
||||
import { AnimatePresence } from 'framer-motion';
|
||||
import i18n from 'i18n';
|
||||
@ -41,10 +40,17 @@ interface Props {
|
||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||
};
|
||||
selectedWorkflowId?: string;
|
||||
destination?: TabName | undefined;
|
||||
selectedStylePresetId?: string;
|
||||
destination?: InvokeTabName | undefined;
|
||||
}
|
||||
|
||||
const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, destination }: Props) => {
|
||||
const App = ({
|
||||
config = DEFAULT_CONFIG,
|
||||
selectedImage,
|
||||
selectedWorkflowId,
|
||||
selectedStylePresetId,
|
||||
destination,
|
||||
}: Props) => {
|
||||
const language = useAppSelector(languageSelector);
|
||||
const logger = useLogger('system');
|
||||
const dispatch = useAppDispatch();
|
||||
@ -83,6 +89,12 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
|
||||
}
|
||||
}, [selectedWorkflowId, getAndLoadWorkflow]);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedStylePresetId) {
|
||||
dispatch(activeStylePresetIdChanged(selectedStylePresetId));
|
||||
}
|
||||
}, [dispatch, selectedStylePresetId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (destination) {
|
||||
dispatch(setActiveTab(destination));
|
||||
@ -95,7 +107,6 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
|
||||
|
||||
useStarterModelsToast();
|
||||
useSyncQueueStatus();
|
||||
useScopeFocusWatcher();
|
||||
|
||||
return (
|
||||
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
|
||||
@ -108,7 +119,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
|
||||
{...dropzone.getRootProps()}
|
||||
>
|
||||
<input {...dropzone.getInputProps()} />
|
||||
<AppContent />
|
||||
<InvokeTabs />
|
||||
<AnimatePresence>
|
||||
{dropzone.isDragActive && isHandlingUpload && (
|
||||
<ImageUploadOverlay dropzone={dropzone} setIsHandlingUpload={setIsHandlingUpload} />
|
||||
@ -119,7 +130,6 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
|
||||
<ChangeBoardModal />
|
||||
<DynamicPromptsModal />
|
||||
<StylePresetModal />
|
||||
<ClearQueueConfirmationsAlertDialog />
|
||||
<PreselectedImage selectedImage={selectedImage} />
|
||||
</ErrorBoundary>
|
||||
);
|
||||
|
@ -19,7 +19,7 @@ import type { PartialAppConfig } from 'app/types/invokeai';
|
||||
import Loading from 'common/components/Loading/Loading';
|
||||
import AppDndContext from 'features/dnd/components/AppDndContext';
|
||||
import type { WorkflowCategory } from 'features/nodes/types/workflow';
|
||||
import type { TabName } from 'features/ui/store/uiTypes';
|
||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import type { PropsWithChildren, ReactNode } from 'react';
|
||||
import React, { lazy, memo, useEffect, useMemo } from 'react';
|
||||
import { Provider } from 'react-redux';
|
||||
@ -45,7 +45,8 @@ interface Props extends PropsWithChildren {
|
||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||
};
|
||||
selectedWorkflowId?: string;
|
||||
destination?: TabName;
|
||||
selectedStylePresetId?: string;
|
||||
destination?: InvokeTabName;
|
||||
customStarUi?: CustomStarUi;
|
||||
socketOptions?: Partial<ManagerOptions & SocketOptions>;
|
||||
isDebugging?: boolean;
|
||||
@ -66,6 +67,7 @@ const InvokeAIUI = ({
|
||||
queueId,
|
||||
selectedImage,
|
||||
selectedWorkflowId,
|
||||
selectedStylePresetId,
|
||||
destination,
|
||||
customStarUi,
|
||||
socketOptions,
|
||||
@ -227,6 +229,7 @@ const InvokeAIUI = ({
|
||||
config={config}
|
||||
selectedImage={selectedImage}
|
||||
selectedWorkflowId={selectedWorkflowId}
|
||||
selectedStylePresetId={selectedStylePresetId}
|
||||
destination={destination}
|
||||
/>
|
||||
</AppDndContext>
|
||||
|
@ -2,7 +2,7 @@ import { useStore } from '@nanostores/react';
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
import { $isDebugging } from 'app/store/nanostores/isDebugging';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import type { MapStore } from 'nanostores';
|
||||
import { atom, map } from 'nanostores';
|
||||
import { useEffect, useMemo } from 'react';
|
||||
@ -18,19 +18,14 @@ declare global {
|
||||
}
|
||||
}
|
||||
|
||||
export type AppSocket = Socket<ServerToClientEvents, ClientToServerEvents>;
|
||||
|
||||
export const $socket = atom<AppSocket | null>(null);
|
||||
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
|
||||
|
||||
const $isSocketInitialized = atom<boolean>(false);
|
||||
export const $isConnected = atom<boolean>(false);
|
||||
|
||||
/**
|
||||
* Initializes the socket.io connection and sets up event listeners.
|
||||
*/
|
||||
export const useSocketIO = () => {
|
||||
const { dispatch, getState } = useAppStore();
|
||||
const dispatch = useAppDispatch();
|
||||
const baseUrl = useStore($baseUrl);
|
||||
const authToken = useStore($authToken);
|
||||
const addlSocketOptions = useStore($socketOptions);
|
||||
@ -66,9 +61,8 @@ export const useSocketIO = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const socket: AppSocket = io(socketUrl, socketOptions);
|
||||
$socket.set(socket);
|
||||
setEventListeners({ socket, dispatch, getState, setIsConnected: $isConnected.set });
|
||||
const socket: Socket<ServerToClientEvents, ClientToServerEvents> = io(socketUrl, socketOptions);
|
||||
setEventListeners({ dispatch, socket });
|
||||
socket.connect();
|
||||
|
||||
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
|
||||
@ -90,5 +84,5 @@ export const useSocketIO = () => {
|
||||
socket.disconnect();
|
||||
$isSocketInitialized.set(false);
|
||||
};
|
||||
}, [dispatch, getState, socketOptions, socketUrl]);
|
||||
}, [dispatch, socketOptions, socketUrl]);
|
||||
};
|
||||
|
@ -15,21 +15,21 @@ export const BASE_CONTEXT = {};
|
||||
|
||||
export const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
|
||||
|
||||
export const zLogNamespace = z.enum([
|
||||
'canvas',
|
||||
'config',
|
||||
'events',
|
||||
'gallery',
|
||||
'generation',
|
||||
'metadata',
|
||||
'models',
|
||||
'system',
|
||||
'queue',
|
||||
'workflows',
|
||||
]);
|
||||
export type LogNamespace = z.infer<typeof zLogNamespace>;
|
||||
export type LoggerNamespace =
|
||||
| 'images'
|
||||
| 'models'
|
||||
| 'config'
|
||||
| 'canvas'
|
||||
| 'generation'
|
||||
| 'nodes'
|
||||
| 'system'
|
||||
| 'socketio'
|
||||
| 'session'
|
||||
| 'queue'
|
||||
| 'dnd'
|
||||
| 'controlLayers';
|
||||
|
||||
export const logger = (namespace: LogNamespace) => $logger.get().child({ namespace });
|
||||
export const logger = (namespace: LoggerNamespace) => $logger.get().child({ namespace });
|
||||
|
||||
export const zLogLevel = z.enum(['trace', 'debug', 'info', 'warn', 'error', 'fatal']);
|
||||
export type LogLevel = z.infer<typeof zLogLevel>;
|
||||
|
@ -3,34 +3,27 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useEffect, useMemo } from 'react';
|
||||
import { ROARR, Roarr } from 'roarr';
|
||||
|
||||
import type { LogNamespace } from './logger';
|
||||
import type { LoggerNamespace } from './logger';
|
||||
import { $logger, BASE_CONTEXT, LOG_LEVEL_MAP, logger } from './logger';
|
||||
|
||||
export const useLogger = (namespace: LogNamespace) => {
|
||||
const logLevel = useAppSelector((s) => s.system.logLevel);
|
||||
const logNamespaces = useAppSelector((s) => s.system.logNamespaces);
|
||||
const logIsEnabled = useAppSelector((s) => s.system.logIsEnabled);
|
||||
export const useLogger = (namespace: LoggerNamespace) => {
|
||||
const consoleLogLevel = useAppSelector((s) => s.system.consoleLogLevel);
|
||||
const shouldLogToConsole = useAppSelector((s) => s.system.shouldLogToConsole);
|
||||
|
||||
// The provided Roarr browser log writer uses localStorage to config logging to console
|
||||
useEffect(() => {
|
||||
if (logIsEnabled) {
|
||||
if (shouldLogToConsole) {
|
||||
// Enable console log output
|
||||
localStorage.setItem('ROARR_LOG', 'true');
|
||||
|
||||
// Use a filter to show only logs of the given level
|
||||
let filter = `context.logLevel:>=${LOG_LEVEL_MAP[logLevel]}`;
|
||||
if (logNamespaces.length > 0) {
|
||||
filter += ` AND (${logNamespaces.map((ns) => `context.namespace:${ns}`).join(' OR ')})`;
|
||||
} else {
|
||||
filter += ' AND context.namespace:undefined';
|
||||
}
|
||||
localStorage.setItem('ROARR_FILTER', filter);
|
||||
localStorage.setItem('ROARR_FILTER', `context.logLevel:>=${LOG_LEVEL_MAP[consoleLogLevel]}`);
|
||||
} else {
|
||||
// Disable console log output
|
||||
localStorage.setItem('ROARR_LOG', 'false');
|
||||
}
|
||||
ROARR.write = createLogWriter();
|
||||
}, [logLevel, logIsEnabled, logNamespaces]);
|
||||
}, [consoleLogLevel, shouldLogToConsole]);
|
||||
|
||||
// Update the module-scoped logger context as needed
|
||||
useEffect(() => {
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { TabName } from 'features/ui/store/uiTypes';
|
||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
|
||||
export const enqueueRequested = createAction<{
|
||||
tabName: TabName;
|
||||
tabName: InvokeTabName;
|
||||
prepend: boolean;
|
||||
}>('app/enqueueRequested');
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { createDraftSafeSelectorCreator, createSelectorCreator, lruMemoize } from '@reduxjs/toolkit';
|
||||
import type { GetSelectorsOptions } from '@reduxjs/toolkit/dist/entities/state_selectors';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
/**
|
||||
@ -20,5 +19,3 @@ export const getSelectorsOptions: GetSelectorsOptions = {
|
||||
argsMemoize: lruMemoize,
|
||||
}),
|
||||
};
|
||||
|
||||
export const createMemoizedAppSelector = createMemoizedSelector.withTypes<RootState>();
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { PersistError, RehydrateError } from 'redux-remember';
|
||||
import { serializeError } from 'serialize-error';
|
||||
|
||||
@ -40,6 +41,6 @@ export const errorHandler = (err: PersistError | RehydrateError) => {
|
||||
} else if (err instanceof RehydrateError) {
|
||||
log.error({ error: serializeError(err) }, 'Problem rehydrating state');
|
||||
} else {
|
||||
log.error({ error: serializeError(err) }, 'Problem in persistence layer');
|
||||
log.error({ error: parseify(err) }, 'Problem in persistence layer');
|
||||
}
|
||||
};
|
||||
|
@ -1,7 +1,9 @@
|
||||
import type { UnknownAction } from '@reduxjs/toolkit';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
import type { Graph } from 'services/api/types';
|
||||
import { socketGeneratorProgress } from 'services/events/actions';
|
||||
|
||||
export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
|
||||
if (isAnyGraphBuilt(action)) {
|
||||
@ -22,5 +24,13 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
|
||||
};
|
||||
}
|
||||
|
||||
if (socketGeneratorProgress.match(action)) {
|
||||
const sanitized = deepClone(action);
|
||||
if (sanitized.payload.data.progress_image) {
|
||||
sanitized.payload.data.progress_image.dataURL = '<Progress image omitted>';
|
||||
}
|
||||
return sanitized;
|
||||
}
|
||||
|
||||
return action;
|
||||
};
|
||||
|
@ -1,7 +1,7 @@
|
||||
import type { TypedStartListening } from '@reduxjs/toolkit';
|
||||
import { createListenerMiddleware } from '@reduxjs/toolkit';
|
||||
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
|
||||
import { addStagingListeners } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
|
||||
import { addCommitStagingAreaImageListener } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
|
||||
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
|
||||
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
|
||||
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
||||
@ -9,6 +9,17 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
|
||||
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
|
||||
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
|
||||
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
|
||||
import { addCanvasCopiedToClipboardListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard';
|
||||
import { addCanvasDownloadedAsImageListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage';
|
||||
import { addCanvasImageToControlNetListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasImageToControlNet';
|
||||
import { addCanvasMaskSavedToGalleryListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMaskSavedToGallery';
|
||||
import { addCanvasMaskToControlNetListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet';
|
||||
import { addCanvasMergedListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMerged';
|
||||
import { addCanvasSavedToGalleryListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery';
|
||||
import { addControlAdapterPreprocessor } from 'app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor';
|
||||
import { addControlNetAutoProcessListener } from 'app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess';
|
||||
import { addControlNetImageProcessedListener } from 'app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed';
|
||||
import { addEnqueueRequestedCanvasListener } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas';
|
||||
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
|
||||
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
|
||||
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
|
||||
@ -26,7 +37,16 @@ import { addModelSelectedListener } from 'app/store/middleware/listenerMiddlewar
|
||||
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
|
||||
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
|
||||
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
|
||||
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
|
||||
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
|
||||
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
|
||||
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
|
||||
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
|
||||
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
|
||||
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
|
||||
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
|
||||
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
|
||||
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged';
|
||||
import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved';
|
||||
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
@ -63,6 +83,7 @@ addGalleryImageClickedListener(startAppListening);
|
||||
addGalleryOffsetChangedListener(startAppListening);
|
||||
|
||||
// User Invoked
|
||||
addEnqueueRequestedCanvasListener(startAppListening);
|
||||
addEnqueueRequestedNodes(startAppListening);
|
||||
addEnqueueRequestedLinear(startAppListening);
|
||||
addEnqueueRequestedUpscale(startAppListening);
|
||||
@ -70,23 +91,32 @@ addAnyEnqueuedListener(startAppListening);
|
||||
addBatchEnqueuedListener(startAppListening);
|
||||
|
||||
// Canvas actions
|
||||
// addCanvasSavedToGalleryListener(startAppListening);
|
||||
// addCanvasMaskSavedToGalleryListener(startAppListening);
|
||||
// addCanvasImageToControlNetListener(startAppListening);
|
||||
// addCanvasMaskToControlNetListener(startAppListening);
|
||||
// addCanvasDownloadedAsImageListener(startAppListening);
|
||||
// addCanvasCopiedToClipboardListener(startAppListening);
|
||||
// addCanvasMergedListener(startAppListening);
|
||||
// addStagingAreaImageSavedListener(startAppListening);
|
||||
// addCommitStagingAreaImageListener(startAppListening);
|
||||
addStagingListeners(startAppListening);
|
||||
addCanvasSavedToGalleryListener(startAppListening);
|
||||
addCanvasMaskSavedToGalleryListener(startAppListening);
|
||||
addCanvasImageToControlNetListener(startAppListening);
|
||||
addCanvasMaskToControlNetListener(startAppListening);
|
||||
addCanvasDownloadedAsImageListener(startAppListening);
|
||||
addCanvasCopiedToClipboardListener(startAppListening);
|
||||
addCanvasMergedListener(startAppListening);
|
||||
addStagingAreaImageSavedListener(startAppListening);
|
||||
addCommitStagingAreaImageListener(startAppListening);
|
||||
|
||||
// Socket.IO
|
||||
addGeneratorProgressEventListener(startAppListening);
|
||||
addInvocationCompleteEventListener(startAppListening);
|
||||
addInvocationErrorEventListener(startAppListening);
|
||||
addInvocationStartedEventListener(startAppListening);
|
||||
addSocketConnectedEventListener(startAppListening);
|
||||
|
||||
// Gallery bulk download
|
||||
addSocketDisconnectedEventListener(startAppListening);
|
||||
addModelLoadEventListener(startAppListening);
|
||||
addModelInstallEventListener(startAppListening);
|
||||
addSocketQueueItemStatusChangedEventListener(startAppListening);
|
||||
addBulkDownloadListeners(startAppListening);
|
||||
|
||||
// ControlNet
|
||||
addControlNetImageProcessedListener(startAppListening);
|
||||
addControlNetAutoProcessListener(startAppListening);
|
||||
|
||||
// Boards
|
||||
addImageAddedToBoardFulfilledListener(startAppListening);
|
||||
addImageRemovedFromBoardFulfilledListener(startAppListening);
|
||||
@ -118,4 +148,4 @@ addAdHocPostProcessingRequestedListener(startAppListening);
|
||||
addDynamicPromptsListener(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening);
|
||||
// addControlAdapterPreprocessor(startAppListening);
|
||||
addControlAdapterPreprocessor(startAppListening);
|
||||
|
@ -1,21 +1,21 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { BatchConfig, ImageDTO } from 'services/api/types';
|
||||
|
||||
const log = logger('queue');
|
||||
|
||||
export const adHocPostProcessingRequested = createAction<{ imageDTO: ImageDTO }>(`upscaling/postProcessingRequested`);
|
||||
|
||||
export const addAdHocPostProcessingRequestedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: adHocPostProcessingRequested,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const log = logger('session');
|
||||
|
||||
const { imageDTO } = action.payload;
|
||||
const state = getState();
|
||||
|
||||
@ -39,9 +39,9 @@ export const addAdHocPostProcessingRequestedListener = (startAppListening: AppSt
|
||||
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
log.debug({ enqueueResult } as SerializableObject, t('queue.graphQueued'));
|
||||
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
|
||||
} catch (error) {
|
||||
log.error({ enqueueBatchArg } as SerializableObject, t('queue.graphFailedToQueue'));
|
||||
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
|
||||
|
||||
if (error instanceof Object && 'status' in error && error.status === 403) {
|
||||
return;
|
||||
|
@ -23,7 +23,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
||||
*/
|
||||
startAppListening({
|
||||
matcher: matchAnyBoardDeleted,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
const deletedBoardId = action.meta.arg.originalArgs;
|
||||
const { autoAddBoardId, selectedBoardId } = state.gallery;
|
||||
@ -44,7 +44,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
||||
// If we archived a board, it may end up hidden. If it's selected or the auto-add board, we should reset those.
|
||||
startAppListening({
|
||||
matcher: boardsApi.endpoints.updateBoard.matchFulfilled,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
const { shouldShowArchivedBoards } = state.gallery;
|
||||
|
||||
@ -61,7 +61,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
||||
// When we hide archived boards, if the selected or the auto-add board is archived, we should reset those.
|
||||
startAppListening({
|
||||
actionCreator: shouldShowArchivedBoardsChanged,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const shouldShowArchivedBoards = action.payload;
|
||||
|
||||
// We only need to take action if we have just hidden archived boards.
|
||||
@ -100,7 +100,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
||||
*/
|
||||
startAppListening({
|
||||
matcher: boardsApi.endpoints.listAllBoards.matchFulfilled,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const boards = action.payload;
|
||||
const state = getState();
|
||||
const { selectedBoardId, autoAddBoardId } = state.gallery;
|
||||
|
@ -1,36 +1,33 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import {
|
||||
rasterLayerAdded,
|
||||
sessionStagingAreaImageAccepted,
|
||||
sessionStagingAreaReset,
|
||||
} from 'features/controlLayers/store/canvasV2Slice';
|
||||
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
|
||||
canvasBatchIdsReset,
|
||||
commitStagingAreaImage,
|
||||
discardStagedImages,
|
||||
resetCanvas,
|
||||
setInitialCanvasImage,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('canvas');
|
||||
const matcher = isAnyOf(commitStagingAreaImage, discardStagedImages, resetCanvas, setInitialCanvasImage);
|
||||
|
||||
export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||
export const addCommitStagingAreaImageListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: sessionStagingAreaReset,
|
||||
effect: async (_, { dispatch }) => {
|
||||
matcher,
|
||||
effect: async (_, { dispatch, getState }) => {
|
||||
const log = logger('canvas');
|
||||
const state = getState();
|
||||
const { batchIds } = state.canvas;
|
||||
|
||||
try {
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.cancelByBatchOrigin.initiate(
|
||||
{ origin: 'canvas' },
|
||||
{ fixedCacheKey: 'cancelByBatchOrigin' }
|
||||
)
|
||||
queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: batchIds }, { fixedCacheKey: 'cancelByBatchIds' })
|
||||
);
|
||||
const { canceled } = await req.unwrap();
|
||||
req.reset();
|
||||
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
|
||||
if (canceled > 0) {
|
||||
log.debug(`Canceled ${canceled} canvas batches`);
|
||||
toast({
|
||||
@ -39,6 +36,7 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||
status: 'success',
|
||||
});
|
||||
}
|
||||
dispatch(canvasBatchIdsReset());
|
||||
} catch {
|
||||
log.error('Failed to cancel canvas batches');
|
||||
toast({
|
||||
@ -49,26 +47,4 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: sessionStagingAreaImageAccepted,
|
||||
effect: (action, api) => {
|
||||
const { index } = action.payload;
|
||||
const state = api.getState();
|
||||
const stagingAreaImage = state.canvasV2.session.stagedImages[index];
|
||||
|
||||
assert(stagingAreaImage, 'No staged image found to accept');
|
||||
const { x, y } = state.canvasV2.bbox.rect;
|
||||
|
||||
const { imageDTO, offsetX, offsetY } = stagingAreaImage;
|
||||
const imageObject = imageDTOToImageObject(imageDTO);
|
||||
const overrides: Partial<CanvasRasterLayerState> = {
|
||||
position: { x: x + offsetX, y: y + offsetY },
|
||||
objects: [imageObject],
|
||||
};
|
||||
|
||||
api.dispatch(rasterLayerAdded({ overrides, isSelected: true }));
|
||||
api.dispatch(sessionStagingAreaReset());
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -4,7 +4,7 @@ import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
|
||||
export const addAnyEnqueuedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
|
||||
effect: (_, { dispatch, getState }) => {
|
||||
effect: async (_, { dispatch, getState }) => {
|
||||
const { data } = selectQueueStatus(getState());
|
||||
|
||||
if (!data || data.processor.is_started) {
|
||||
|
@ -1,14 +1,14 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { setInfillMethod } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { setInfillMethod } from 'features/parameters/store/generationSlice';
|
||||
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
|
||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
|
||||
export const addAppConfigReceivedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const { infill_methods = [], nsfw_methods = [], watermarking_methods = [] } = action.payload;
|
||||
const infillMethod = getState().canvasV2.compositing.infillMethod;
|
||||
const infillMethod = getState().generation.infillMethod;
|
||||
|
||||
if (!infill_methods.includes(infillMethod)) {
|
||||
// if there is no infill method, set it to the first one
|
||||
|
@ -6,7 +6,7 @@ export const appStarted = createAction('app/appStarted');
|
||||
export const addAppStartedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: appStarted,
|
||||
effect: (action, { unsubscribe, cancelActiveListeners }) => {
|
||||
effect: async (action, { unsubscribe, cancelActiveListeners }) => {
|
||||
// this should only run once
|
||||
cancelActiveListeners();
|
||||
unsubscribe();
|
||||
|
@ -1,30 +1,27 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { truncate, upperFirst } from 'lodash-es';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
const log = logger('queue');
|
||||
|
||||
export const addBatchEnqueuedListener = (startAppListening: AppStartListening) => {
|
||||
// success
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
|
||||
effect: (action) => {
|
||||
const enqueueResult = action.payload;
|
||||
effect: async (action) => {
|
||||
const response = action.payload;
|
||||
const arg = action.meta.arg.originalArgs;
|
||||
log.debug({ enqueueResult } as SerializableObject, 'Batch enqueued');
|
||||
logger('queue').debug({ enqueueResult: parseify(response) }, 'Batch enqueued');
|
||||
|
||||
toast({
|
||||
id: 'QUEUE_BATCH_SUCCEEDED',
|
||||
title: t('queue.batchQueued'),
|
||||
status: 'success',
|
||||
description: t('queue.batchQueuedDesc', {
|
||||
count: enqueueResult.enqueued,
|
||||
count: response.enqueued,
|
||||
direction: arg.prepend ? t('queue.front') : t('queue.back'),
|
||||
}),
|
||||
});
|
||||
@ -34,9 +31,9 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
|
||||
// error
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.enqueueBatch.matchRejected,
|
||||
effect: (action) => {
|
||||
effect: async (action) => {
|
||||
const response = action.payload;
|
||||
const batchConfig = action.meta.arg.originalArgs;
|
||||
const arg = action.meta.arg.originalArgs;
|
||||
|
||||
if (!response) {
|
||||
toast({
|
||||
@ -45,7 +42,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
|
||||
status: 'error',
|
||||
description: t('common.unknownError'),
|
||||
});
|
||||
log.error({ batchConfig } as SerializableObject, t('queue.batchFailedToQueue'));
|
||||
logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue'));
|
||||
return;
|
||||
}
|
||||
|
||||
@ -71,7 +68,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
|
||||
description: t('common.unknownError'),
|
||||
});
|
||||
}
|
||||
log.error({ batchConfig, error: serializeError(response) } as SerializableObject, t('queue.batchFailedToQueue'));
|
||||
logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue'));
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,4 +1,7 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
||||
import { controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { allLayersDeleted } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
|
||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
@ -6,22 +9,39 @@ import { imagesApi } from 'services/api/endpoints/images';
|
||||
export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.deleteBoardAndImages.matchFulfilled,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { deleted_images } = action.payload;
|
||||
|
||||
// Remove all deleted images from the UI
|
||||
|
||||
let wasCanvasReset = false;
|
||||
let wasNodeEditorReset = false;
|
||||
let wereControlAdaptersReset = false;
|
||||
let wereControlLayersReset = false;
|
||||
|
||||
const { nodes, canvasV2 } = getState();
|
||||
|
||||
const { canvas, nodes, controlAdapters, controlLayers } = getState();
|
||||
deleted_images.forEach((image_name) => {
|
||||
const imageUsage = getImageUsage(nodes.present, canvasV2, image_name);
|
||||
const imageUsage = getImageUsage(canvas, nodes.present, controlAdapters, controlLayers.present, image_name);
|
||||
|
||||
if (imageUsage.isCanvasImage && !wasCanvasReset) {
|
||||
dispatch(resetCanvas());
|
||||
wasCanvasReset = true;
|
||||
}
|
||||
|
||||
if (imageUsage.isNodesImage && !wasNodeEditorReset) {
|
||||
dispatch(nodeEditorReset());
|
||||
wasNodeEditorReset = true;
|
||||
}
|
||||
|
||||
if (imageUsage.isControlImage && !wereControlAdaptersReset) {
|
||||
dispatch(controlAdaptersReset());
|
||||
wereControlAdaptersReset = true;
|
||||
}
|
||||
|
||||
if (imageUsage.isControlLayerImage && !wereControlLayersReset) {
|
||||
dispatch(allLayersDeleted());
|
||||
wereControlLayersReset = true;
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
@ -1,15 +1,21 @@
|
||||
import { ExternalLink } from '@invoke-ai/ui-library';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import {
|
||||
socketBulkDownloadComplete,
|
||||
socketBulkDownloadError,
|
||||
socketBulkDownloadStarted,
|
||||
} from 'services/events/actions';
|
||||
|
||||
const log = logger('gallery');
|
||||
const log = logger('images');
|
||||
|
||||
export const addBulkDownloadListeners = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.bulkDownloadImages.matchFulfilled,
|
||||
effect: (action) => {
|
||||
effect: async (action) => {
|
||||
log.debug(action.payload, 'Bulk download requested');
|
||||
|
||||
// If we have an item name, we are processing the bulk download locally and should use it as the toast id to
|
||||
@ -27,7 +33,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
|
||||
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.bulkDownloadImages.matchRejected,
|
||||
effect: () => {
|
||||
effect: async () => {
|
||||
log.debug('Bulk download request failed');
|
||||
|
||||
// There isn't any toast to update if we get this event.
|
||||
@ -38,4 +44,55 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketBulkDownloadStarted,
|
||||
effect: async (action) => {
|
||||
// This should always happen immediately after the bulk download request, so we don't need to show a toast here.
|
||||
log.debug(action.payload.data, 'Bulk download preparation started');
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketBulkDownloadComplete,
|
||||
effect: async (action) => {
|
||||
log.debug(action.payload.data, 'Bulk download preparation completed');
|
||||
|
||||
const { bulk_download_item_name } = action.payload.data;
|
||||
|
||||
// TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first
|
||||
const url = `/api/v1/images/download/${bulk_download_item_name}`;
|
||||
|
||||
toast({
|
||||
id: bulk_download_item_name,
|
||||
title: t('gallery.bulkDownloadReady', 'Download ready'),
|
||||
status: 'success',
|
||||
description: (
|
||||
<ExternalLink
|
||||
label={t('gallery.clickToDownload', 'Click here to download')}
|
||||
href={url}
|
||||
download={bulk_download_item_name}
|
||||
/>
|
||||
),
|
||||
duration: null,
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketBulkDownloadError,
|
||||
effect: async (action) => {
|
||||
log.debug(action.payload.data, 'Bulk download preparation failed');
|
||||
|
||||
const { bulk_download_item_name } = action.payload.data;
|
||||
|
||||
toast({
|
||||
id: bulk_download_item_name,
|
||||
title: t('gallery.bulkDownloadFailed'),
|
||||
status: 'error',
|
||||
description: action.payload.data.error,
|
||||
duration: null,
|
||||
});
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -0,0 +1,38 @@
|
||||
import { $logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { canvasCopiedToClipboard } from 'features/canvas/store/actions';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
|
||||
export const addCanvasCopiedToClipboardListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: canvasCopiedToClipboard,
|
||||
effect: async (action, { getState }) => {
|
||||
const moduleLog = $logger.get().child({ namespace: 'canvasCopiedToClipboardListener' });
|
||||
const state = getState();
|
||||
|
||||
try {
|
||||
const blob = getBaseLayerBlob(state);
|
||||
|
||||
copyBlobToClipboard(blob);
|
||||
} catch (err) {
|
||||
moduleLog.error(String(err));
|
||||
toast({
|
||||
id: 'CANVAS_COPY_FAILED',
|
||||
title: t('toast.problemCopyingCanvas'),
|
||||
description: t('toast.problemCopyingCanvasDesc'),
|
||||
status: 'error',
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
toast({
|
||||
id: 'CANVAS_COPY_SUCCEEDED',
|
||||
title: t('toast.canvasCopiedClipboard'),
|
||||
status: 'success',
|
||||
});
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,34 @@
|
||||
import { $logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { canvasDownloadedAsImage } from 'features/canvas/store/actions';
|
||||
import { downloadBlob } from 'features/canvas/util/downloadBlob';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
|
||||
export const addCanvasDownloadedAsImageListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: canvasDownloadedAsImage,
|
||||
effect: async (action, { getState }) => {
|
||||
const moduleLog = $logger.get().child({ namespace: 'canvasSavedToGalleryListener' });
|
||||
const state = getState();
|
||||
|
||||
let blob;
|
||||
try {
|
||||
blob = await getBaseLayerBlob(state);
|
||||
} catch (err) {
|
||||
moduleLog.error(String(err));
|
||||
toast({
|
||||
id: 'CANVAS_DOWNLOAD_FAILED',
|
||||
title: t('toast.problemDownloadingCanvas'),
|
||||
description: t('toast.problemDownloadingCanvasDesc'),
|
||||
status: 'error',
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
downloadBlob(blob, 'canvas.png');
|
||||
toast({ id: 'CANVAS_DOWNLOAD_SUCCEEDED', title: t('toast.canvasDownloaded'), status: 'success' });
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,60 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { canvasImageToControlAdapter } from 'features/canvas/store/actions';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
export const addCanvasImageToControlNetListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: canvasImageToControlAdapter,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const log = logger('canvas');
|
||||
const state = getState();
|
||||
const { id } = action.payload;
|
||||
|
||||
let blob: Blob;
|
||||
try {
|
||||
blob = await getBaseLayerBlob(state, true);
|
||||
} catch (err) {
|
||||
log.error(String(err));
|
||||
toast({
|
||||
id: 'PROBLEM_SAVING_CANVAS',
|
||||
title: t('toast.problemSavingCanvas'),
|
||||
description: t('toast.problemSavingCanvasDesc'),
|
||||
status: 'error',
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const { autoAddBoardId } = state.gallery;
|
||||
|
||||
const imageDTO = await dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({
|
||||
file: new File([blob], 'savedCanvas.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
image_category: 'control',
|
||||
is_intermediate: true,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
crop_visible: false,
|
||||
postUploadAction: {
|
||||
type: 'TOAST',
|
||||
title: t('toast.canvasSentControlnetAssets'),
|
||||
},
|
||||
})
|
||||
).unwrap();
|
||||
|
||||
const { image_name } = imageDTO;
|
||||
|
||||
dispatch(
|
||||
controlAdapterImageChanged({
|
||||
id,
|
||||
controlImage: image_name,
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,60 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { canvasMaskSavedToGallery } from 'features/canvas/store/actions';
|
||||
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
export const addCanvasMaskSavedToGalleryListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: canvasMaskSavedToGallery,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const log = logger('canvas');
|
||||
const state = getState();
|
||||
|
||||
const canvasBlobsAndImageData = await getCanvasData(
|
||||
state.canvas.layerState,
|
||||
state.canvas.boundingBoxCoordinates,
|
||||
state.canvas.boundingBoxDimensions,
|
||||
state.canvas.isMaskEnabled,
|
||||
state.canvas.shouldPreserveMaskedArea
|
||||
);
|
||||
|
||||
if (!canvasBlobsAndImageData) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { maskBlob } = canvasBlobsAndImageData;
|
||||
|
||||
if (!maskBlob) {
|
||||
log.error('Problem getting mask layer blob');
|
||||
toast({
|
||||
id: 'PROBLEM_SAVING_MASK',
|
||||
title: t('toast.problemSavingMask'),
|
||||
description: t('toast.problemSavingMaskDesc'),
|
||||
status: 'error',
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const { autoAddBoardId } = state.gallery;
|
||||
|
||||
dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({
|
||||
file: new File([maskBlob], 'canvasMaskImage.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
image_category: 'mask',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
crop_visible: true,
|
||||
postUploadAction: {
|
||||
type: 'TOAST',
|
||||
title: t('toast.maskSavedAssets'),
|
||||
},
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,70 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { canvasMaskToControlAdapter } from 'features/canvas/store/actions';
|
||||
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
||||
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
export const addCanvasMaskToControlNetListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: canvasMaskToControlAdapter,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const log = logger('canvas');
|
||||
const state = getState();
|
||||
const { id } = action.payload;
|
||||
const canvasBlobsAndImageData = await getCanvasData(
|
||||
state.canvas.layerState,
|
||||
state.canvas.boundingBoxCoordinates,
|
||||
state.canvas.boundingBoxDimensions,
|
||||
state.canvas.isMaskEnabled,
|
||||
state.canvas.shouldPreserveMaskedArea
|
||||
);
|
||||
|
||||
if (!canvasBlobsAndImageData) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { maskBlob } = canvasBlobsAndImageData;
|
||||
|
||||
if (!maskBlob) {
|
||||
log.error('Problem getting mask layer blob');
|
||||
toast({
|
||||
id: 'PROBLEM_IMPORTING_MASK',
|
||||
title: t('toast.problemImportingMask'),
|
||||
description: t('toast.problemImportingMaskDesc'),
|
||||
status: 'error',
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const { autoAddBoardId } = state.gallery;
|
||||
|
||||
const imageDTO = await dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({
|
||||
file: new File([maskBlob], 'canvasMaskImage.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
image_category: 'mask',
|
||||
is_intermediate: true,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
crop_visible: false,
|
||||
postUploadAction: {
|
||||
type: 'TOAST',
|
||||
title: t('toast.maskSentControlnetAssets'),
|
||||
},
|
||||
})
|
||||
).unwrap();
|
||||
|
||||
const { image_name } = imageDTO;
|
||||
|
||||
dispatch(
|
||||
controlAdapterImageChanged({
|
||||
id,
|
||||
controlImage: image_name,
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,73 @@
|
||||
import { $logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { canvasMerged } from 'features/canvas/store/actions';
|
||||
import { $canvasBaseLayer } from 'features/canvas/store/canvasNanostore';
|
||||
import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
|
||||
import { getFullBaseLayerBlob } from 'features/canvas/util/getFullBaseLayerBlob';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
export const addCanvasMergedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: canvasMerged,
|
||||
effect: async (action, { dispatch }) => {
|
||||
const moduleLog = $logger.get().child({ namespace: 'canvasCopiedToClipboardListener' });
|
||||
const blob = await getFullBaseLayerBlob();
|
||||
|
||||
if (!blob) {
|
||||
moduleLog.error('Problem getting base layer blob');
|
||||
toast({
|
||||
id: 'PROBLEM_MERGING_CANVAS',
|
||||
title: t('toast.problemMergingCanvas'),
|
||||
description: t('toast.problemMergingCanvasDesc'),
|
||||
status: 'error',
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const canvasBaseLayer = $canvasBaseLayer.get();
|
||||
|
||||
if (!canvasBaseLayer) {
|
||||
moduleLog.error('Problem getting canvas base layer');
|
||||
toast({
|
||||
id: 'PROBLEM_MERGING_CANVAS',
|
||||
title: t('toast.problemMergingCanvas'),
|
||||
description: t('toast.problemMergingCanvasDesc'),
|
||||
status: 'error',
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const baseLayerRect = canvasBaseLayer.getClientRect({
|
||||
relativeTo: canvasBaseLayer.getParent() ?? undefined,
|
||||
});
|
||||
|
||||
const imageDTO = await dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({
|
||||
file: new File([blob], 'mergedCanvas.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
image_category: 'general',
|
||||
is_intermediate: true,
|
||||
postUploadAction: {
|
||||
type: 'TOAST',
|
||||
title: t('toast.canvasMerged'),
|
||||
},
|
||||
})
|
||||
).unwrap();
|
||||
|
||||
// TODO: I can't figure out how to do the type narrowing in the `take()` so just brute forcing it here
|
||||
const { image_name } = imageDTO;
|
||||
|
||||
dispatch(
|
||||
setMergedCanvas({
|
||||
kind: 'image',
|
||||
layer: 'base',
|
||||
imageName: image_name,
|
||||
...baseLayerRect,
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,53 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { canvasSavedToGallery } from 'features/canvas/store/actions';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: canvasSavedToGallery,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const log = logger('canvas');
|
||||
const state = getState();
|
||||
|
||||
let blob;
|
||||
try {
|
||||
blob = await getBaseLayerBlob(state);
|
||||
} catch (err) {
|
||||
log.error(String(err));
|
||||
toast({
|
||||
id: 'CANVAS_SAVE_FAILED',
|
||||
title: t('toast.problemSavingCanvas'),
|
||||
description: t('toast.problemSavingCanvasDesc'),
|
||||
status: 'error',
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const { autoAddBoardId } = state.gallery;
|
||||
|
||||
dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({
|
||||
file: new File([blob], 'savedCanvas.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
image_category: 'general',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
crop_visible: true,
|
||||
postUploadAction: {
|
||||
type: 'TOAST',
|
||||
title: t('toast.canvasSavedGallery'),
|
||||
},
|
||||
metadata: {
|
||||
_canvas_objects: parseify(state.canvas.layerState.objects),
|
||||
},
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,194 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppDispatch } from 'app/store/store';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import {
|
||||
caLayerImageChanged,
|
||||
caLayerModelChanged,
|
||||
caLayerProcessedImageChanged,
|
||||
caLayerProcessorConfigChanged,
|
||||
caLayerProcessorPendingBatchIdChanged,
|
||||
caLayerRecalled,
|
||||
isControlAdapterLayer,
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { getImageDTO } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { BatchConfig } from 'services/api/types';
|
||||
import { socketInvocationComplete } from 'services/events/actions';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const matcher = isAnyOf(
|
||||
caLayerImageChanged,
|
||||
caLayerProcessedImageChanged,
|
||||
caLayerProcessorConfigChanged,
|
||||
caLayerModelChanged,
|
||||
caLayerRecalled
|
||||
);
|
||||
|
||||
const DEBOUNCE_MS = 300;
|
||||
const log = logger('session');
|
||||
|
||||
/**
|
||||
* Simple helper to cancel a batch and reset the pending batch ID
|
||||
*/
|
||||
const cancelProcessorBatch = async (dispatch: AppDispatch, layerId: string, batchId: string) => {
|
||||
const req = dispatch(queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: [batchId] }));
|
||||
log.trace({ batchId }, 'Cancelling existing preprocessor batch');
|
||||
try {
|
||||
await req.unwrap();
|
||||
} catch {
|
||||
// no-op
|
||||
} finally {
|
||||
req.reset();
|
||||
// Always reset the pending batch ID - the cancel req could fail if the batch doesn't exist
|
||||
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: null }));
|
||||
}
|
||||
};
|
||||
|
||||
export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher,
|
||||
effect: async (action, { dispatch, getState, getOriginalState, cancelActiveListeners, delay, take, signal }) => {
|
||||
const layerId = caLayerRecalled.match(action) ? action.payload.id : action.payload.layerId;
|
||||
const state = getState();
|
||||
const originalState = getOriginalState();
|
||||
|
||||
// Cancel any in-progress instances of this listener
|
||||
cancelActiveListeners();
|
||||
log.trace('Control Layer CA auto-process triggered');
|
||||
|
||||
// Delay before starting actual work
|
||||
await delay(DEBOUNCE_MS);
|
||||
|
||||
const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId);
|
||||
|
||||
if (!layer) {
|
||||
return;
|
||||
}
|
||||
|
||||
// We should only process if the processor settings or image have changed
|
||||
const originalLayer = originalState.controlLayers.present.layers
|
||||
.filter(isControlAdapterLayer)
|
||||
.find((l) => l.id === layerId);
|
||||
const originalImage = originalLayer?.controlAdapter.image;
|
||||
const originalConfig = originalLayer?.controlAdapter.processorConfig;
|
||||
|
||||
const image = layer.controlAdapter.image;
|
||||
const processedImage = layer.controlAdapter.processedImage;
|
||||
const config = layer.controlAdapter.processorConfig;
|
||||
|
||||
if (isEqual(config, originalConfig) && isEqual(image, originalImage) && processedImage) {
|
||||
// Neither config nor image have changed, we can bail
|
||||
return;
|
||||
}
|
||||
|
||||
if (!image || !config) {
|
||||
// - If we have no image, we have nothing to process
|
||||
// - If we have no processor config, we have nothing to process
|
||||
// Clear the processed image and bail
|
||||
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null }));
|
||||
return;
|
||||
}
|
||||
|
||||
// At this point, the user has stopped fiddling with the processor settings and there is a processor selected.
|
||||
|
||||
// If there is a pending processor batch, cancel it.
|
||||
if (layer.controlAdapter.processorPendingBatchId) {
|
||||
cancelProcessorBatch(dispatch, layerId, layer.controlAdapter.processorPendingBatchId);
|
||||
}
|
||||
|
||||
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
|
||||
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config as never);
|
||||
const enqueueBatchArg: BatchConfig = {
|
||||
prepend: true,
|
||||
batch: {
|
||||
graph: {
|
||||
nodes: {
|
||||
[processorNode.id]: {
|
||||
...processorNode,
|
||||
// Control images are always intermediate - do not save to gallery
|
||||
is_intermediate: true,
|
||||
},
|
||||
},
|
||||
edges: [],
|
||||
},
|
||||
runs: 1,
|
||||
},
|
||||
};
|
||||
|
||||
// Kick off the processor batch
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
|
||||
try {
|
||||
const enqueueResult = await req.unwrap();
|
||||
// TODO(psyche): Update the pydantic models, pretty sure we will _always_ have a batch_id here, but the model says it's optional
|
||||
assert(enqueueResult.batch.batch_id, 'Batch ID not returned from queue');
|
||||
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: enqueueResult.batch.batch_id }));
|
||||
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
|
||||
|
||||
// Wait for the processor node to complete
|
||||
const [invocationCompleteAction] = await take(
|
||||
(action): action is ReturnType<typeof socketInvocationComplete> =>
|
||||
socketInvocationComplete.match(action) &&
|
||||
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
|
||||
action.payload.data.invocation_source_id === processorNode.id
|
||||
);
|
||||
|
||||
// We still have to check the output type
|
||||
assert(
|
||||
invocationCompleteAction.payload.data.result.type === 'image_output',
|
||||
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
|
||||
);
|
||||
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
||||
|
||||
const imageDTO = await getImageDTO(image_name);
|
||||
assert(imageDTO, "Failed to fetch processor output's image DTO");
|
||||
|
||||
// Whew! We made it. Update the layer with the processed image
|
||||
log.debug({ layerId, imageDTO }, 'ControlNet image processed');
|
||||
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO }));
|
||||
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: null }));
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
// The listener was canceled - we need to cancel the pending processor batch, if there is one (could have changed by now).
|
||||
const pendingBatchId = getState()
|
||||
.controlLayers.present.layers.filter(isControlAdapterLayer)
|
||||
.find((l) => l.id === layerId)?.controlAdapter.processorPendingBatchId;
|
||||
if (pendingBatchId) {
|
||||
cancelProcessorBatch(dispatch, layerId, pendingBatchId);
|
||||
}
|
||||
log.trace('Control Adapter preprocessor cancelled');
|
||||
} else {
|
||||
// Some other error condition...
|
||||
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
|
||||
|
||||
if (error instanceof Object) {
|
||||
if ('data' in error && 'status' in error) {
|
||||
if (error.status === 403) {
|
||||
dispatch(caLayerImageChanged({ layerId, imageDTO: null }));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toast({
|
||||
id: 'GRAPH_QUEUE_FAILED',
|
||||
title: t('queue.graphFailedToQueue'),
|
||||
status: 'error',
|
||||
});
|
||||
}
|
||||
} finally {
|
||||
req.reset();
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,85 @@
|
||||
import type { AnyListenerPredicate } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { controlAdapterImageProcessed } from 'features/controlAdapters/store/actions';
|
||||
import {
|
||||
controlAdapterAutoConfigToggled,
|
||||
controlAdapterImageChanged,
|
||||
controlAdapterModelChanged,
|
||||
controlAdapterProcessorParamsChanged,
|
||||
controlAdapterProcessortTypeChanged,
|
||||
selectControlAdapterById,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||
|
||||
type AnyControlAdapterParamChangeAction =
|
||||
| ReturnType<typeof controlAdapterProcessorParamsChanged>
|
||||
| ReturnType<typeof controlAdapterModelChanged>
|
||||
| ReturnType<typeof controlAdapterImageChanged>
|
||||
| ReturnType<typeof controlAdapterProcessortTypeChanged>
|
||||
| ReturnType<typeof controlAdapterAutoConfigToggled>;
|
||||
|
||||
const predicate: AnyListenerPredicate<RootState> = (action, state, prevState) => {
|
||||
const isActionMatched =
|
||||
controlAdapterProcessorParamsChanged.match(action) ||
|
||||
controlAdapterModelChanged.match(action) ||
|
||||
controlAdapterImageChanged.match(action) ||
|
||||
controlAdapterProcessortTypeChanged.match(action) ||
|
||||
controlAdapterAutoConfigToggled.match(action);
|
||||
|
||||
if (!isActionMatched) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const { id } = action.payload;
|
||||
const prevCA = selectControlAdapterById(prevState.controlAdapters, id);
|
||||
const ca = selectControlAdapterById(state.controlAdapters, id);
|
||||
if (!prevCA || !isControlNetOrT2IAdapter(prevCA) || !ca || !isControlNetOrT2IAdapter(ca)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (controlAdapterAutoConfigToggled.match(action)) {
|
||||
// do not process if the user just disabled auto-config
|
||||
if (prevCA.shouldAutoConfig === true) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const { controlImage, processorType, shouldAutoConfig } = ca;
|
||||
if (controlAdapterModelChanged.match(action) && !shouldAutoConfig) {
|
||||
// do not process if the action is a model change but the processor settings are dirty
|
||||
return false;
|
||||
}
|
||||
|
||||
const isProcessorSelected = processorType !== 'none';
|
||||
|
||||
const hasControlImage = Boolean(controlImage);
|
||||
|
||||
return isProcessorSelected && hasControlImage;
|
||||
};
|
||||
|
||||
const DEBOUNCE_MS = 300;
|
||||
|
||||
/**
|
||||
* Listener that automatically processes a ControlNet image when its processor parameters are changed.
|
||||
*
|
||||
* The network request is debounced.
|
||||
*/
|
||||
export const addControlNetAutoProcessListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
predicate,
|
||||
effect: async (action, { dispatch, cancelActiveListeners, delay }) => {
|
||||
const log = logger('session');
|
||||
const { id } = (action as AnyControlAdapterParamChangeAction).payload;
|
||||
|
||||
// Cancel any in-progress instances of this listener
|
||||
cancelActiveListeners();
|
||||
log.trace('ControlNet auto-process triggered');
|
||||
// Delay before starting actual work
|
||||
await delay(DEBOUNCE_MS);
|
||||
|
||||
dispatch(controlAdapterImageProcessed({ id }));
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,118 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { controlAdapterImageProcessed } from 'features/controlAdapters/store/actions';
|
||||
import {
|
||||
controlAdapterImageChanged,
|
||||
controlAdapterProcessedImageChanged,
|
||||
pendingControlImagesCleared,
|
||||
selectControlAdapterById,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { BatchConfig, ImageDTO } from 'services/api/types';
|
||||
import { socketInvocationComplete } from 'services/events/actions';
|
||||
|
||||
export const addControlNetImageProcessedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: controlAdapterImageProcessed,
|
||||
effect: async (action, { dispatch, getState, take }) => {
|
||||
const log = logger('session');
|
||||
const { id } = action.payload;
|
||||
const ca = selectControlAdapterById(getState().controlAdapters, id);
|
||||
|
||||
if (!ca?.controlImage || !isControlNetOrT2IAdapter(ca)) {
|
||||
log.error('Unable to process ControlNet image');
|
||||
return;
|
||||
}
|
||||
|
||||
if (ca.processorType === 'none' || ca.processorNode.type === 'none') {
|
||||
return;
|
||||
}
|
||||
|
||||
// ControlNet one-off procressing graph is just the processor node, no edges.
|
||||
// Also we need to grab the image.
|
||||
|
||||
const nodeId = ca.processorNode.id;
|
||||
const enqueueBatchArg: BatchConfig = {
|
||||
prepend: true,
|
||||
batch: {
|
||||
graph: {
|
||||
nodes: {
|
||||
[ca.processorNode.id]: {
|
||||
...ca.processorNode,
|
||||
is_intermediate: true,
|
||||
use_cache: false,
|
||||
image: { image_name: ca.controlImage },
|
||||
},
|
||||
},
|
||||
edges: [],
|
||||
},
|
||||
runs: 1,
|
||||
},
|
||||
};
|
||||
|
||||
try {
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
|
||||
|
||||
const [invocationCompleteAction] = await take(
|
||||
(action): action is ReturnType<typeof socketInvocationComplete> =>
|
||||
socketInvocationComplete.match(action) &&
|
||||
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
|
||||
action.payload.data.invocation_source_id === nodeId
|
||||
);
|
||||
|
||||
// We still have to check the output type
|
||||
if (invocationCompleteAction.payload.data.result.type === 'image_output') {
|
||||
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
||||
|
||||
// Wait for the ImageDTO to be received
|
||||
const [{ payload }] = await take(
|
||||
(action) =>
|
||||
imagesApi.endpoints.getImageDTO.matchFulfilled(action) && action.payload.image_name === image_name
|
||||
);
|
||||
|
||||
const processedControlImage = payload as ImageDTO;
|
||||
|
||||
log.debug({ controlNetId: action.payload, processedControlImage }, 'ControlNet image processed');
|
||||
|
||||
// Update the processed image in the store
|
||||
dispatch(
|
||||
controlAdapterProcessedImageChanged({
|
||||
id,
|
||||
processedControlImage: processedControlImage.image_name,
|
||||
})
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
|
||||
|
||||
if (error instanceof Object) {
|
||||
if ('data' in error && 'status' in error) {
|
||||
if (error.status === 403) {
|
||||
dispatch(pendingControlImagesCleared());
|
||||
dispatch(controlAdapterImageChanged({ id, controlImage: null }));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toast({
|
||||
id: 'GRAPH_QUEUE_FAILED',
|
||||
title: t('queue.graphFailedToQueue'),
|
||||
status: 'error',
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,144 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { canvasBatchIdAdded, stagingAreaInitialized } from 'features/canvas/store/canvasSlice';
|
||||
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
||||
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
|
||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildCanvasGraph } from 'features/nodes/util/graph/canvas/buildCanvasGraph';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* This listener is responsible invoking the canvas. This involves a number of steps:
|
||||
*
|
||||
* 1. Generate image blobs from the canvas layers
|
||||
* 2. Determine the generation mode from the layers (txt2img, img2img, inpaint)
|
||||
* 3. Build the canvas graph
|
||||
* 4. Create the session with the graph
|
||||
* 5. Upload the init image if necessary
|
||||
* 6. Upload the mask image if necessary
|
||||
* 7. Update the init and mask images with the session ID
|
||||
* 8. Initialize the staging area if not yet initialized
|
||||
* 9. Dispatch the sessionReadyToInvoke action to invoke the session
|
||||
*/
|
||||
export const addEnqueueRequestedCanvasListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'canvas',
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const log = logger('queue');
|
||||
const { prepend } = action.payload;
|
||||
const state = getState();
|
||||
|
||||
const { layerState, boundingBoxCoordinates, boundingBoxDimensions, isMaskEnabled, shouldPreserveMaskedArea } =
|
||||
state.canvas;
|
||||
|
||||
// Build canvas blobs
|
||||
const canvasBlobsAndImageData = await getCanvasData(
|
||||
layerState,
|
||||
boundingBoxCoordinates,
|
||||
boundingBoxDimensions,
|
||||
isMaskEnabled,
|
||||
shouldPreserveMaskedArea
|
||||
);
|
||||
|
||||
if (!canvasBlobsAndImageData) {
|
||||
log.error('Unable to create canvas data');
|
||||
return;
|
||||
}
|
||||
|
||||
const { baseBlob, baseImageData, maskBlob, maskImageData } = canvasBlobsAndImageData;
|
||||
|
||||
// Determine the generation mode
|
||||
const generationMode = getCanvasGenerationMode(baseImageData, maskImageData);
|
||||
|
||||
if (state.system.enableImageDebugging) {
|
||||
const baseDataURL = await blobToDataURL(baseBlob);
|
||||
const maskDataURL = await blobToDataURL(maskBlob);
|
||||
openBase64ImageInTab([
|
||||
{ base64: maskDataURL, caption: 'mask b64' },
|
||||
{ base64: baseDataURL, caption: 'image b64' },
|
||||
]);
|
||||
}
|
||||
|
||||
log.debug(`Generation mode: ${generationMode}`);
|
||||
|
||||
// Temp placeholders for the init and mask images
|
||||
let canvasInitImage: ImageDTO | undefined;
|
||||
let canvasMaskImage: ImageDTO | undefined;
|
||||
|
||||
// For img2img and inpaint/outpaint, we need to upload the init images
|
||||
if (['img2img', 'inpaint', 'outpaint'].includes(generationMode)) {
|
||||
// upload the image, saving the request id
|
||||
canvasInitImage = await dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({
|
||||
file: new File([baseBlob], 'canvasInitImage.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
image_category: 'general',
|
||||
is_intermediate: true,
|
||||
})
|
||||
).unwrap();
|
||||
}
|
||||
|
||||
// For inpaint/outpaint, we also need to upload the mask layer
|
||||
if (['inpaint', 'outpaint'].includes(generationMode)) {
|
||||
// upload the image, saving the request id
|
||||
canvasMaskImage = await dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({
|
||||
file: new File([maskBlob], 'canvasMaskImage.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
image_category: 'mask',
|
||||
is_intermediate: true,
|
||||
})
|
||||
).unwrap();
|
||||
}
|
||||
|
||||
const graph = await buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
|
||||
|
||||
log.debug({ graph: parseify(graph) }, `Canvas graph built`);
|
||||
|
||||
// currently this action is just listened to for logging
|
||||
dispatch(canvasGraphBuilt(graph));
|
||||
|
||||
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
||||
|
||||
try {
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
|
||||
const batchId = enqueueResult.batch.batch_id as string; // we know the is a string, backend provides it
|
||||
|
||||
// Prep the canvas staging area if it is not yet initialized
|
||||
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
||||
dispatch(
|
||||
stagingAreaInitialized({
|
||||
boundingBox: {
|
||||
...state.canvas.boundingBoxCoordinates,
|
||||
...state.canvas.boundingBoxDimensions,
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// Associate the session with the canvas session ID
|
||||
dispatch(canvasBatchIdAdded(batchId));
|
||||
} catch {
|
||||
// no-op
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -1,18 +1,10 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { sessionStagingAreaReset, sessionStartedStaging } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { buildGenerationTabGraph } from 'features/nodes/util/graph/generation/buildGenerationTabGraph';
|
||||
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/generation/buildGenerationTabSDXLGraph';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
@ -20,54 +12,32 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'generation',
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const model = state.canvasV2.params.model;
|
||||
const { shouldShowProgressInViewer } = state.ui;
|
||||
const model = state.generation.model;
|
||||
const { prepend } = action.payload;
|
||||
|
||||
const manager = $canvasManager.get();
|
||||
assert(manager, 'No model found in state');
|
||||
let graph;
|
||||
|
||||
let didStartStaging = false;
|
||||
if (!state.canvasV2.session.isStaging && state.canvasV2.session.mode === 'compose') {
|
||||
dispatch(sessionStartedStaging());
|
||||
didStartStaging = true;
|
||||
if (model?.base === 'sdxl') {
|
||||
graph = await buildGenerationTabSDXLGraph(state);
|
||||
} else {
|
||||
graph = await buildGenerationTabGraph(state);
|
||||
}
|
||||
|
||||
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
try {
|
||||
let g: Graph;
|
||||
let noise: Invocation<'noise'>;
|
||||
let posCond: Invocation<'compel' | 'sdxl_compel_prompt'>;
|
||||
|
||||
assert(model, 'No model found in state');
|
||||
const base = model.base;
|
||||
|
||||
if (base === 'sdxl') {
|
||||
const result = await buildSDXLGraph(state, manager);
|
||||
g = result.g;
|
||||
noise = result.noise;
|
||||
posCond = result.posCond;
|
||||
} else if (base === 'sd-1' || base === 'sd-2') {
|
||||
const result = await buildSD1Graph(state, manager);
|
||||
g = result.g;
|
||||
noise = result.noise;
|
||||
posCond = result.posCond;
|
||||
} else {
|
||||
assert(false, `No graph builders for base ${base}`);
|
||||
}
|
||||
|
||||
const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond);
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
req.reset();
|
||||
await req.unwrap();
|
||||
} catch (error) {
|
||||
log.error({ error: serializeError(error) }, 'Failed to enqueue batch');
|
||||
if (didStartStaging && getState().canvasV2.session.isStaging) {
|
||||
dispatch(sessionStagingAreaReset());
|
||||
if (shouldShowProgressInViewer) {
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
}
|
||||
} finally {
|
||||
req.reset();
|
||||
}
|
||||
},
|
||||
});
|
||||
|
@ -29,8 +29,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
batch: {
|
||||
graph,
|
||||
workflow: builtWorkflow,
|
||||
runs: state.canvasV2.params.iterations,
|
||||
origin: 'workflows',
|
||||
runs: state.generation.iterations,
|
||||
},
|
||||
prepend: action.payload.prepend,
|
||||
};
|
||||
|
@ -14,9 +14,9 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
|
||||
const { shouldShowProgressInViewer } = state.ui;
|
||||
const { prepend } = action.payload;
|
||||
|
||||
const { g, noise, posCond } = await buildMultidiffusionUpscaleGraph(state);
|
||||
const graph = await buildMultidiffusionUpscaleGraph(state);
|
||||
|
||||
const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond);
|
||||
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
|
@ -27,7 +27,7 @@ export const galleryImageClicked = createAction<{
|
||||
export const addGalleryImageClickedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: galleryImageClicked,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
|
||||
const state = getState();
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
|
@ -1,27 +1,24 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
|
||||
import { size } from 'lodash-es';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: appInfoApi.endpoints.getOpenAPISchema.matchFulfilled,
|
||||
effect: (action, { getState }) => {
|
||||
const log = logger('system');
|
||||
const schemaJSON = action.payload;
|
||||
|
||||
log.debug({ schemaJSON: parseify(schemaJSON) } as SerializableObject, 'Received OpenAPI schema');
|
||||
log.debug({ schemaJSON: parseify(schemaJSON) }, 'Received OpenAPI schema');
|
||||
const { nodesAllowlist, nodesDenylist } = getState().config;
|
||||
|
||||
const nodeTemplates = parseSchema(schemaJSON, nodesAllowlist, nodesDenylist);
|
||||
|
||||
log.debug({ nodeTemplates } as SerializableObject, `Built ${size(nodeTemplates)} node templates`);
|
||||
log.debug({ nodeTemplates: parseify(nodeTemplates) }, `Built ${size(nodeTemplates)} node templates`);
|
||||
|
||||
$templates.set(nodeTemplates);
|
||||
},
|
||||
@ -33,7 +30,8 @@ export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening
|
||||
// If action.meta.condition === true, the request was canceled/skipped because another request was in flight or
|
||||
// the value was already in the cache. We don't want to log these errors.
|
||||
if (!action.meta.condition) {
|
||||
log.error({ error: serializeError(action.error) }, 'Problem retrieving OpenAPI Schema');
|
||||
const log = logger('system');
|
||||
log.error({ error: parseify(action.error) }, 'Problem retrieving OpenAPI Schema');
|
||||
}
|
||||
},
|
||||
});
|
||||
|
@ -2,13 +2,15 @@ import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
const log = logger('gallery');
|
||||
|
||||
export const addImageAddedToBoardFulfilledListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.addImageToBoard.matchFulfilled,
|
||||
effect: (action) => {
|
||||
const log = logger('images');
|
||||
const { board_id, imageDTO } = action.meta.arg.originalArgs;
|
||||
|
||||
// TODO: update listImages cache for this board
|
||||
|
||||
log.debug({ board_id, imageDTO }, 'Image added to board');
|
||||
},
|
||||
});
|
||||
@ -16,7 +18,9 @@ export const addImageAddedToBoardFulfilledListener = (startAppListening: AppStar
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.addImageToBoard.matchRejected,
|
||||
effect: (action) => {
|
||||
const log = logger('images');
|
||||
const { board_id, imageDTO } = action.meta.arg.originalArgs;
|
||||
|
||||
log.debug({ board_id, imageDTO }, 'Problem adding image to board');
|
||||
},
|
||||
});
|
||||
|
@ -1,7 +1,20 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import { entityDeleted, ipaImageChanged } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
controlAdapterImageChanged,
|
||||
controlAdapterProcessedImageChanged,
|
||||
selectControlAdapterAll,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||
import {
|
||||
isControlAdapterLayer,
|
||||
isInitialImageLayer,
|
||||
isIPAdapterLayer,
|
||||
isRegionalGuidanceLayer,
|
||||
layerDeleted,
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
|
||||
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
@ -13,10 +26,6 @@ import { forEach, intersectionBy } from 'lodash-es';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
const log = logger('gallery');
|
||||
|
||||
//TODO(psyche): handle image deletion (canvas sessions?)
|
||||
|
||||
// Some utils to delete images from different parts of the app
|
||||
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
state.nodes.present.nodes.forEach((node) => {
|
||||
@ -38,37 +47,52 @@ const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: Im
|
||||
});
|
||||
};
|
||||
|
||||
// const deleteControlAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
// state.canvasV2.controlAdapters.entities.forEach(({ id, imageObject, processedImageObject }) => {
|
||||
// if (
|
||||
// imageObject?.image.image_name === imageDTO.image_name ||
|
||||
// processedImageObject?.image.image_name === imageDTO.image_name
|
||||
// ) {
|
||||
// dispatch(caImageChanged({ id, imageDTO: null }));
|
||||
// dispatch(caProcessedImageChanged({ id, imageDTO: null }));
|
||||
// }
|
||||
// });
|
||||
// };
|
||||
|
||||
const deleteIPAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
state.canvasV2.ipAdapters.entities.forEach(({ id, ipAdapter }) => {
|
||||
if (ipAdapter.image?.image_name === imageDTO.image_name) {
|
||||
dispatch(ipaImageChanged({ id, imageDTO: null }));
|
||||
const deleteControlAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
forEach(selectControlAdapterAll(state.controlAdapters), (ca) => {
|
||||
if (
|
||||
ca.controlImage === imageDTO.image_name ||
|
||||
(isControlNetOrT2IAdapter(ca) && ca.processedControlImage === imageDTO.image_name)
|
||||
) {
|
||||
dispatch(
|
||||
controlAdapterImageChanged({
|
||||
id: ca.id,
|
||||
controlImage: null,
|
||||
})
|
||||
);
|
||||
dispatch(
|
||||
controlAdapterProcessedImageChanged({
|
||||
id: ca.id,
|
||||
processedControlImage: null,
|
||||
})
|
||||
);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const deleteLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
state.canvasV2.rasterLayers.entities.forEach(({ id, objects }) => {
|
||||
let shouldDelete = false;
|
||||
for (const obj of objects) {
|
||||
if (obj.type === 'image' && obj.image.image_name === imageDTO.image_name) {
|
||||
shouldDelete = true;
|
||||
break;
|
||||
const deleteControlLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
state.controlLayers.present.layers.forEach((l) => {
|
||||
if (isRegionalGuidanceLayer(l)) {
|
||||
if (l.ipAdapters.some((ipa) => ipa.image?.name === imageDTO.image_name)) {
|
||||
dispatch(layerDeleted(l.id));
|
||||
}
|
||||
}
|
||||
if (shouldDelete) {
|
||||
dispatch(entityDeleted({ entityIdentifier: { id, type: 'raster_layer' } }));
|
||||
if (isControlAdapterLayer(l)) {
|
||||
if (
|
||||
l.controlAdapter.image?.name === imageDTO.image_name ||
|
||||
l.controlAdapter.processedImage?.name === imageDTO.image_name
|
||||
) {
|
||||
dispatch(layerDeleted(l.id));
|
||||
}
|
||||
}
|
||||
if (isIPAdapterLayer(l)) {
|
||||
if (l.ipAdapter.image?.name === imageDTO.image_name) {
|
||||
dispatch(layerDeleted(l.id));
|
||||
}
|
||||
}
|
||||
if (isInitialImageLayer(l)) {
|
||||
if (l.image?.name === imageDTO.image_name) {
|
||||
dispatch(layerDeleted(l.id));
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
@ -121,10 +145,14 @@ export const addImageDeletionListeners = (startAppListening: AppStartListening)
|
||||
}
|
||||
}
|
||||
|
||||
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
|
||||
if (imageUsage.isCanvasImage) {
|
||||
dispatch(resetCanvas());
|
||||
}
|
||||
|
||||
deleteControlAdapterImages(state, dispatch, imageDTO);
|
||||
deleteNodesImages(state, dispatch, imageDTO);
|
||||
// deleteControlAdapterImages(state, dispatch, imageDTO);
|
||||
deleteIPAdapterImages(state, dispatch, imageDTO);
|
||||
deleteLayerImages(state, dispatch, imageDTO);
|
||||
deleteControlLayerImages(state, dispatch, imageDTO);
|
||||
} catch {
|
||||
// no-op
|
||||
} finally {
|
||||
@ -161,11 +189,14 @@ export const addImageDeletionListeners = (startAppListening: AppStartListening)
|
||||
|
||||
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
|
||||
|
||||
if (imagesUsage.some((i) => i.isCanvasImage)) {
|
||||
dispatch(resetCanvas());
|
||||
}
|
||||
|
||||
imageDTOs.forEach((imageDTO) => {
|
||||
deleteControlAdapterImages(state, dispatch, imageDTO);
|
||||
deleteNodesImages(state, dispatch, imageDTO);
|
||||
// deleteControlAdapterImages(state, dispatch, imageDTO);
|
||||
deleteIPAdapterImages(state, dispatch, imageDTO);
|
||||
deleteLayerImages(state, dispatch, imageDTO);
|
||||
deleteControlLayerImages(state, dispatch, imageDTO);
|
||||
});
|
||||
} catch {
|
||||
// no-op
|
||||
@ -189,6 +220,7 @@ export const addImageDeletionListeners = (startAppListening: AppStartListening)
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.deleteImage.matchFulfilled,
|
||||
effect: (action) => {
|
||||
const log = logger('images');
|
||||
log.debug({ imageDTO: action.meta.arg.originalArgs }, 'Image deleted');
|
||||
},
|
||||
});
|
||||
@ -196,6 +228,7 @@ export const addImageDeletionListeners = (startAppListening: AppStartListening)
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.deleteImage.matchRejected,
|
||||
effect: (action) => {
|
||||
const log = logger('images');
|
||||
log.debug({ imageDTO: action.meta.arg.originalArgs }, 'Unable to delete image');
|
||||
},
|
||||
});
|
||||
|
@ -1,18 +1,28 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
controlLayerAdded,
|
||||
ipaImageChanged,
|
||||
rasterLayerAdded,
|
||||
rgIPAdapterImageChanged,
|
||||
} from 'features/controlLayers/store/canvasV2Slice';
|
||||
import type { CanvasControlLayerState, CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
|
||||
controlAdapterImageChanged,
|
||||
controlAdapterIsEnabledChanged,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import {
|
||||
caLayerImageChanged,
|
||||
iiLayerImageChanged,
|
||||
ipaLayerImageChanged,
|
||||
rgLayerIPAdapterImageChanged,
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||
import { imageToCompareChanged, isImageViewerOpenChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
imageSelected,
|
||||
imageToCompareChanged,
|
||||
isImageViewerOpenChanged,
|
||||
selectionChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
@ -21,12 +31,11 @@ export const dndDropped = createAction<{
|
||||
activeData: TypesafeDraggableData;
|
||||
}>('dnd/dndDropped');
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const addImageDroppedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: dndDropped,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const log = logger('dnd');
|
||||
const { activeData, overData } = action.payload;
|
||||
if (!isValidDrop(overData, activeData)) {
|
||||
return;
|
||||
@ -37,21 +46,81 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
} else if (activeData.payloadType === 'GALLERY_SELECTION') {
|
||||
log.debug({ activeData, overData }, `Images (${getState().gallery.selection.length}) dropped`);
|
||||
} else if (activeData.payloadType === 'NODE_FIELD') {
|
||||
log.debug({ activeData, overData }, 'Node field dropped');
|
||||
log.debug({ activeData: parseify(activeData), overData: parseify(overData) }, 'Node field dropped');
|
||||
} else {
|
||||
log.debug({ activeData, overData }, `Unknown payload dropped`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on current image
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'SET_CURRENT_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
dispatch(imageSelected(activeData.payload.imageDTO));
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on ControlNet
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'SET_CONTROL_ADAPTER_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { id } = overData.context;
|
||||
dispatch(
|
||||
controlAdapterImageChanged({
|
||||
id,
|
||||
controlImage: activeData.payload.imageDTO.image_name,
|
||||
})
|
||||
);
|
||||
dispatch(
|
||||
controlAdapterIsEnabledChanged({
|
||||
id,
|
||||
isEnabled: true,
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on Control Adapter Layer
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'SET_CA_LAYER_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { layerId } = overData.context;
|
||||
dispatch(
|
||||
caLayerImageChanged({
|
||||
layerId,
|
||||
imageDTO: activeData.payload.imageDTO,
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on IP Adapter Layer
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'SET_IPA_IMAGE' &&
|
||||
overData.actionType === 'SET_IPA_LAYER_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { id } = overData.context;
|
||||
dispatch(ipaImageChanged({ id, imageDTO: activeData.payload.imageDTO }));
|
||||
const { layerId } = overData.context;
|
||||
dispatch(
|
||||
ipaLayerImageChanged({
|
||||
layerId,
|
||||
imageDTO: activeData.payload.imageDTO,
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -59,48 +128,48 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
* Image dropped on RG Layer IP Adapter
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'SET_RG_IP_ADAPTER_IMAGE' &&
|
||||
overData.actionType === 'SET_RG_LAYER_IP_ADAPTER_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { id, ipAdapterId } = overData.context;
|
||||
dispatch(rgIPAdapterImageChanged({ id, ipAdapterId, imageDTO: activeData.payload.imageDTO }));
|
||||
const { layerId, ipAdapterId } = overData.context;
|
||||
dispatch(
|
||||
rgLayerIPAdapterImageChanged({
|
||||
layerId,
|
||||
ipAdapterId,
|
||||
imageDTO: activeData.payload.imageDTO,
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on Raster layer
|
||||
* Image dropped on II Layer Image
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'ADD_RASTER_LAYER_FROM_IMAGE' &&
|
||||
overData.actionType === 'SET_II_LAYER_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
|
||||
const { x, y } = getState().canvasV2.bbox.rect;
|
||||
const overrides: Partial<CanvasRasterLayerState> = {
|
||||
objects: [imageObject],
|
||||
position: { x, y },
|
||||
};
|
||||
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
|
||||
const { layerId } = overData.context;
|
||||
dispatch(
|
||||
iiLayerImageChanged({
|
||||
layerId,
|
||||
imageDTO: activeData.payload.imageDTO,
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on Raster layer
|
||||
* Image dropped on Canvas
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'ADD_CONTROL_LAYER_FROM_IMAGE' &&
|
||||
overData.actionType === 'SET_CANVAS_INITIAL_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
|
||||
const { x, y } = getState().canvasV2.bbox.rect;
|
||||
const overrides: Partial<CanvasControlLayerState> = {
|
||||
objects: [imageObject],
|
||||
position: { x, y },
|
||||
};
|
||||
dispatch(controlLayerAdded({ overrides, isSelected: true }));
|
||||
dispatch(setInitialCanvasImage(activeData.payload.imageDTO, selectOptimalDimension(getState())));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2,13 +2,13 @@ import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
const log = logger('gallery');
|
||||
|
||||
export const addImageRemovedFromBoardFulfilledListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.removeImageFromBoard.matchFulfilled,
|
||||
effect: (action) => {
|
||||
const log = logger('images');
|
||||
const imageDTO = action.meta.arg.originalArgs;
|
||||
|
||||
log.debug({ imageDTO }, 'Image removed from board');
|
||||
},
|
||||
});
|
||||
@ -16,7 +16,9 @@ export const addImageRemovedFromBoardFulfilledListener = (startAppListening: App
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.removeImageFromBoard.matchRejected,
|
||||
effect: (action) => {
|
||||
const log = logger('images');
|
||||
const imageDTO = action.meta.arg.originalArgs;
|
||||
|
||||
log.debug({ imageDTO }, 'Problem removing image from board');
|
||||
},
|
||||
});
|
||||
|
@ -6,17 +6,16 @@ import { imagesToDeleteSelected, isModalOpenChanged } from 'features/deleteImage
|
||||
export const addImageToDeleteSelectedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: imagesToDeleteSelected,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const imageDTOs = action.payload;
|
||||
const state = getState();
|
||||
const { shouldConfirmOnDelete } = state.system;
|
||||
const imagesUsage = selectImageUsage(getState());
|
||||
|
||||
const isImageInUse =
|
||||
imagesUsage.some((i) => i.isLayerImage) ||
|
||||
imagesUsage.some((i) => i.isControlAdapterImage) ||
|
||||
imagesUsage.some((i) => i.isIPAdapterImage) ||
|
||||
imagesUsage.some((i) => i.isLayerImage);
|
||||
imagesUsage.some((i) => i.isCanvasImage) ||
|
||||
imagesUsage.some((i) => i.isControlImage) ||
|
||||
imagesUsage.some((i) => i.isNodesImage);
|
||||
|
||||
if (shouldConfirmOnDelete || isImageInUse) {
|
||||
dispatch(isModalOpenChanged(true));
|
||||
|
@ -1,8 +1,19 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { ipaImageChanged, rgIPAdapterImageChanged } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
controlAdapterImageChanged,
|
||||
controlAdapterIsEnabledChanged,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import {
|
||||
caLayerImageChanged,
|
||||
iiLayerImageChanged,
|
||||
ipaLayerImageChanged,
|
||||
rgLayerIPAdapterImageChanged,
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
@ -10,12 +21,11 @@ import { omit } from 'lodash-es';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
const log = logger('gallery');
|
||||
|
||||
export const addImageUploadedFulfilledListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.uploadImage.matchFulfilled,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const log = logger('images');
|
||||
const imageDTO = action.payload;
|
||||
const state = getState();
|
||||
const { autoAddBoardId } = state.gallery;
|
||||
@ -71,6 +81,15 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
|
||||
return;
|
||||
}
|
||||
|
||||
if (postUploadAction?.type === 'SET_CANVAS_INITIAL_IMAGE') {
|
||||
dispatch(setInitialCanvasImage(imageDTO, selectOptimalDimension(state)));
|
||||
toast({
|
||||
...DEFAULT_UPLOADED_TOAST,
|
||||
description: t('toast.setAsCanvasInitialImage'),
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (postUploadAction?.type === 'SET_UPSCALE_INITIAL_IMAGE') {
|
||||
dispatch(upscaleInitialImageChanged(imageDTO));
|
||||
toast({
|
||||
@ -80,31 +99,70 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
|
||||
return;
|
||||
}
|
||||
|
||||
// if (postUploadAction?.type === 'SET_CA_IMAGE') {
|
||||
// const { id } = postUploadAction;
|
||||
// dispatch(caImageChanged({ id, imageDTO }));
|
||||
// toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
|
||||
// return;
|
||||
// }
|
||||
|
||||
if (postUploadAction?.type === 'SET_IPA_IMAGE') {
|
||||
if (postUploadAction?.type === 'SET_CONTROL_ADAPTER_IMAGE') {
|
||||
const { id } = postUploadAction;
|
||||
dispatch(ipaImageChanged({ id, imageDTO }));
|
||||
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
|
||||
dispatch(
|
||||
controlAdapterIsEnabledChanged({
|
||||
id,
|
||||
isEnabled: true,
|
||||
})
|
||||
);
|
||||
dispatch(
|
||||
controlAdapterImageChanged({
|
||||
id,
|
||||
controlImage: imageDTO.image_name,
|
||||
})
|
||||
);
|
||||
toast({
|
||||
...DEFAULT_UPLOADED_TOAST,
|
||||
description: t('toast.setControlImage'),
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (postUploadAction?.type === 'SET_RG_IP_ADAPTER_IMAGE') {
|
||||
const { id, ipAdapterId } = postUploadAction;
|
||||
dispatch(rgIPAdapterImageChanged({ id, ipAdapterId, imageDTO }));
|
||||
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
|
||||
return;
|
||||
if (postUploadAction?.type === 'SET_CA_LAYER_IMAGE') {
|
||||
const { layerId } = postUploadAction;
|
||||
dispatch(caLayerImageChanged({ layerId, imageDTO }));
|
||||
toast({
|
||||
...DEFAULT_UPLOADED_TOAST,
|
||||
description: t('toast.setControlImage'),
|
||||
});
|
||||
}
|
||||
|
||||
if (postUploadAction?.type === 'SET_IPA_LAYER_IMAGE') {
|
||||
const { layerId } = postUploadAction;
|
||||
dispatch(ipaLayerImageChanged({ layerId, imageDTO }));
|
||||
toast({
|
||||
...DEFAULT_UPLOADED_TOAST,
|
||||
description: t('toast.setControlImage'),
|
||||
});
|
||||
}
|
||||
|
||||
if (postUploadAction?.type === 'SET_RG_LAYER_IP_ADAPTER_IMAGE') {
|
||||
const { layerId, ipAdapterId } = postUploadAction;
|
||||
dispatch(rgLayerIPAdapterImageChanged({ layerId, ipAdapterId, imageDTO }));
|
||||
toast({
|
||||
...DEFAULT_UPLOADED_TOAST,
|
||||
description: t('toast.setControlImage'),
|
||||
});
|
||||
}
|
||||
|
||||
if (postUploadAction?.type === 'SET_II_LAYER_IMAGE') {
|
||||
const { layerId } = postUploadAction;
|
||||
dispatch(iiLayerImageChanged({ layerId, imageDTO }));
|
||||
toast({
|
||||
...DEFAULT_UPLOADED_TOAST,
|
||||
description: t('toast.setControlImage'),
|
||||
});
|
||||
}
|
||||
|
||||
if (postUploadAction?.type === 'SET_NODES_IMAGE') {
|
||||
const { nodeId, fieldName } = postUploadAction;
|
||||
dispatch(fieldImageValueChanged({ nodeId, fieldName, value: imageDTO }));
|
||||
toast({ ...DEFAULT_UPLOADED_TOAST, description: `${t('toast.setNodeField')} ${fieldName}` });
|
||||
toast({
|
||||
...DEFAULT_UPLOADED_TOAST,
|
||||
description: `${t('toast.setNodeField')} ${fieldName}`,
|
||||
});
|
||||
return;
|
||||
}
|
||||
},
|
||||
@ -113,6 +171,7 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.uploadImage.matchRejected,
|
||||
effect: (action) => {
|
||||
const log = logger('images');
|
||||
const sanitizedData = {
|
||||
arg: {
|
||||
...omit(action.meta.arg.originalArgs, ['file', 'postUploadAction']),
|
||||
|
@ -6,7 +6,7 @@ import type { ImageDTO } from 'services/api/types';
|
||||
export const addImagesStarredListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.starImages.matchFulfilled,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { updated_image_names: starredImages } = action.payload;
|
||||
|
||||
const state = getState();
|
||||
|
@ -6,7 +6,7 @@ import type { ImageDTO } from 'services/api/types';
|
||||
export const addImagesUnstarredListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.unstarImages.matchFulfilled,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { updated_image_names: unstarredImages } = action.payload;
|
||||
|
||||
const state = getState();
|
||||
|
@ -1,17 +1,23 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { loraDeleted, modelChanged, vaeSelected } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import {
|
||||
controlAdapterIsEnabledChanged,
|
||||
selectControlAdapterAll,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { loraRemoved } from 'features/lora/store/loraSlice';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
|
||||
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
|
||||
const log = logger('models');
|
||||
import { forEach } from 'lodash-es';
|
||||
|
||||
export const addModelSelectedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: modelSelected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const log = logger('models');
|
||||
|
||||
const state = getState();
|
||||
const result = zParameterModel.safeParse(action.payload);
|
||||
|
||||
@ -23,36 +29,34 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
const newModel = result.data;
|
||||
|
||||
const newBaseModel = newModel.base;
|
||||
const didBaseModelChange = state.canvasV2.params.model?.base !== newBaseModel;
|
||||
const didBaseModelChange = state.generation.model?.base !== newBaseModel;
|
||||
|
||||
if (didBaseModelChange) {
|
||||
// we may need to reset some incompatible submodels
|
||||
let modelsCleared = 0;
|
||||
|
||||
// handle incompatible loras
|
||||
state.canvasV2.loras.forEach((lora) => {
|
||||
forEach(state.lora.loras, (lora, id) => {
|
||||
if (lora.model.base !== newBaseModel) {
|
||||
dispatch(loraDeleted({ id: lora.id }));
|
||||
dispatch(loraRemoved(id));
|
||||
modelsCleared += 1;
|
||||
}
|
||||
});
|
||||
|
||||
// handle incompatible vae
|
||||
const { vae } = state.canvasV2.params;
|
||||
const { vae } = state.generation;
|
||||
if (vae && vae.base !== newBaseModel) {
|
||||
dispatch(vaeSelected(null));
|
||||
modelsCleared += 1;
|
||||
}
|
||||
|
||||
// handle incompatible controlnets
|
||||
// state.canvasV2.controlAdapters.entities.forEach((ca) => {
|
||||
// if (ca.model?.base !== newBaseModel) {
|
||||
// modelsCleared += 1;
|
||||
// if (ca.isEnabled) {
|
||||
// dispatch(entityIsEnabledToggled({ entityIdentifier: { id: ca.id, type: 'control_adapter' } }));
|
||||
// }
|
||||
// }
|
||||
// });
|
||||
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
|
||||
if (ca.model?.base !== newBaseModel) {
|
||||
dispatch(controlAdapterIsEnabledChanged({ id: ca.id, isEnabled: false }));
|
||||
modelsCleared += 1;
|
||||
}
|
||||
});
|
||||
|
||||
if (modelsCleared > 0) {
|
||||
toast({
|
||||
@ -66,7 +70,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(modelChanged({ model: newModel, previousModel: state.canvasV2.params.model }));
|
||||
dispatch(modelChanged(newModel, state.generation.model));
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,42 +1,36 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import type { JSONObject } from 'common/types';
|
||||
import {
|
||||
bboxHeightChanged,
|
||||
bboxWidthChanged,
|
||||
controlLayerModelChanged,
|
||||
ipaModelChanged,
|
||||
loraDeleted,
|
||||
modelChanged,
|
||||
refinerModelChanged,
|
||||
rgIPAdapterModelChanged,
|
||||
vaeSelected,
|
||||
} from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { calculateNewSize } from 'features/parameters/components/DocumentSize/calculateNewSize';
|
||||
controlAdapterModelCleared,
|
||||
selectControlAdapterAll,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { heightChanged, widthChanged } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { loraRemoved } from 'features/lora/store/loraSlice';
|
||||
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
|
||||
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
|
||||
import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
||||
import { forEach } from 'lodash-es';
|
||||
import type { Logger } from 'roarr';
|
||||
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isControlNetOrT2IAdapterModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
isLoRAModelConfig,
|
||||
isNonRefinerMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isSpandrelImageToImageModelConfig,
|
||||
isVAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
const log = logger('models');
|
||||
|
||||
export const addModelsLoadedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
predicate: modelsApi.endpoints.getModelConfigs.matchFulfilled,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `Models loaded (${action.payload.ids.length})`);
|
||||
|
||||
const state = getState();
|
||||
@ -49,7 +43,6 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
||||
handleLoRAModels(models, state, dispatch, log);
|
||||
handleControlAdapterModels(models, state, dispatch, log);
|
||||
handleSpandrelImageToImageModels(models, state, dispatch, log);
|
||||
handleIPAdapterModels(models, state, dispatch, log);
|
||||
},
|
||||
});
|
||||
};
|
||||
@ -58,15 +51,15 @@ type ModelHandler = (
|
||||
models: AnyModelConfig[],
|
||||
state: RootState,
|
||||
dispatch: AppDispatch,
|
||||
log: Logger<SerializableObject>
|
||||
log: Logger<JSONObject>
|
||||
) => undefined;
|
||||
|
||||
const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const currentModel = state.canvasV2.params.model;
|
||||
const currentModel = state.generation.model;
|
||||
const mainModels = models.filter(isNonRefinerMainModelConfig);
|
||||
if (mainModels.length === 0) {
|
||||
// No models loaded at all
|
||||
dispatch(modelChanged({ model: null }));
|
||||
dispatch(modelChanged(null));
|
||||
return;
|
||||
}
|
||||
|
||||
@ -81,19 +74,25 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
if (defaultModelInList) {
|
||||
const result = zParameterModel.safeParse(defaultModelInList);
|
||||
if (result.success) {
|
||||
dispatch(modelChanged({ model: defaultModelInList, previousModel: currentModel }));
|
||||
dispatch(modelChanged(defaultModelInList, currentModel));
|
||||
|
||||
const optimalDimension = getOptimalDimension(defaultModelInList);
|
||||
if (getIsSizeOptimal(state.canvasV2.bbox.rect.width, state.canvasV2.bbox.rect.height, optimalDimension)) {
|
||||
if (
|
||||
getIsSizeOptimal(
|
||||
state.controlLayers.present.size.width,
|
||||
state.controlLayers.present.size.height,
|
||||
optimalDimension
|
||||
)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
const { width, height } = calculateNewSize(
|
||||
state.canvasV2.bbox.aspectRatio.value,
|
||||
state.controlLayers.present.size.aspectRatio.value,
|
||||
optimalDimension * optimalDimension
|
||||
);
|
||||
|
||||
dispatch(bboxWidthChanged({ width }));
|
||||
dispatch(bboxHeightChanged({ height }));
|
||||
dispatch(widthChanged({ width }));
|
||||
dispatch(heightChanged({ height }));
|
||||
return;
|
||||
}
|
||||
}
|
||||
@ -105,11 +104,11 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(modelChanged({ model: result.data, previousModel: currentModel }));
|
||||
dispatch(modelChanged(result.data, currentModel));
|
||||
};
|
||||
|
||||
const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const currentRefinerModel = state.canvasV2.params.refinerModel;
|
||||
const currentRefinerModel = state.sdxl.refinerModel;
|
||||
const refinerModels = models.filter(isRefinerMainModelModelConfig);
|
||||
if (models.length === 0) {
|
||||
// No models loaded at all
|
||||
@ -128,7 +127,7 @@ const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
};
|
||||
|
||||
const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const currentVae = state.canvasV2.params.vae;
|
||||
const currentVae = state.generation.vae;
|
||||
|
||||
if (currentVae === null) {
|
||||
// null is a valid VAE! it means "use the default with the main model"
|
||||
@ -161,45 +160,28 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
};
|
||||
|
||||
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const loraModels = models.filter(isLoRAModelConfig);
|
||||
state.canvasV2.loras.forEach((lora) => {
|
||||
const isLoRAAvailable = loraModels.some((m) => m.key === lora.model.key);
|
||||
const loras = state.lora.loras;
|
||||
|
||||
forEach(loras, (lora, id) => {
|
||||
const isLoRAAvailable = models.some((m) => m.key === lora.model.key);
|
||||
|
||||
if (isLoRAAvailable) {
|
||||
return;
|
||||
}
|
||||
dispatch(loraDeleted({ id: lora.id }));
|
||||
|
||||
dispatch(loraRemoved(id));
|
||||
});
|
||||
};
|
||||
|
||||
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const caModels = models.filter(isControlNetOrT2IAdapterModelConfig);
|
||||
state.canvasV2.controlLayers.entities.forEach((entity) => {
|
||||
const isModelAvailable = caModels.some((m) => m.key === entity.controlAdapter.model?.key);
|
||||
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
|
||||
const isModelAvailable = models.some((m) => m.key === ca.model?.key);
|
||||
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
}
|
||||
dispatch(controlLayerModelChanged({ id: entity.id, modelConfig: null }));
|
||||
});
|
||||
};
|
||||
|
||||
const handleIPAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const ipaModels = models.filter(isIPAdapterModelConfig);
|
||||
state.canvasV2.ipAdapters.entities.forEach((entity) => {
|
||||
const isModelAvailable = ipaModels.some((m) => m.key === entity.ipAdapter.model?.key);
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
}
|
||||
dispatch(ipaModelChanged({ id: entity.id, modelConfig: null }));
|
||||
});
|
||||
|
||||
state.canvasV2.regions.entities.forEach(({ id, ipAdapters }) => {
|
||||
ipAdapters.forEach(({ id: ipAdapterId, model }) => {
|
||||
const isModelAvailable = ipaModels.some((m) => m.key === model?.key);
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
}
|
||||
dispatch(rgIPAdapterModelChanged({ id, ipAdapterId, modelConfig: null }));
|
||||
});
|
||||
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
||||
});
|
||||
};
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { positivePromptChanged } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { positivePromptChanged } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import {
|
||||
combinatorialToggled,
|
||||
isErrorChanged,
|
||||
@ -15,7 +15,7 @@ import { getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilder
|
||||
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
|
||||
import { utilitiesApi } from 'services/api/endpoints/utilities';
|
||||
import { socketConnected } from 'services/events/setEventListeners';
|
||||
import { socketConnected } from 'services/events/actions';
|
||||
|
||||
const matcher = isAnyOf(
|
||||
positivePromptChanged,
|
||||
@ -24,6 +24,8 @@ const matcher = isAnyOf(
|
||||
maxPromptsReset,
|
||||
socketConnected,
|
||||
activeStylePresetIdChanged,
|
||||
stylePresetsApi.endpoints.deleteStylePreset.matchFulfilled,
|
||||
stylePresetsApi.endpoints.updateStylePreset.matchFulfilled,
|
||||
stylePresetsApi.endpoints.listStylePresets.matchFulfilled
|
||||
);
|
||||
|
||||
|
@ -1,15 +1,14 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { heightChanged, widthChanged } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||
import {
|
||||
bboxHeightChanged,
|
||||
bboxWidthChanged,
|
||||
setCfgRescaleMultiplier,
|
||||
setCfgScale,
|
||||
setScheduler,
|
||||
setSteps,
|
||||
vaePrecisionChanged,
|
||||
vaeSelected,
|
||||
} from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
isParameterCFGRescaleMultiplier,
|
||||
isParameterCFGScale,
|
||||
@ -31,7 +30,7 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
|
||||
const currentModel = state.canvasV2.params.model;
|
||||
const currentModel = state.generation.model;
|
||||
|
||||
if (!currentModel) {
|
||||
return;
|
||||
@ -99,13 +98,13 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
||||
const setSizeOptions = { updateAspectRatio: true, clamp: true };
|
||||
if (width) {
|
||||
if (isParameterWidth(width)) {
|
||||
dispatch(bboxWidthChanged({ width, ...setSizeOptions }));
|
||||
dispatch(widthChanged({ width, ...setSizeOptions }));
|
||||
}
|
||||
}
|
||||
|
||||
if (height) {
|
||||
if (isParameterHeight(height)) {
|
||||
dispatch(bboxHeightChanged({ height, ...setSizeOptions }));
|
||||
dispatch(heightChanged({ height, ...setSizeOptions }));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6,9 +6,9 @@ import { atom } from 'nanostores';
|
||||
import { api } from 'services/api';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
|
||||
import { socketConnected } from 'services/events/setEventListeners';
|
||||
import { socketConnected } from 'services/events/actions';
|
||||
|
||||
const log = logger('events');
|
||||
const log = logger('socketio');
|
||||
|
||||
const $isFirstConnection = atom(true);
|
||||
|
@ -0,0 +1,14 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketDisconnected } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
export const addSocketDisconnectedEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketDisconnected,
|
||||
effect: () => {
|
||||
log.debug('Disconnected');
|
||||
},
|
||||
});
|
||||
};
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user