Compare commits

..

29 Commits

Author SHA1 Message Date
Ryan Dick
87261bdbc9
FLUX memory management improvements (#6791)
## Summary

This PR contains several improvements to memory management for FLUX
workflows.

It is now possible to achieve better FLUX model caching performance, but
this still requires users to manually configure their `ram`/`vram`
settings. E.g. a `vram` setting of 16.0 should allow for all quantized
FLUX models to be kept in memory on the GPU.

Changes:
- Check the size of a model on disk and free the requisite space in the
model cache before loading it. (This behaviour existed previously, but
was removed in https://github.com/invoke-ai/InvokeAI/pull/6072/files.
The removal did not seem to be intentional).
- Removed the hack to free 24GB of space in the cache before loading the
FLUX model.
- Split the T5 embedding and CLIP embedding steps into separate
functions so that the two models don't both have to be held in RAM at
the same time.
- Fix a bug in `InvokeLinear8bitLt` that was causing some tensors to be
left on the GPU when the model was offloaded to the CPU. (This class is
getting very messy due to the non-standard state_dict handling in
`bnb.nn.Linear8bitLt`. )
- Tidy up some dtype handling in FluxTextToImageInvocation to avoid
situations where we hold references to two copies of the same tensor
unnecessarily.
- (minor) Misc cleanup of ModelCache: improve docs and remove unused
vars.

Future:
We should revisit our default ram/vram configs. The current defaults are
very conservative, and users could see major performance improvements
from tuning these values.

## QA Instructions

I tested the FLUX workflow with the following configurations and
verified that the cache hit rates and memory usage matched the expected
behaviour:
- `ram = 16` and `vram = 16`
- `ram = 16` and `vram = 1`
- `ram = 1` and `vram = 1`

Note that the changes in this PR are not isolated to FLUX. Since we now
check the size of models on disk, we may see slight changes in model
cache offload patterns for other models as well.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
2024-08-29 15:17:45 -04:00
Ryan Dick
4e4b6c6dbc Tidy variable management and dtype handling in FluxTextToImageInvocation. 2024-08-29 19:08:18 +00:00
Ryan Dick
5e8cf9fb6a Remove hack to clear cache from the FluxTextToImageInvocation. We now clear the cache based on the on-disk model size. 2024-08-29 19:08:18 +00:00
Ryan Dick
c738fe051f Split T5 encoding and CLIP encoding into separate functions to ensure that all model references are locally-scoped so that the two models don't have to be help in memory at the same time. 2024-08-29 19:08:18 +00:00
Ryan Dick
29fe1533f2 Fix bug in InvokeLinear8bitLt that was causing old state information to persist after loading from a state dict. This manifested as state tensors being left on the GPU even when a model had been offloaded to the CPU cache. 2024-08-29 19:08:18 +00:00
Ryan Dick
77090070bd Check the size of a model on disk and make room for it in the cache before loading it. 2024-08-29 19:08:18 +00:00
Ryan Dick
6ba9b1b6b0 Tidy up GIG -> GB and remove unused GIG constant. 2024-08-29 19:08:18 +00:00
Ryan Dick
c578b8df1e Improve ModelCache docs. 2024-08-29 19:08:18 +00:00
Ryan Dick
cad9a41433 Remove unused MOdelCache.exists(...) function. 2024-08-29 19:08:18 +00:00
Ryan Dick
5fefb3b0f4 Remove unused param from ModelCache. 2024-08-29 19:08:18 +00:00
Ryan Dick
5284a870b0 Remove unused constructor params from ModelCache. 2024-08-29 19:08:18 +00:00
Ryan Dick
e064377c05 Remove default model cache sizes from model_cache_default.py. These defaults were misleading, because the config defaults take precedence over them. 2024-08-29 19:08:18 +00:00
Mary Hipp
3e569c8312 feat(ui): add fields for CLIP embed models and Flux VAE models in workflows 2024-08-29 11:52:51 -04:00
maryhipp
16825ee6e9 feat(nodes): bump version of flux model node, update default workflow 2024-08-29 11:52:51 -04:00
Mary Hipp
3f5340fa53 feat(nodes): add submodels as inputs to FLUX main model node instead of hardcoded names 2024-08-29 11:52:51 -04:00
chainchompa
f2a1a39b33
Add selectedStylePreset to app parameters (#6787)
## Summary
- Add selectedStylePreset to app parameters
<!--A description of the changes in this PR. Include the kind of change
(fix, feature, docs, etc), the "why" and the "how". Screenshots or
videos are useful for frontend changes.-->

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-08-28 10:53:07 -04:00
chainchompa
326de55d3e remove api changes and only preselect style preset 2024-08-28 09:53:29 -04:00
chainchompa
b2df909570 added selectedStylePreset to preload presets when app loads 2024-08-28 09:50:44 -04:00
chainchompa
026ac36b06 Revert "added selectedStylePreset to preload presets when app loads"
This reverts commit e97fd85904e32ebe0a35cc66e75f010970e7dc63.
2024-08-28 09:44:08 -04:00
chainchompa
92125e5fd2 bug fixes 2024-08-27 16:13:38 -04:00
chainchompa
c0c139da88 formatting ruff 2024-08-27 15:46:51 -04:00
chainchompa
404ad6a7fd cleanup 2024-08-27 15:42:42 -04:00
chainchompa
fc39086fb4 call stylePresetSelected 2024-08-27 15:34:31 -04:00
chainchompa
cd215700fe added route for selecting style preset 2024-08-27 15:34:07 -04:00
chainchompa
e97fd85904 added selectedStylePreset to preload presets when app loads 2024-08-27 15:33:24 -04:00
Brandon Rising
0a263fa5b1 chore: bump version to v4.2.9rc1 2024-08-27 12:09:27 -04:00
Mary Hipp
fae3836a8d fix CLIP 2024-08-27 10:29:10 -04:00
Mary Hipp
b3d2eb4178 add translations for new model types in MM, remove clip vision from filter since its not displayed in list 2024-08-27 10:29:10 -04:00
psychedelicious
576f1cbb75 build: remove broken scripts
These two scripts are broken and can cause data loss. Remove them.

They are not in the launcher script, but _are_ available to users in the terminal/file browser.

Hopefully, when we removing them here, `pip` will delete them on next installation of the package...
2024-08-27 22:01:45 +10:00
857 changed files with 30397 additions and 23288 deletions

View File

@ -11,7 +11,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByOriginResult,
ClearResult,
EnqueueBatchResult,
PruneResult,
@ -106,19 +105,6 @@ async def cancel_by_batch_ids(
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
@session_queue_router.put(
"/{queue_id}/cancel_by_origin",
operation_id="cancel_by_origin",
responses={200: {"model": CancelByBatchIDsResult}},
)
async def cancel_by_origin(
queue_id: str = Path(description="The queue id to perform this operation on"),
origin: str = Query(description="The origin to cancel all queue items for"),
) -> CancelByOriginResult:
"""Immediately cancels all queue items with the given origin"""
return ApiDependencies.invoker.services.session_queue.cancel_by_origin(queue_id=queue_id, origin=origin)
@session_queue_router.put(
"/{queue_id}/clear",
operation_id="clear",

View File

@ -20,6 +20,7 @@ from typing import (
Type,
TypeVar,
Union,
cast,
)
import semver
@ -79,7 +80,7 @@ class UIConfigBase(BaseModel):
version: str = Field(
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
)
node_pack: str = Field(description="The node pack that this node belongs to, will be 'invokeai' for built-in nodes")
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
model_config = ConfigDict(
@ -229,16 +230,18 @@ class BaseInvocation(ABC, BaseModel):
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
if title := model_class.UIConfig.title:
schema["title"] = title
if tags := model_class.UIConfig.tags:
schema["tags"] = tags
if category := model_class.UIConfig.category:
schema["category"] = category
if node_pack := model_class.UIConfig.node_pack:
schema["node_pack"] = node_pack
schema["classification"] = model_class.UIConfig.classification
schema["version"] = model_class.UIConfig.version
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
if uiconfig is not None:
if uiconfig.title is not None:
schema["title"] = uiconfig.title
if uiconfig.tags is not None:
schema["tags"] = uiconfig.tags
if uiconfig.category is not None:
schema["category"] = uiconfig.category
if uiconfig.node_pack is not None:
schema["node_pack"] = uiconfig.node_pack
schema["classification"] = uiconfig.classification
schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
schema["class"] = "invocation"
@ -309,7 +312,7 @@ class BaseInvocation(ABC, BaseModel):
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
)
UIConfig: ClassVar[UIConfigBase]
UIConfig: ClassVar[Type[UIConfigBase]]
model_config = ConfigDict(
protected_namespaces=(),
@ -438,25 +441,30 @@ def invocation(
validate_fields(cls.model_fields, invocation_type)
# Add OpenAPI schema extras
uiconfig: dict[str, Any] = {}
uiconfig["title"] = title
uiconfig["tags"] = tags
uiconfig["category"] = category
uiconfig["classification"] = classification
# The node pack is the module name - will be "invokeai" for built-in nodes
uiconfig["node_pack"] = cls.__module__.split(".")[0]
uiconfig_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconfig_name:
cls.UIConfig = type(uiconfig_name, (UIConfigBase,), {})
cls.UIConfig.title = title
cls.UIConfig.tags = tags
cls.UIConfig.category = category
cls.UIConfig.classification = classification
# Grab the node pack's name from the module name, if it's a custom node
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"
if is_custom_node:
cls.UIConfig.node_pack = cls.__module__.split(".")[0]
else:
cls.UIConfig.node_pack = None
if version is not None:
try:
semver.Version.parse(version)
except ValueError as e:
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
uiconfig["version"] = version
cls.UIConfig.version = version
else:
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
uiconfig["version"] = "1.0.0"
cls.UIConfig = UIConfigBase(**uiconfig)
cls.UIConfig.version = "1.0.0"
if use_cache is not None:
cls.model_fields["use_cache"].default = use_cache

View File

@ -45,11 +45,13 @@ class UIType(str, Enum, metaclass=MetaEnum):
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
VAEModel = "VAEModelField"
FluxVAEModel = "FluxVAEModelField"
LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField"
T5EncoderModel = "T5EncoderModelField"
CLIPEmbedModel = "CLIPEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
# endregion
@ -128,6 +130,7 @@ class FieldDescriptions:
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
t5_encoder = "T5 tokenizer and text encoder"
clip_embed_model = "CLIP Embed loader"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
vae = "VAE"

View File

@ -40,7 +40,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
t5_embeddings, clip_embeddings = self._encode_prompt(context)
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
t5_embeddings = self._t5_encode(context)
clip_embeddings = self._clip_encode(context)
conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
)
@ -48,12 +51,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput.build(conditioning_name)
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
# Load CLIP.
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
# Load T5.
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
@ -70,6 +68,15 @@ class FluxTextEncoderInvocation(BaseInvocation):
prompt_embeds = t5_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
prompt = [self.prompt]
with (
clip_text_encoder_info as clip_text_encoder,
clip_tokenizer_info as clip_tokenizer,
@ -81,6 +88,5 @@ class FluxTextEncoderInvocation(BaseInvocation):
pooled_prompt_embeds = clip_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return prompt_embeds, pooled_prompt_embeds
return pooled_prompt_embeds

View File

@ -58,13 +58,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
latents = self._run_diffusion(context)
image = self._run_vae_decoding(context, latents)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)
@ -72,12 +66,20 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def _run_diffusion(
self,
context: InvocationContext,
clip_embeddings: torch.Tensor,
t5_embeddings: torch.Tensor,
):
transformer_info = context.models.load(self.transformer.transformer)
inference_dtype = torch.bfloat16
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
transformer_info = context.models.load(self.transformer.transformer)
# Prepare input noise.
x = get_noise(
num_samples=1,
@ -88,24 +90,19 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
seed=self.seed,
)
img, img_ids = prepare_latent_img_patches(x)
x, img_ids = prepare_latent_img_patches(x)
is_schnell = "schnell" in transformer_info.config.config_path
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=img.shape[1],
image_seq_len=x.shape[1],
shift=not is_schnell,
)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty.
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
with transformer_info as transformer:
assert isinstance(transformer, Flux)
@ -140,7 +137,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
x = denoise(
model=transformer,
img=img,
img=x,
img_ids=img_ids,
txt=t5_embeddings,
txt_ids=txt_ids,

View File

@ -6,19 +6,13 @@ import cv2
import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import (
ColorField,
FieldDescriptions,
ImageField,
InputField,
OutputField,
WithBoard,
WithMetadata,
)
@ -1013,62 +1007,3 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
return ImageOutput.build(image_dto)
@invocation_output("canvas_v2_mask_and_crop_output")
class CanvasV2MaskAndCropOutput(ImageOutput):
offset_x: int = OutputField(description="The x offset of the image, after cropping")
offset_y: int = OutputField(description="The y offset of the image, after cropping")
@invocation(
"canvas_v2_mask_and_crop",
title="Canvas V2 Mask and Crop",
tags=["image", "mask", "id"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Handles Canvas V2 image output masking and cropping"""
source_image: ImageField | None = InputField(
default=None,
description="The source image onto which the masked generated image is pasted. If omitted, the masked generated image is returned with transparency.",
)
generated_image: ImageField = InputField(description="The image to apply the mask to")
mask: ImageField = InputField(description="The mask to apply")
mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")
def _prepare_mask(self, mask: Image.Image) -> Image.Image:
mask_array = numpy.array(mask)
kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
dilated_mask = Image.fromarray(dilated_mask_array)
if self.mask_blur > 0:
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
return ImageOps.invert(mask.convert("L"))
def invoke(self, context: InvocationContext) -> CanvasV2MaskAndCropOutput:
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
if self.source_image:
generated_image = context.images.get_pil(self.generated_image.image_name)
source_image = context.images.get_pil(self.source_image.image_name)
source_image.paste(generated_image, (0, 0), mask)
image_dto = context.images.save(image=source_image)
else:
generated_image = context.images.get_pil(self.generated_image.image_name)
generated_image.putalpha(mask)
image_dto = context.images.save(image=generated_image)
# bbox = image.getbbox()
# image = image.crop(bbox)
return CanvasV2MaskAndCropOutput(
image=ImageField(image_name=image_dto.image_name),
offset_x=0,
offset_y=0,
width=image_dto.width,
height=image_dto.height,
)

View File

@ -157,7 +157,7 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.3",
version="1.0.4",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
@ -169,23 +169,35 @@ class FluxModelLoaderInvocation(BaseInvocation):
input=Input.Direct,
)
t5_encoder: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder,
ui_type=UIType.T5EncoderModel,
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
)
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct,
title="CLIP Embed",
)
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
model_key = self.model.key
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
if not context.models.exists(model_key):
raise ValueError(f"Unknown model: {model_key}")
transformer = self._get_model(context, SubModelType.Transformer)
tokenizer = self._get_model(context, SubModelType.Tokenizer)
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
vae = self._get_model(context, SubModelType.VAE)
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
@ -197,52 +209,6 @@ class FluxModelLoaderInvocation(BaseInvocation):
max_seq_len=max_seq_lengths[transformer_config.config_path],
)
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
match submodel:
case SubModelType.Transformer:
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
case SubModelType.VAE:
return self._pull_model_from_mm(
context,
SubModelType.VAE,
"FLUX.1-schnell_ae",
ModelType.VAE,
BaseModelType.Flux,
)
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
return self._pull_model_from_mm(
context,
submodel,
"clip-vit-large-patch14",
ModelType.CLIPEmbed,
BaseModelType.Any,
)
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
return self._pull_model_from_mm(
context,
submodel,
self.t5_encoder.name,
ModelType.T5Encoder,
BaseModelType.Any,
)
case _:
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
def _pull_model_from_mm(
self,
context: InvocationContext,
submodel: SubModelType,
name: str,
type: ModelType,
base: BaseModelType,
):
if models := context.models.search_by_attrs(name=name, base=base, type=type):
if len(models) != 1:
raise Exception(f"Multiple models detected for selected model with name {name}")
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
else:
raise ValueError(f"Please install the {base}:{type} model named {name} via starter models")
@invocation(
"main_model_loader",

View File

@ -88,7 +88,6 @@ class QueueItemEventBase(QueueEventBase):
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
origin: str | None = Field(default=None, description="The origin of the batch")
class InvocationEventBase(QueueItemEventBase):
@ -96,6 +95,8 @@ class InvocationEventBase(QueueItemEventBase):
session_id: str = Field(description="The ID of the session (aka graph execution state)")
queue_id: str = Field(description="The ID of the queue")
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation: AnyInvocation = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
@ -113,7 +114,6 @@ class InvocationStartedEvent(InvocationEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@ -147,7 +147,6 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@ -185,7 +184,6 @@ class InvocationCompleteEvent(InvocationEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@ -218,7 +216,6 @@ class InvocationErrorEvent(InvocationEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@ -256,7 +253,6 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
status=queue_item.status,
error_type=queue_item.error_type,
@ -283,14 +279,12 @@ class BatchEnqueuedEvent(QueueEventBase):
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
)
priority: int = Field(description="The priority of the batch")
origin: str | None = Field(default=None, description="The origin of the batch")
@classmethod
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
return cls(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
origin=enqueue_result.batch.origin,
enqueued=enqueue_result.enqueued,
requested=enqueue_result.requested,
priority=enqueue_result.priority,

View File

@ -6,7 +6,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByOriginResult,
CancelByQueueIDResult,
ClearResult,
EnqueueBatchResult,
@ -96,11 +95,6 @@ class SessionQueueBase(ABC):
"""Cancels all queue items with matching batch IDs"""
pass
@abstractmethod
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
"""Cancels all queue items with the given batch origin"""
pass
@abstractmethod
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
"""Cancels all queue items with matching queue ID"""

View File

@ -77,7 +77,6 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]]
class Batch(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
origin: str | None = Field(default=None, description="The origin of this batch.")
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
graph: Graph = Field(description="The graph to initialize the session with")
workflow: Optional[WorkflowWithoutID] = Field(
@ -196,7 +195,6 @@ class SessionQueueItemWithoutGraph(BaseModel):
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
priority: int = Field(default=0, description="The priority of this queue item")
batch_id: str = Field(description="The ID of the batch associated with this queue item")
origin: str | None = Field(default=None, description="The origin of this queue item. ")
session_id: str = Field(
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
)
@ -296,7 +294,6 @@ class SessionQueueStatus(BaseModel):
class BatchStatus(BaseModel):
queue_id: str = Field(..., description="The ID of the queue")
batch_id: str = Field(..., description="The ID of the batch")
origin: str | None = Field(..., description="The origin of the batch")
pending: int = Field(..., description="Number of queue items with status 'pending'")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
completed: int = Field(..., description="Number of queue items with status 'complete'")
@ -331,12 +328,6 @@ class CancelByBatchIDsResult(BaseModel):
canceled: int = Field(..., description="Number of queue items canceled")
class CancelByOriginResult(BaseModel):
"""Result of canceling by list of batch ids"""
canceled: int = Field(..., description="Number of queue items canceled")
class CancelByQueueIDResult(CancelByBatchIDsResult):
"""Result of canceling by queue id"""
@ -442,7 +433,6 @@ class SessionQueueValueToInsert(NamedTuple):
field_values: Optional[str] # field_values json
priority: int # priority
workflow: Optional[str] # workflow json
origin: str | None
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
@ -463,7 +453,6 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
priority, # priority
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
batch.origin, # origin
)
)
return values_to_insert

View File

@ -10,7 +10,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByOriginResult,
CancelByQueueIDResult,
ClearResult,
EnqueueBatchResult,
@ -128,8 +127,8 @@ class SqliteSessionQueue(SessionQueueBase):
self.__cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
@ -418,7 +417,11 @@ class SqliteSessionQueue(SessionQueueBase):
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
current_queue_item, batch_status, queue_status
)
except Exception:
self.__conn.rollback()
raise
@ -426,46 +429,6 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.release()
return CancelByBatchIDsResult(canceled=count)
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
try:
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
where = """--sql
WHERE
queue_id == ?
AND origin == ?
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
"""
params = (queue_id, origin)
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
params,
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
{where};
""",
params,
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.origin == origin:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelByOriginResult(canceled=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
try:
current_queue_item = self.get_current(queue_id)
@ -578,8 +541,7 @@ class SqliteSessionQueue(SessionQueueBase):
started_at,
session_id,
batch_id,
queue_id,
origin
queue_id
FROM session_queue
WHERE queue_id = ?
"""
@ -659,7 +621,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*), origin
SELECT status, count(*)
FROM session_queue
WHERE
queue_id = ?
@ -671,7 +633,6 @@ class SqliteSessionQueue(SessionQueueBase):
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
except Exception:
self.__conn.rollback()
raise
@ -680,7 +641,6 @@ class SqliteSessionQueue(SessionQueueBase):
return BatchStatus(
batch_id=batch_id,
origin=origin,
queue_id=queue_id,
pending=counts.get("pending", 0),
in_progress=counts.get("in_progress", 0),

View File

@ -17,7 +17,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import build_migration_15
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -52,7 +51,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_12(app_config=config))
migrator.register_migration(build_migration_13())
migrator.register_migration(build_migration_14())
migrator.register_migration(build_migration_15())
migrator.run_migrations()
return db

View File

@ -1,31 +0,0 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration15Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._add_origin_col(cursor)
def _add_origin_col(self, cursor: sqlite3.Cursor) -> None:
"""
- Adds `origin` column to the session queue table.
"""
cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;")
def build_migration_15() -> Migration:
"""
Build the migration from database version 14 to 15.
This migration does the following:
- Adds `origin` column to the session queue table.
"""
migration_15 = Migration(
from_version=14,
to_version=15,
callback=Migration15Callback(),
)
return migration_15

View File

@ -2,13 +2,13 @@
"name": "FLUX Text to Image",
"author": "InvokeAI",
"description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
"version": "1.0.0",
"version": "1.0.4",
"contact": "",
"tags": "text2image, flux",
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
"exposedFields": [
{
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "model"
},
{
@ -20,8 +20,8 @@
"fieldName": "num_steps"
},
{
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"fieldName": "t5_encoder"
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "t5_encoder_model"
}
],
"meta": {
@ -30,12 +30,12 @@
},
"nodes": [
{
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "invocation",
"data": {
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "flux_model_loader",
"version": "1.0.3",
"version": "1.0.4",
"label": "",
"notes": "",
"isOpen": true,
@ -44,31 +44,25 @@
"inputs": {
"model": {
"name": "model",
"label": "Model (Starter Models can be found in Model Manager)",
"value": {
"key": "f04a7a2f-c74d-4538-8d5e-879a53501662",
"hash": "random:4875da7a9508444ffa706f61961c260d0c6729f6181a86b31fad06df1277b850",
"name": "FLUX Dev (Quantized)",
"base": "flux",
"type": "main"
}
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": "T 5 Encoder (Starter Models can be found in Model Manager)",
"value": {
"key": "20dcd9ec-5fbb-4012-8401-049e707da5e5",
"hash": "random:f986be43ff3502169e4adbdcee158afb0e0a65a1edc4cab16ae59963630cfd8f",
"name": "t5_bnb_int8_quantized_encoder",
"base": "any",
"type": "t5_encoder"
}
"t5_encoder_model": {
"name": "t5_encoder_model",
"label": ""
},
"clip_embed_model": {
"name": "clip_embed_model",
"label": ""
},
"vae_model": {
"name": "vae_model",
"label": ""
}
}
},
"position": {
"x": 337.09365228062825,
"y": 40.63469521079861
"x": 381.1882713063478,
"y": -95.89663532854017
}
},
{
@ -207,45 +201,45 @@
],
"edges": [
{
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33amax_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
"type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "max_seq_len",
"targetHandle": "t5_max_seq_len"
},
{
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33avae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
"type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33atransformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
"type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33at5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
"type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33aclip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
"type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning",
"type": "default",

View File

@ -111,16 +111,7 @@ def denoise(
step_callback: Callable[[], None],
guidance: float = 4.0,
):
dtype = model.txt_in.bias.dtype
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
img = img.to(dtype=dtype)
img_ids = img_ids.to(dtype=dtype)
txt = txt.to(dtype=dtype)
txt_ids = txt_ids.to(dtype=dtype)
vec = vec.to(dtype=dtype)
# this is ignored for schnell
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
@ -168,9 +159,9 @@ def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor,
img = repeat(img, "1 ... -> bs ...", bs=bs)
# Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
return img, img_ids

View File

@ -72,6 +72,7 @@ class ModelLoader(ModelLoaderBase):
pass
config.path = str(self._get_model_path(config))
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
loaded_model = self._load_model(config, submodel_type)
self._ram_cache.put(

View File

@ -193,15 +193,6 @@ class ModelCacheBase(ABC, Generic[T]):
"""
pass
@abstractmethod
def exists(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
) -> bool:
"""Return true if the model identified by key and submodel_type is in the cache."""
pass
@abstractmethod
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""

View File

@ -1,22 +1,6 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
# TODO: Add Stalker's proper name to copyright
"""
Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
The cache returns context manager generators designed to load the
model into the GPU within the context, and unload outside the
context. Use like this:
cache = ModelCache(max_cache_size=7.5)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
do_something_in_GPU(SD1,SD2)
"""
""" """
import gc
import math
@ -40,45 +24,64 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
# amount of GPU memory to hold in reserve for use by generations (GB)
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
# actual size of a gig
GIG = 1073741824
# Size of a GB in bytes.
GB = 2**30
# Size of a MB in bytes.
MB = 2**20
class ModelCache(ModelCacheBase[AnyModel]):
"""Implementation of ModelCacheBase."""
"""A cache for managing models in memory.
The cache is based on two levels of model storage:
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
The model cache is based on the following assumptions:
- storage_device_mem_size > execution_device_mem_size
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
the execution_device.
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
configuration.
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
the context, and unload outside the context.
Example usage:
```
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
do_something_on_gpu(SD1)
```
"""
def __init__(
self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
max_cache_size: float,
max_vram_cache_size: float,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
):
"""
Initialize the model RAM cache.
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param max_cache_size: Maximum size of the storage_device cache in GBs.
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded.
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
@ -86,7 +89,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""
# allow lazy offloading only when vram cache enabled
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device
@ -145,15 +147,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
total += cache_record.size
return total
def exists(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
) -> bool:
"""Return true if the model identified by key and submodel_type is in the cache."""
key = self._make_cache_key(key, submodel_type)
return key in self._cached_models
def put(
self,
key: str,
@ -203,7 +196,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
# more stats
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.cache_size = int(self._max_cache_size * GB)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
@ -231,10 +224,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
return model_key
def offload_unlocked_models(self, size_required: int) -> None:
"""Move any unused models from VRAM."""
reserved = self._max_vram_cache_size * GIG
"""Offload models from the execution_device to make room for size_required.
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
"""
reserved = self._max_vram_cache_size * GB
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
@ -245,7 +241,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
)
TorchDevice.empty_cache()
@ -303,7 +299,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self.logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
@ -326,14 +322,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
f"Moving model '{cache_entry.key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n"
f" {(cache_entry.size/GB):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % (self.cache_size() / GIG)
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
ram = "%4.2fG" % (self.cache_size() / GB)
in_ram_models = 0
in_vram_models = 0
@ -353,17 +349,20 @@ class ModelCache(ModelCacheBase[AnyModel]):
)
def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size."""
# calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1
"""Make enough room in the cache to accommodate a new model of indicated size.
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
external references to the model, there's nothing that the cache can do about it, and those models will not be
garbage-collected.
"""
bytes_needed = size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
current_size = self.cache_size()
if current_size + bytes_needed > maximum_size:
self.logger.debug(
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GIG):.2f} GB"
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GB):.2f} GB"
)
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
@ -380,7 +379,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
if not cache_entry.locked:
self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1

View File

@ -54,8 +54,10 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
scb = state_dict.pop(prefix + "SCB", None)
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
_weight_format = state_dict.pop(prefix + "weight_format", None)
# Currently, we only support weight_format=0.
weight_format = state_dict.pop(prefix + "weight_format", None)
assert weight_format == 0
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
# rather than raising an exception to correctly implement this API.
@ -89,6 +91,14 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
# Reset the state. The persisted fields are based on the initialization behaviour in
# `bnb.nn.Linear8bitLt.__init__()`.
new_state = bnb.MatmulLtState()
new_state.threshold = self.state.threshold
new_state.has_fp16_weights = False
new_state.use_pool = self.state.use_pool
self.state = new_state
def _convert_linear_layers_to_llm_8bit(
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""

View File

@ -43,6 +43,11 @@ class FLUXConditioningInfo:
clip_embeds: torch.Tensor
t5_embeds: torch.Tensor
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype)
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
return self
@dataclass
class ConditioningFieldData:

View File

@ -3,10 +3,9 @@ Initialization file for invokeai.backend.util
"""
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.util import GIG, Chdir, directory_size
from invokeai.backend.util.util import Chdir, directory_size
__all__ = [
"GIG",
"directory_size",
"Chdir",
"InvokeAILogger",

View File

@ -7,9 +7,6 @@ from pathlib import Path
from PIL import Image
# actual size of a gig
GIG = 1073741824
def slugify(value: str, allow_unicode: bool = False) -> str:
"""

View File

@ -12,10 +12,6 @@ module.exports = {
'i18next/no-literal-string': 'error',
// https://eslint.org/docs/latest/rules/no-console
'no-console': 'error',
// https://eslint.org/docs/latest/rules/no-promise-executor-return
'no-promise-executor-return': 'error',
// https://eslint.org/docs/latest/rules/require-await
'require-await': 'error',
},
overrides: [
/**

View File

@ -1,5 +1,5 @@
import { PropsWithChildren, memo, useEffect } from 'react';
import { modelChanged } from '../src/features/controlLayers/store/paramsSlice';
import { modelChanged } from '../src/features/parameters/store/generationSlice';
import { useAppDispatch } from '../src/app/store/storeHooks';
import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
/**
@ -10,9 +10,7 @@ export const ReduxInit = memo((props: PropsWithChildren) => {
const dispatch = useAppDispatch();
useGlobalModifiersInit();
useEffect(() => {
dispatch(
modelChanged({ model: { key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' } })
);
dispatch(modelChanged({ key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' }));
}, []);
return props.children;

View File

@ -9,8 +9,6 @@ const config: KnipConfig = {
'src/services/api/schema.ts',
'src/features/nodes/types/v1/**',
'src/features/nodes/types/v2/**',
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
'src/features/controlLayers/konva/util.ts',
],
ignoreBinaries: ['only-allow'],
paths: {

View File

@ -24,7 +24,7 @@
"build": "pnpm run lint && vite build",
"typegen": "node scripts/typegen.js",
"preview": "vite preview",
"lint:knip": "knip --tags=-knipignore",
"lint:knip": "knip",
"lint:dpdm": "dpdm --no-warning --no-tree --transform --exit-code circular:1 src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .",
"lint:prettier": "prettier --check .",
@ -52,19 +52,18 @@
}
},
"dependencies": {
"@chakra-ui/react-use-size": "^2.1.0",
"@dagrejs/dagre": "^1.1.3",
"@dagrejs/graphlib": "^2.2.3",
"@dnd-kit/core": "^6.1.0",
"@dnd-kit/sortable": "^8.0.0",
"@dnd-kit/utilities": "^3.2.2",
"@fontsource-variable/inter": "^5.0.20",
"@invoke-ai/ui-library": "^0.0.32",
"@invoke-ai/ui-library": "^0.0.29",
"@nanostores/react": "^0.7.3",
"@reduxjs/toolkit": "2.2.3",
"@roarr/browser-log-writer": "^1.3.0",
"async-mutex": "^0.5.0",
"chakra-react-select": "^4.9.1",
"cmdk": "^1.0.0",
"compare-versions": "^6.1.1",
"dateformat": "^5.0.3",
"fracturedjsonjs": "^4.0.2",
@ -75,8 +74,6 @@
"jsondiffpatch": "^0.6.0",
"konva": "^9.3.14",
"lodash-es": "^4.17.21",
"lru-cache": "^11.0.0",
"nanoid": "^5.0.7",
"nanostores": "^0.11.2",
"new-github-issue-url": "^1.0.0",
"overlayscrollbars": "^2.10.0",
@ -91,8 +88,10 @@
"react-hotkeys-hook": "4.5.0",
"react-i18next": "^14.1.3",
"react-icons": "^5.2.1",
"react-konva": "^18.2.10",
"react-redux": "9.1.2",
"react-resizable-panels": "^2.0.23",
"react-select": "5.8.0",
"react-use": "^17.5.1",
"react-virtuoso": "^4.9.0",
"reactflow": "^11.11.4",
@ -103,9 +102,9 @@
"roarr": "^7.21.1",
"serialize-error": "^11.0.3",
"socket.io-client": "^4.7.5",
"stable-hash": "^0.0.4",
"use-debounce": "^10.0.2",
"use-device-pixel-ratio": "^1.1.2",
"use-image": "^1.1.1",
"uuid": "^10.0.0",
"zod": "^3.23.8",
"zod-validation-error": "^3.3.1"

File diff suppressed because it is too large Load Diff

View File

@ -80,7 +80,6 @@
"aboutDesc": "Using Invoke for work? Check out:",
"aboutHeading": "Own Your Creative Power",
"accept": "Accept",
"apply": "Apply",
"add": "Add",
"advanced": "Advanced",
"ai": "ai",
@ -116,7 +115,6 @@
"githubLabel": "Github",
"goTo": "Go to",
"hotkeysLabel": "Hotkeys",
"loadingImage": "Loading Image",
"imageFailedToLoad": "Unable to Load Image",
"img2img": "Image To Image",
"inpaint": "inpaint",
@ -327,10 +325,6 @@
"canceled": "Canceled",
"completedIn": "Completed in",
"batch": "Batch",
"origin": "Origin",
"originCanvas": "Canvas",
"originWorkflows": "Workflows",
"originOther": "Other",
"batchFieldValues": "Batch Field Values",
"item": "Item",
"session": "Session",
@ -702,6 +696,8 @@
"availableModels": "Available Models",
"baseModel": "Base Model",
"cancel": "Cancel",
"clipEmbed": "CLIP Embed",
"clipVision": "CLIP Vision",
"config": "Config",
"convert": "Convert",
"convertingModelBegin": "Converting Model. Please wait.",
@ -789,6 +785,7 @@
"settings": "Settings",
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
"source": "Source",
"spandrelImageToImage": "Image to Image (Spandrel)",
"starterModels": "Starter Models",
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
"syncModels": "Sync Models",
@ -797,6 +794,7 @@
"loraTriggerPhrases": "LoRA Trigger Phrases",
"mainModelTriggerPhrases": "Main Model Trigger Phrases",
"typePhraseHere": "Type phrase here",
"t5Encoder": "T5 Encoder",
"upcastAttention": "Upcast Attention",
"uploadImage": "Upload Image",
"urlOrLocalPath": "URL or Local Path",
@ -1102,6 +1100,7 @@
"confirmOnDelete": "Confirm On Delete",
"developer": "Developer",
"displayInProgress": "Display Progress Images",
"enableImageDebugging": "Enable Image Debugging",
"enableInformationalPopovers": "Enable Informational Popovers",
"informationalPopoversDisabled": "Informational Popovers Disabled",
"informationalPopoversDisabledDesc": "Informational popovers have been disabled. Enable them in Settings.",
@ -1568,7 +1567,7 @@
"copyToClipboard": "Copy to Clipboard",
"cursorPosition": "Cursor Position",
"darkenOutsideSelection": "Darken Outside Selection",
"discardAll": "Discard All & Cancel Pending Generations",
"discardAll": "Discard All",
"discardCurrent": "Discard Current",
"downloadAsImage": "Download As Image",
"enableMask": "Enable Mask",
@ -1646,123 +1645,39 @@
"storeNotInitialized": "Store is not initialized"
},
"controlLayers": {
"clearHistory": "Clear History",
"generateMode": "Generate",
"generateModeDesc": "Create individual images. Generated images are added directly to the gallery.",
"composeMode": "Compose",
"composeModeDesc": "Compose your work iterative. Generated images are added back to the canvas.",
"autoSave": "Auto-save to Gallery",
"resetCanvas": "Reset Canvas",
"resetAll": "Reset All",
"clearCaches": "Clear Caches",
"recalculateRects": "Recalculate Rects",
"clipToBbox": "Clip Strokes to Bbox",
"deleteAll": "Delete All",
"addLayer": "Add Layer",
"duplicate": "Duplicate",
"moveToFront": "Move to Front",
"moveToBack": "Move to Back",
"moveForward": "Move Forward",
"moveBackward": "Move Backward",
"brushSize": "Brush Size",
"width": "Width",
"zoom": "Zoom",
"resetView": "Reset View",
"controlLayers": "Control Layers",
"globalMaskOpacity": "Global Mask Opacity",
"autoNegative": "Auto Negative",
"enableAutoNegative": "Enable Auto Negative",
"disableAutoNegative": "Disable Auto Negative",
"deletePrompt": "Delete Prompt",
"resetRegion": "Reset Region",
"debugLayers": "Debug Layers",
"rectangle": "Rectangle",
"maskFill": "Mask Fill",
"maskPreviewColor": "Mask Preview Color",
"addPositivePrompt": "Add $t(common.positivePrompt)",
"addNegativePrompt": "Add $t(common.negativePrompt)",
"addIPAdapter": "Add $t(common.ipAdapter)",
"regionalGuidance": "Regional Guidance",
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",
"raster": "Raster",
"rasterLayer_one": "Raster Layer",
"controlLayer_one": "Control Layer",
"inpaintMask_one": "Inpaint Mask",
"regionalGuidance_one": "Regional Guidance",
"ipAdapter_one": "IP Adapter",
"rasterLayer_other": "Raster Layers",
"controlLayer_other": "Control Layers",
"inpaintMask_other": "Inpaint Masks",
"regionalGuidance_other": "Regional Guidance",
"ipAdapter_other": "IP Adapters",
"opacity": "Opacity",
"regionalGuidance_withCount_hidden": "Regional Guidance ({{count}} hidden)",
"controlAdapters_withCount_hidden": "Control Adapters ({{count}} hidden)",
"controlLayers_withCount_hidden": "Control Layers ({{count}} hidden)",
"rasterLayers_withCount_hidden": "Raster Layers ({{count}} hidden)",
"ipAdapters_withCount_hidden": "IP Adapters ({{count}} hidden)",
"inpaintMasks_withCount_hidden": "Inpaint Masks ({{count}} hidden)",
"regionalGuidance_withCount_visible": "Regional Guidance ({{count}})",
"controlAdapters_withCount_visible": "Control Adapters ({{count}})",
"controlLayers_withCount_visible": "Control Layers ({{count}})",
"rasterLayers_withCount_visible": "Raster Layers ({{count}})",
"ipAdapters_withCount_visible": "IP Adapters ({{count}})",
"inpaintMasks_withCount_visible": "Inpaint Masks ({{count}})",
"globalControlAdapter": "Global $t(controlnet.controlAdapter_one)",
"globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",
"globalIPAdapter": "Global $t(common.ipAdapter)",
"globalIPAdapterLayer": "Global $t(common.ipAdapter) $t(unifiedCanvas.layer)",
"globalInitialImage": "Global Initial Image",
"globalInitialImageLayer": "$t(controlLayers.globalInitialImage) $t(unifiedCanvas.layer)",
"layer": "Layer",
"opacityFilter": "Opacity Filter",
"clearProcessor": "Clear Processor",
"resetProcessor": "Reset Processor to Defaults",
"noLayersAdded": "No Layers Added",
"layers_one": "Layer",
"layers_other": "Layers",
"objects_zero": "empty",
"objects_one": "{{count}} object",
"objects_other": "{{count}} objects",
"convertToControlLayer": "Convert to Control Layer",
"convertToRasterLayer": "Convert to Raster Layer",
"transparency": "Transparency",
"enableTransparencyEffect": "Enable Transparency Effect",
"disableTransparencyEffect": "Disable Transparency Effect",
"hidingType": "Hiding {{type}}",
"showingType": "Showing {{type}}",
"dynamicGrid": "Dynamic Grid",
"logDebugInfo": "Log Debug Info",
"locked": "Locked",
"unlocked": "Unlocked",
"deleteSelected": "Delete Selected",
"deleteAll": "Delete All",
"flipHorizontal": "Flip Horizontal",
"flipVertical": "Flip Vertical",
"fill": {
"fillStyle": "Fill Style",
"solid": "Solid",
"grid": "Grid",
"crosshatch": "Crosshatch",
"vertical": "Vertical",
"horizontal": "Horizontal",
"diagonal": "Diagonal"
},
"tool": {
"brush": "Brush",
"eraser": "Eraser",
"rectangle": "Rectangle",
"bbox": "Bbox",
"move": "Move",
"view": "View",
"transform": "Transform",
"colorPicker": "Color Picker"
},
"filter": {
"filter": "Filter",
"filters": "Filters",
"filterType": "Filter Type",
"preview": "Preview",
"apply": "Apply",
"cancel": "Cancel"
}
"layers_other": "Layers"
},
"upscaling": {
"upscale": "Upscale",
@ -1850,30 +1765,5 @@
"upscaling": "Upscaling",
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
}
},
"system": {
"enableLogging": "Enable Logging",
"logLevel": {
"logLevel": "Log Level",
"trace": "Trace",
"debug": "Debug",
"info": "Info",
"warn": "Warn",
"error": "Error",
"fatal": "Fatal"
},
"logNamespaces": {
"logNamespaces": "Log Namespaces",
"gallery": "Gallery",
"models": "Models",
"config": "Config",
"canvas": "Canvas",
"generation": "Generation",
"workflows": "Workflows",
"system": "System",
"events": "Events",
"queue": "Queue",
"metadata": "Metadata"
}
}
}

View File

@ -38,7 +38,7 @@ async function generateTypes(schema) {
process.stdout.write(`\nOK!\r\n`);
}
function main() {
async function main() {
const encoding = 'utf-8';
if (process.stdin.isTTY) {

View File

@ -6,7 +6,6 @@ import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/ap
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { PartialAppConfig } from 'app/types/invokeai';
import ImageUploadOverlay from 'common/components/ImageUploadOverlay';
import { useScopeFocusWatcher } from 'common/hooks/interactionScopes';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { useFullscreenDropzone } from 'common/hooks/useFullscreenDropzone';
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
@ -14,15 +13,13 @@ import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardMo
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
import RefreshAfterResetModal from 'features/system/components/SettingsModal/RefreshAfterResetModal';
import SettingsModal from 'features/system/components/SettingsModal/SettingsModal';
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { configChanged } from 'features/system/store/configSlice';
import { selectLanguage } from 'features/system/store/systemSelectors';
import { AppContent } from 'features/ui/components/AppContent';
import { languageSelector } from 'features/system/store/systemSelectors';
import InvokeTabs from 'features/ui/components/InvokeTabs';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import { setActiveTab } from 'features/ui/store/uiSlice';
import type { TabName } from 'features/ui/store/uiTypes';
import { useGetAndLoadLibraryWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadLibraryWorkflow';
import { AnimatePresence } from 'framer-motion';
import i18n from 'i18n';
@ -43,11 +40,18 @@ interface Props {
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
selectedWorkflowId?: string;
destination?: TabName | undefined;
selectedStylePresetId?: string;
destination?: InvokeTabName | undefined;
}
const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, destination }: Props) => {
const language = useAppSelector(selectLanguage);
const App = ({
config = DEFAULT_CONFIG,
selectedImage,
selectedWorkflowId,
selectedStylePresetId,
destination,
}: Props) => {
const language = useAppSelector(languageSelector);
const logger = useLogger('system');
const dispatch = useAppDispatch();
const clearStorage = useClearStorage();
@ -85,6 +89,12 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
}
}, [selectedWorkflowId, getAndLoadWorkflow]);
useEffect(() => {
if (selectedStylePresetId) {
dispatch(activeStylePresetIdChanged(selectedStylePresetId));
}
}, [dispatch, selectedStylePresetId]);
useEffect(() => {
if (destination) {
dispatch(setActiveTab(destination));
@ -97,7 +107,6 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
useStarterModelsToast();
useSyncQueueStatus();
useScopeFocusWatcher();
return (
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
@ -110,7 +119,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
{...dropzone.getRootProps()}
>
<input {...dropzone.getInputProps()} />
<AppContent />
<InvokeTabs />
<AnimatePresence>
{dropzone.isDragActive && isHandlingUpload && (
<ImageUploadOverlay dropzone={dropzone} setIsHandlingUpload={setIsHandlingUpload} />
@ -121,10 +130,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
<ChangeBoardModal />
<DynamicPromptsModal />
<StylePresetModal />
<ClearQueueConfirmationsAlertDialog />
<PreselectedImage selectedImage={selectedImage} />
<SettingsModal />
<RefreshAfterResetModal />
</ErrorBoundary>
);
};

View File

@ -1,7 +1,5 @@
import { Button, Flex, Heading, Image, Link, Text } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import newGithubIssueUrl from 'new-github-issue-url';
import InvokeLogoYellow from 'public/assets/images/invoke-symbol-ylw-lrg.svg';
@ -15,11 +13,9 @@ type Props = {
resetErrorBoundary: () => void;
};
const selectIsLocal = createSelector(selectConfigSlice, (config) => config.isLocal);
const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => {
const { t } = useTranslation();
const isLocal = useAppSelector(selectIsLocal);
const isLocal = useAppSelector((s) => s.config.isLocal);
const handleCopy = useCallback(() => {
const text = JSON.stringify(serializeError(error), null, 2);

View File

@ -19,7 +19,7 @@ import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import AppDndContext from 'features/dnd/components/AppDndContext';
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import type { TabName } from 'features/ui/store/uiTypes';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import type { PropsWithChildren, ReactNode } from 'react';
import React, { lazy, memo, useEffect, useMemo } from 'react';
import { Provider } from 'react-redux';
@ -45,7 +45,8 @@ interface Props extends PropsWithChildren {
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
selectedWorkflowId?: string;
destination?: TabName;
selectedStylePresetId?: string;
destination?: InvokeTabName;
customStarUi?: CustomStarUi;
socketOptions?: Partial<ManagerOptions & SocketOptions>;
isDebugging?: boolean;
@ -66,6 +67,7 @@ const InvokeAIUI = ({
queueId,
selectedImage,
selectedWorkflowId,
selectedStylePresetId,
destination,
customStarUi,
socketOptions,
@ -227,6 +229,7 @@ const InvokeAIUI = ({
config={config}
selectedImage={selectedImage}
selectedWorkflowId={selectedWorkflowId}
selectedStylePresetId={selectedStylePresetId}
destination={destination}
/>
</AppDndContext>

View File

@ -2,7 +2,7 @@ import { useStore } from '@nanostores/react';
import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $isDebugging } from 'app/store/nanostores/isDebugging';
import { useAppStore } from 'app/store/nanostores/store';
import { useAppDispatch } from 'app/store/storeHooks';
import type { MapStore } from 'nanostores';
import { atom, map } from 'nanostores';
import { useEffect, useMemo } from 'react';
@ -18,19 +18,14 @@ declare global {
}
}
export type AppSocket = Socket<ServerToClientEvents, ClientToServerEvents>;
export const $socket = atom<AppSocket | null>(null);
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
const $isSocketInitialized = atom<boolean>(false);
export const $isConnected = atom<boolean>(false);
/**
* Initializes the socket.io connection and sets up event listeners.
*/
export const useSocketIO = () => {
const { dispatch, getState } = useAppStore();
const dispatch = useAppDispatch();
const baseUrl = useStore($baseUrl);
const authToken = useStore($authToken);
const addlSocketOptions = useStore($socketOptions);
@ -66,9 +61,8 @@ export const useSocketIO = () => {
return;
}
const socket: AppSocket = io(socketUrl, socketOptions);
$socket.set(socket);
setEventListeners({ socket, dispatch, getState, setIsConnected: $isConnected.set });
const socket: Socket<ServerToClientEvents, ClientToServerEvents> = io(socketUrl, socketOptions);
setEventListeners({ dispatch, socket });
socket.connect();
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
@ -90,5 +84,5 @@ export const useSocketIO = () => {
socket.disconnect();
$isSocketInitialized.set(false);
};
}, [dispatch, getState, socketOptions, socketUrl]);
}, [dispatch, socketOptions, socketUrl]);
};

View File

@ -15,21 +15,21 @@ export const BASE_CONTEXT = {};
export const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
export const zLogNamespace = z.enum([
'canvas',
'config',
'events',
'gallery',
'generation',
'metadata',
'models',
'system',
'queue',
'workflows',
]);
export type LogNamespace = z.infer<typeof zLogNamespace>;
export type LoggerNamespace =
| 'images'
| 'models'
| 'config'
| 'canvas'
| 'generation'
| 'nodes'
| 'system'
| 'socketio'
| 'session'
| 'queue'
| 'dnd'
| 'controlLayers';
export const logger = (namespace: LogNamespace) => $logger.get().child({ namespace });
export const logger = (namespace: LoggerNamespace) => $logger.get().child({ namespace });
export const zLogLevel = z.enum(['trace', 'debug', 'info', 'warn', 'error', 'fatal']);
export type LogLevel = z.infer<typeof zLogLevel>;

View File

@ -1,41 +1,29 @@
import { createLogWriter } from '@roarr/browser-log-writer';
import { useAppSelector } from 'app/store/storeHooks';
import {
selectSystemLogIsEnabled,
selectSystemLogLevel,
selectSystemLogNamespaces,
} from 'features/system/store/systemSlice';
import { useEffect, useMemo } from 'react';
import { ROARR, Roarr } from 'roarr';
import type { LogNamespace } from './logger';
import type { LoggerNamespace } from './logger';
import { $logger, BASE_CONTEXT, LOG_LEVEL_MAP, logger } from './logger';
export const useLogger = (namespace: LogNamespace) => {
const logLevel = useAppSelector(selectSystemLogLevel);
const logNamespaces = useAppSelector(selectSystemLogNamespaces);
const logIsEnabled = useAppSelector(selectSystemLogIsEnabled);
export const useLogger = (namespace: LoggerNamespace) => {
const consoleLogLevel = useAppSelector((s) => s.system.consoleLogLevel);
const shouldLogToConsole = useAppSelector((s) => s.system.shouldLogToConsole);
// The provided Roarr browser log writer uses localStorage to config logging to console
useEffect(() => {
if (logIsEnabled) {
if (shouldLogToConsole) {
// Enable console log output
localStorage.setItem('ROARR_LOG', 'true');
// Use a filter to show only logs of the given level
let filter = `context.logLevel:>=${LOG_LEVEL_MAP[logLevel]}`;
if (logNamespaces.length > 0) {
filter += ` AND (${logNamespaces.map((ns) => `context.namespace:${ns}`).join(' OR ')})`;
} else {
filter += ' AND context.namespace:undefined';
}
localStorage.setItem('ROARR_FILTER', filter);
localStorage.setItem('ROARR_FILTER', `context.logLevel:>=${LOG_LEVEL_MAP[consoleLogLevel]}`);
} else {
// Disable console log output
localStorage.setItem('ROARR_LOG', 'false');
}
ROARR.write = createLogWriter();
}, [logLevel, logIsEnabled, logNamespaces]);
}, [consoleLogLevel, shouldLogToConsole]);
// Update the module-scoped logger context as needed
useEffect(() => {

View File

@ -1,7 +1,7 @@
import { createAction } from '@reduxjs/toolkit';
import type { TabName } from 'features/ui/store/uiTypes';
import type { InvokeTabName } from 'features/ui/store/tabMap';
export const enqueueRequested = createAction<{
tabName: TabName;
tabName: InvokeTabName;
prepend: boolean;
}>('app/enqueueRequested');

View File

@ -1,3 +1,2 @@
export const STORAGE_PREFIX = '@@invokeai-';
export const EMPTY_ARRAY = [];
export const EMPTY_OBJECT = {};

View File

@ -1,6 +1,5 @@
import { createDraftSafeSelectorCreator, createSelectorCreator, lruMemoize } from '@reduxjs/toolkit';
import type { GetSelectorsOptions } from '@reduxjs/toolkit/dist/entities/state_selectors';
import type { RootState } from 'app/store/store';
import { isEqual } from 'lodash-es';
/**
@ -20,5 +19,3 @@ export const getSelectorsOptions: GetSelectorsOptions = {
argsMemoize: lruMemoize,
}),
};
export const createMemoizedAppSelector = createMemoizedSelector.withTypes<RootState>();

View File

@ -1,4 +1,5 @@
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { PersistError, RehydrateError } from 'redux-remember';
import { serializeError } from 'serialize-error';
@ -40,6 +41,6 @@ export const errorHandler = (err: PersistError | RehydrateError) => {
} else if (err instanceof RehydrateError) {
log.error({ error: serializeError(err) }, 'Problem rehydrating state');
} else {
log.error({ error: serializeError(err) }, 'Problem in persistence layer');
log.error({ error: parseify(err) }, 'Problem in persistence layer');
}
};

View File

@ -1,7 +1,9 @@
import type { UnknownAction } from '@reduxjs/toolkit';
import { deepClone } from 'common/util/deepClone';
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import type { Graph } from 'services/api/types';
import { socketGeneratorProgress } from 'services/events/actions';
export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
if (isAnyGraphBuilt(action)) {
@ -22,5 +24,13 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
};
}
if (socketGeneratorProgress.match(action)) {
const sanitized = deepClone(action);
if (sanitized.payload.data.progress_image) {
sanitized.payload.data.progress_image.dataURL = '<Progress image omitted>';
}
return sanitized;
}
return action;
};

View File

@ -1,7 +1,7 @@
import type { TypedStartListening } from '@reduxjs/toolkit';
import { createListenerMiddleware } from '@reduxjs/toolkit';
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { addStagingListeners } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
import { addCommitStagingAreaImageListener } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
@ -9,6 +9,17 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addCanvasCopiedToClipboardListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard';
import { addCanvasDownloadedAsImageListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage';
import { addCanvasImageToControlNetListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasImageToControlNet';
import { addCanvasMaskSavedToGalleryListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMaskSavedToGallery';
import { addCanvasMaskToControlNetListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet';
import { addCanvasMergedListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMerged';
import { addCanvasSavedToGalleryListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery';
import { addControlAdapterPreprocessor } from 'app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor';
import { addControlNetAutoProcessListener } from 'app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess';
import { addControlNetImageProcessedListener } from 'app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed';
import { addEnqueueRequestedCanvasListener } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas';
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
@ -26,7 +37,16 @@ import { addModelSelectedListener } from 'app/store/middleware/listenerMiddlewar
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged';
import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved';
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store';
@ -63,6 +83,7 @@ addGalleryImageClickedListener(startAppListening);
addGalleryOffsetChangedListener(startAppListening);
// User Invoked
addEnqueueRequestedCanvasListener(startAppListening);
addEnqueueRequestedNodes(startAppListening);
addEnqueueRequestedLinear(startAppListening);
addEnqueueRequestedUpscale(startAppListening);
@ -70,23 +91,32 @@ addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);
// Canvas actions
// addCanvasSavedToGalleryListener(startAppListening);
// addCanvasMaskSavedToGalleryListener(startAppListening);
// addCanvasImageToControlNetListener(startAppListening);
// addCanvasMaskToControlNetListener(startAppListening);
// addCanvasDownloadedAsImageListener(startAppListening);
// addCanvasCopiedToClipboardListener(startAppListening);
// addCanvasMergedListener(startAppListening);
// addStagingAreaImageSavedListener(startAppListening);
// addCommitStagingAreaImageListener(startAppListening);
addStagingListeners(startAppListening);
addCanvasSavedToGalleryListener(startAppListening);
addCanvasMaskSavedToGalleryListener(startAppListening);
addCanvasImageToControlNetListener(startAppListening);
addCanvasMaskToControlNetListener(startAppListening);
addCanvasDownloadedAsImageListener(startAppListening);
addCanvasCopiedToClipboardListener(startAppListening);
addCanvasMergedListener(startAppListening);
addStagingAreaImageSavedListener(startAppListening);
addCommitStagingAreaImageListener(startAppListening);
// Socket.IO
addGeneratorProgressEventListener(startAppListening);
addInvocationCompleteEventListener(startAppListening);
addInvocationErrorEventListener(startAppListening);
addInvocationStartedEventListener(startAppListening);
addSocketConnectedEventListener(startAppListening);
// Gallery bulk download
addSocketDisconnectedEventListener(startAppListening);
addModelLoadEventListener(startAppListening);
addModelInstallEventListener(startAppListening);
addSocketQueueItemStatusChangedEventListener(startAppListening);
addBulkDownloadListeners(startAppListening);
// ControlNet
addControlNetImageProcessedListener(startAppListening);
addControlNetAutoProcessListener(startAppListening);
// Boards
addImageAddedToBoardFulfilledListener(startAppListening);
addImageRemovedFromBoardFulfilledListener(startAppListening);
@ -118,4 +148,4 @@ addAdHocPostProcessingRequestedListener(startAppListening);
addDynamicPromptsListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);
// addControlAdapterPreprocessor(startAppListening);
addControlAdapterPreprocessor(startAppListening);

View File

@ -1,21 +1,21 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import { parseify } from 'common/util/serialize';
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig, ImageDTO } from 'services/api/types';
const log = logger('queue');
export const adHocPostProcessingRequested = createAction<{ imageDTO: ImageDTO }>(`upscaling/postProcessingRequested`);
export const addAdHocPostProcessingRequestedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: adHocPostProcessingRequested,
effect: async (action, { dispatch, getState }) => {
const log = logger('session');
const { imageDTO } = action.payload;
const state = getState();
@ -39,9 +39,9 @@ export const addAdHocPostProcessingRequestedListener = (startAppListening: AppSt
const enqueueResult = await req.unwrap();
req.reset();
log.debug({ enqueueResult } as SerializableObject, t('queue.graphQueued'));
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
} catch (error) {
log.error({ enqueueBatchArg } as SerializableObject, t('queue.graphFailedToQueue'));
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
if (error instanceof Object && 'status' in error && error.status === 403) {
return;

View File

@ -23,7 +23,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
*/
startAppListening({
matcher: matchAnyBoardDeleted,
effect: (action, { dispatch, getState }) => {
effect: async (action, { dispatch, getState }) => {
const state = getState();
const deletedBoardId = action.meta.arg.originalArgs;
const { autoAddBoardId, selectedBoardId } = state.gallery;
@ -44,7 +44,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
// If we archived a board, it may end up hidden. If it's selected or the auto-add board, we should reset those.
startAppListening({
matcher: boardsApi.endpoints.updateBoard.matchFulfilled,
effect: (action, { dispatch, getState }) => {
effect: async (action, { dispatch, getState }) => {
const state = getState();
const { shouldShowArchivedBoards } = state.gallery;
@ -61,7 +61,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
// When we hide archived boards, if the selected or the auto-add board is archived, we should reset those.
startAppListening({
actionCreator: shouldShowArchivedBoardsChanged,
effect: (action, { dispatch, getState }) => {
effect: async (action, { dispatch, getState }) => {
const shouldShowArchivedBoards = action.payload;
// We only need to take action if we have just hidden archived boards.
@ -100,7 +100,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
*/
startAppListening({
matcher: boardsApi.endpoints.listAllBoards.matchFulfilled,
effect: (action, { dispatch, getState }) => {
effect: async (action, { dispatch, getState }) => {
const boards = action.payload;
const state = getState();
const { selectedBoardId, autoAddBoardId } = state.gallery;

View File

@ -1,37 +1,33 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import {
sessionStagingAreaImageAccepted,
sessionStagingAreaReset,
} from 'features/controlLayers/store/canvasSessionSlice';
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
canvasBatchIdsReset,
commitStagingAreaImage,
discardStagedImages,
resetCanvas,
setInitialCanvasImage,
} from 'features/canvas/store/canvasSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
import { assert } from 'tsafe';
const log = logger('canvas');
const matcher = isAnyOf(commitStagingAreaImage, discardStagedImages, resetCanvas, setInitialCanvasImage);
export const addStagingListeners = (startAppListening: AppStartListening) => {
export const addCommitStagingAreaImageListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: sessionStagingAreaReset,
effect: async (_, { dispatch }) => {
matcher,
effect: async (_, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const { batchIds } = state.canvas;
try {
const req = dispatch(
queueApi.endpoints.cancelByBatchOrigin.initiate(
{ origin: 'canvas' },
{ fixedCacheKey: 'cancelByBatchOrigin' }
)
queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: batchIds }, { fixedCacheKey: 'cancelByBatchIds' })
);
const { canceled } = await req.unwrap();
req.reset();
$lastCanvasProgressEvent.set(null);
if (canceled > 0) {
log.debug(`Canceled ${canceled} canvas batches`);
toast({
@ -40,6 +36,7 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
status: 'success',
});
}
dispatch(canvasBatchIdsReset());
} catch {
log.error('Failed to cancel canvas batches');
toast({
@ -50,26 +47,4 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
}
},
});
startAppListening({
actionCreator: sessionStagingAreaImageAccepted,
effect: (action, api) => {
const { index } = action.payload;
const state = api.getState();
const stagingAreaImage = state.canvasSession.stagedImages[index];
assert(stagingAreaImage, 'No staged image found to accept');
const { x, y } = selectCanvasSlice(state).bbox.rect;
const { imageDTO, offsetX, offsetY } = stagingAreaImage;
const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasRasterLayerState> = {
position: { x: x + offsetX, y: y + offsetY },
objects: [imageObject],
};
api.dispatch(rasterLayerAdded({ overrides, isSelected: true }));
api.dispatch(sessionStagingAreaReset());
},
});
};

View File

@ -4,7 +4,7 @@ import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
export const addAnyEnqueuedListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
effect: (_, { dispatch, getState }) => {
effect: async (_, { dispatch, getState }) => {
const { data } = selectQueueStatus(getState());
if (!data || data.processor.is_started) {

View File

@ -1,14 +1,14 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { setInfillMethod } from 'features/controlLayers/store/paramsSlice';
import { setInfillMethod } from 'features/parameters/store/generationSlice';
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo';
export const addAppConfigReceivedListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
effect: (action, { getState, dispatch }) => {
effect: async (action, { getState, dispatch }) => {
const { infill_methods = [], nsfw_methods = [], watermarking_methods = [] } = action.payload;
const infillMethod = getState().params.infillMethod;
const infillMethod = getState().generation.infillMethod;
if (!infill_methods.includes(infillMethod)) {
// if there is no infill method, set it to the first one

View File

@ -6,7 +6,7 @@ export const appStarted = createAction('app/appStarted');
export const addAppStartedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: appStarted,
effect: (action, { unsubscribe, cancelActiveListeners }) => {
effect: async (action, { unsubscribe, cancelActiveListeners }) => {
// this should only run once
cancelActiveListeners();
unsubscribe();

View File

@ -1,30 +1,27 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import { parseify } from 'common/util/serialize';
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { truncate, upperFirst } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue';
const log = logger('queue');
export const addBatchEnqueuedListener = (startAppListening: AppStartListening) => {
// success
startAppListening({
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
effect: (action) => {
const enqueueResult = action.payload;
effect: async (action) => {
const response = action.payload;
const arg = action.meta.arg.originalArgs;
log.debug({ enqueueResult } as SerializableObject, 'Batch enqueued');
logger('queue').debug({ enqueueResult: parseify(response) }, 'Batch enqueued');
toast({
id: 'QUEUE_BATCH_SUCCEEDED',
title: t('queue.batchQueued'),
status: 'success',
description: t('queue.batchQueuedDesc', {
count: enqueueResult.enqueued,
count: response.enqueued,
direction: arg.prepend ? t('queue.front') : t('queue.back'),
}),
});
@ -34,9 +31,9 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
// error
startAppListening({
matcher: queueApi.endpoints.enqueueBatch.matchRejected,
effect: (action) => {
effect: async (action) => {
const response = action.payload;
const batchConfig = action.meta.arg.originalArgs;
const arg = action.meta.arg.originalArgs;
if (!response) {
toast({
@ -45,7 +42,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
status: 'error',
description: t('common.unknownError'),
});
log.error({ batchConfig } as SerializableObject, t('queue.batchFailedToQueue'));
logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue'));
return;
}
@ -71,7 +68,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
description: t('common.unknownError'),
});
}
log.error({ batchConfig, error: serializeError(response) } as SerializableObject, t('queue.batchFailedToQueue'));
logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue'));
},
});
};

View File

@ -1,31 +1,47 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice';
import { allLayersDeleted } from 'features/controlLayers/store/controlLayersSlice';
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { imagesApi } from 'services/api/endpoints/images';
export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.deleteBoardAndImages.matchFulfilled,
effect: (action, { dispatch, getState }) => {
effect: async (action, { dispatch, getState }) => {
const { deleted_images } = action.payload;
// Remove all deleted images from the UI
let wasCanvasReset = false;
let wasNodeEditorReset = false;
let wereControlAdaptersReset = false;
let wereControlLayersReset = false;
const state = getState();
const nodes = selectNodesSlice(state);
const canvas = selectCanvasSlice(state);
const { canvas, nodes, controlAdapters, controlLayers } = getState();
deleted_images.forEach((image_name) => {
const imageUsage = getImageUsage(nodes, canvas, image_name);
const imageUsage = getImageUsage(canvas, nodes.present, controlAdapters, controlLayers.present, image_name);
if (imageUsage.isCanvasImage && !wasCanvasReset) {
dispatch(resetCanvas());
wasCanvasReset = true;
}
if (imageUsage.isNodesImage && !wasNodeEditorReset) {
dispatch(nodeEditorReset());
wasNodeEditorReset = true;
}
if (imageUsage.isControlImage && !wereControlAdaptersReset) {
dispatch(controlAdaptersReset());
wereControlAdaptersReset = true;
}
if (imageUsage.isControlLayerImage && !wereControlLayersReset) {
dispatch(allLayersDeleted());
wereControlLayersReset = true;
}
});
},
});

View File

@ -1,15 +1,21 @@
import { ExternalLink } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import {
socketBulkDownloadComplete,
socketBulkDownloadError,
socketBulkDownloadStarted,
} from 'services/events/actions';
const log = logger('gallery');
const log = logger('images');
export const addBulkDownloadListeners = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.bulkDownloadImages.matchFulfilled,
effect: (action) => {
effect: async (action) => {
log.debug(action.payload, 'Bulk download requested');
// If we have an item name, we are processing the bulk download locally and should use it as the toast id to
@ -27,7 +33,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
startAppListening({
matcher: imagesApi.endpoints.bulkDownloadImages.matchRejected,
effect: () => {
effect: async () => {
log.debug('Bulk download request failed');
// There isn't any toast to update if we get this event.
@ -38,4 +44,55 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
});
},
});
startAppListening({
actionCreator: socketBulkDownloadStarted,
effect: async (action) => {
// This should always happen immediately after the bulk download request, so we don't need to show a toast here.
log.debug(action.payload.data, 'Bulk download preparation started');
},
});
startAppListening({
actionCreator: socketBulkDownloadComplete,
effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation completed');
const { bulk_download_item_name } = action.payload.data;
// TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first
const url = `/api/v1/images/download/${bulk_download_item_name}`;
toast({
id: bulk_download_item_name,
title: t('gallery.bulkDownloadReady', 'Download ready'),
status: 'success',
description: (
<ExternalLink
label={t('gallery.clickToDownload', 'Click here to download')}
href={url}
download={bulk_download_item_name}
/>
),
duration: null,
});
},
});
startAppListening({
actionCreator: socketBulkDownloadError,
effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation failed');
const { bulk_download_item_name } = action.payload.data;
toast({
id: bulk_download_item_name,
title: t('gallery.bulkDownloadFailed'),
status: 'error',
description: action.payload.data.error,
duration: null,
});
},
});
};

View File

@ -0,0 +1,38 @@
import { $logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasCopiedToClipboard } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
export const addCanvasCopiedToClipboardListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasCopiedToClipboard,
effect: async (action, { getState }) => {
const moduleLog = $logger.get().child({ namespace: 'canvasCopiedToClipboardListener' });
const state = getState();
try {
const blob = getBaseLayerBlob(state);
copyBlobToClipboard(blob);
} catch (err) {
moduleLog.error(String(err));
toast({
id: 'CANVAS_COPY_FAILED',
title: t('toast.problemCopyingCanvas'),
description: t('toast.problemCopyingCanvasDesc'),
status: 'error',
});
return;
}
toast({
id: 'CANVAS_COPY_SUCCEEDED',
title: t('toast.canvasCopiedClipboard'),
status: 'success',
});
},
});
};

View File

@ -0,0 +1,34 @@
import { $logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasDownloadedAsImage } from 'features/canvas/store/actions';
import { downloadBlob } from 'features/canvas/util/downloadBlob';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
export const addCanvasDownloadedAsImageListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasDownloadedAsImage,
effect: async (action, { getState }) => {
const moduleLog = $logger.get().child({ namespace: 'canvasSavedToGalleryListener' });
const state = getState();
let blob;
try {
blob = await getBaseLayerBlob(state);
} catch (err) {
moduleLog.error(String(err));
toast({
id: 'CANVAS_DOWNLOAD_FAILED',
title: t('toast.problemDownloadingCanvas'),
description: t('toast.problemDownloadingCanvasDesc'),
status: 'error',
});
return;
}
downloadBlob(blob, 'canvas.png');
toast({ id: 'CANVAS_DOWNLOAD_SUCCEEDED', title: t('toast.canvasDownloaded'), status: 'success' });
},
});
};

View File

@ -0,0 +1,60 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasImageToControlAdapter } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
export const addCanvasImageToControlNetListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasImageToControlAdapter,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const { id } = action.payload;
let blob: Blob;
try {
blob = await getBaseLayerBlob(state, true);
} catch (err) {
log.error(String(err));
toast({
id: 'PROBLEM_SAVING_CANVAS',
title: t('toast.problemSavingCanvas'),
description: t('toast.problemSavingCanvasDesc'),
status: 'error',
});
return;
}
const { autoAddBoardId } = state.gallery;
const imageDTO = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([blob], 'savedCanvas.png', {
type: 'image/png',
}),
image_category: 'control',
is_intermediate: true,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: false,
postUploadAction: {
type: 'TOAST',
title: t('toast.canvasSentControlnetAssets'),
},
})
).unwrap();
const { image_name } = imageDTO;
dispatch(
controlAdapterImageChanged({
id,
controlImage: image_name,
})
);
},
});
};

View File

@ -0,0 +1,60 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasMaskSavedToGallery } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
export const addCanvasMaskSavedToGalleryListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasMaskSavedToGallery,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const canvasBlobsAndImageData = await getCanvasData(
state.canvas.layerState,
state.canvas.boundingBoxCoordinates,
state.canvas.boundingBoxDimensions,
state.canvas.isMaskEnabled,
state.canvas.shouldPreserveMaskedArea
);
if (!canvasBlobsAndImageData) {
return;
}
const { maskBlob } = canvasBlobsAndImageData;
if (!maskBlob) {
log.error('Problem getting mask layer blob');
toast({
id: 'PROBLEM_SAVING_MASK',
title: t('toast.problemSavingMask'),
description: t('toast.problemSavingMaskDesc'),
status: 'error',
});
return;
}
const { autoAddBoardId } = state.gallery;
dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([maskBlob], 'canvasMaskImage.png', {
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: true,
postUploadAction: {
type: 'TOAST',
title: t('toast.maskSavedAssets'),
},
})
);
},
});
};

View File

@ -0,0 +1,70 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasMaskToControlAdapter } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
export const addCanvasMaskToControlNetListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasMaskToControlAdapter,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const { id } = action.payload;
const canvasBlobsAndImageData = await getCanvasData(
state.canvas.layerState,
state.canvas.boundingBoxCoordinates,
state.canvas.boundingBoxDimensions,
state.canvas.isMaskEnabled,
state.canvas.shouldPreserveMaskedArea
);
if (!canvasBlobsAndImageData) {
return;
}
const { maskBlob } = canvasBlobsAndImageData;
if (!maskBlob) {
log.error('Problem getting mask layer blob');
toast({
id: 'PROBLEM_IMPORTING_MASK',
title: t('toast.problemImportingMask'),
description: t('toast.problemImportingMaskDesc'),
status: 'error',
});
return;
}
const { autoAddBoardId } = state.gallery;
const imageDTO = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([maskBlob], 'canvasMaskImage.png', {
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: true,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: false,
postUploadAction: {
type: 'TOAST',
title: t('toast.maskSentControlnetAssets'),
},
})
).unwrap();
const { image_name } = imageDTO;
dispatch(
controlAdapterImageChanged({
id,
controlImage: image_name,
})
);
},
});
};

View File

@ -0,0 +1,73 @@
import { $logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasMerged } from 'features/canvas/store/actions';
import { $canvasBaseLayer } from 'features/canvas/store/canvasNanostore';
import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
import { getFullBaseLayerBlob } from 'features/canvas/util/getFullBaseLayerBlob';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
export const addCanvasMergedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasMerged,
effect: async (action, { dispatch }) => {
const moduleLog = $logger.get().child({ namespace: 'canvasCopiedToClipboardListener' });
const blob = await getFullBaseLayerBlob();
if (!blob) {
moduleLog.error('Problem getting base layer blob');
toast({
id: 'PROBLEM_MERGING_CANVAS',
title: t('toast.problemMergingCanvas'),
description: t('toast.problemMergingCanvasDesc'),
status: 'error',
});
return;
}
const canvasBaseLayer = $canvasBaseLayer.get();
if (!canvasBaseLayer) {
moduleLog.error('Problem getting canvas base layer');
toast({
id: 'PROBLEM_MERGING_CANVAS',
title: t('toast.problemMergingCanvas'),
description: t('toast.problemMergingCanvasDesc'),
status: 'error',
});
return;
}
const baseLayerRect = canvasBaseLayer.getClientRect({
relativeTo: canvasBaseLayer.getParent() ?? undefined,
});
const imageDTO = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([blob], 'mergedCanvas.png', {
type: 'image/png',
}),
image_category: 'general',
is_intermediate: true,
postUploadAction: {
type: 'TOAST',
title: t('toast.canvasMerged'),
},
})
).unwrap();
// TODO: I can't figure out how to do the type narrowing in the `take()` so just brute forcing it here
const { image_name } = imageDTO;
dispatch(
setMergedCanvas({
kind: 'image',
layer: 'base',
imageName: image_name,
...baseLayerRect,
})
);
},
});
};

View File

@ -0,0 +1,53 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { canvasSavedToGallery } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasSavedToGallery,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
let blob;
try {
blob = await getBaseLayerBlob(state);
} catch (err) {
log.error(String(err));
toast({
id: 'CANVAS_SAVE_FAILED',
title: t('toast.problemSavingCanvas'),
description: t('toast.problemSavingCanvasDesc'),
status: 'error',
});
return;
}
const { autoAddBoardId } = state.gallery;
dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([blob], 'savedCanvas.png', {
type: 'image/png',
}),
image_category: 'general',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: true,
postUploadAction: {
type: 'TOAST',
title: t('toast.canvasSavedGallery'),
},
metadata: {
_canvas_objects: parseify(state.canvas.layerState.objects),
},
})
);
},
});
};

View File

@ -0,0 +1,194 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch } from 'app/store/store';
import { parseify } from 'common/util/serialize';
import {
caLayerImageChanged,
caLayerModelChanged,
caLayerProcessedImageChanged,
caLayerProcessorConfigChanged,
caLayerProcessorPendingBatchIdChanged,
caLayerRecalled,
isControlAdapterLayer,
} from 'features/controlLayers/store/controlLayersSlice';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { isEqual } from 'lodash-es';
import { getImageDTO } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions';
import { assert } from 'tsafe';
const matcher = isAnyOf(
caLayerImageChanged,
caLayerProcessedImageChanged,
caLayerProcessorConfigChanged,
caLayerModelChanged,
caLayerRecalled
);
const DEBOUNCE_MS = 300;
const log = logger('session');
/**
* Simple helper to cancel a batch and reset the pending batch ID
*/
const cancelProcessorBatch = async (dispatch: AppDispatch, layerId: string, batchId: string) => {
const req = dispatch(queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: [batchId] }));
log.trace({ batchId }, 'Cancelling existing preprocessor batch');
try {
await req.unwrap();
} catch {
// no-op
} finally {
req.reset();
// Always reset the pending batch ID - the cancel req could fail if the batch doesn't exist
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: null }));
}
};
export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => {
startAppListening({
matcher,
effect: async (action, { dispatch, getState, getOriginalState, cancelActiveListeners, delay, take, signal }) => {
const layerId = caLayerRecalled.match(action) ? action.payload.id : action.payload.layerId;
const state = getState();
const originalState = getOriginalState();
// Cancel any in-progress instances of this listener
cancelActiveListeners();
log.trace('Control Layer CA auto-process triggered');
// Delay before starting actual work
await delay(DEBOUNCE_MS);
const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId);
if (!layer) {
return;
}
// We should only process if the processor settings or image have changed
const originalLayer = originalState.controlLayers.present.layers
.filter(isControlAdapterLayer)
.find((l) => l.id === layerId);
const originalImage = originalLayer?.controlAdapter.image;
const originalConfig = originalLayer?.controlAdapter.processorConfig;
const image = layer.controlAdapter.image;
const processedImage = layer.controlAdapter.processedImage;
const config = layer.controlAdapter.processorConfig;
if (isEqual(config, originalConfig) && isEqual(image, originalImage) && processedImage) {
// Neither config nor image have changed, we can bail
return;
}
if (!image || !config) {
// - If we have no image, we have nothing to process
// - If we have no processor config, we have nothing to process
// Clear the processed image and bail
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null }));
return;
}
// At this point, the user has stopped fiddling with the processor settings and there is a processor selected.
// If there is a pending processor batch, cancel it.
if (layer.controlAdapter.processorPendingBatchId) {
cancelProcessorBatch(dispatch, layerId, layer.controlAdapter.processorPendingBatchId);
}
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config as never);
const enqueueBatchArg: BatchConfig = {
prepend: true,
batch: {
graph: {
nodes: {
[processorNode.id]: {
...processorNode,
// Control images are always intermediate - do not save to gallery
is_intermediate: true,
},
},
edges: [],
},
runs: 1,
},
};
// Kick off the processor batch
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
fixedCacheKey: 'enqueueBatch',
})
);
try {
const enqueueResult = await req.unwrap();
// TODO(psyche): Update the pydantic models, pretty sure we will _always_ have a batch_id here, but the model says it's optional
assert(enqueueResult.batch.batch_id, 'Batch ID not returned from queue');
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: enqueueResult.batch.batch_id }));
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
// Wait for the processor node to complete
const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) &&
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
action.payload.data.invocation_source_id === processorNode.id
);
// We still have to check the output type
assert(
invocationCompleteAction.payload.data.result.type === 'image_output',
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
);
const { image_name } = invocationCompleteAction.payload.data.result.image;
const imageDTO = await getImageDTO(image_name);
assert(imageDTO, "Failed to fetch processor output's image DTO");
// Whew! We made it. Update the layer with the processed image
log.debug({ layerId, imageDTO }, 'ControlNet image processed');
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO }));
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: null }));
} catch (error) {
if (signal.aborted) {
// The listener was canceled - we need to cancel the pending processor batch, if there is one (could have changed by now).
const pendingBatchId = getState()
.controlLayers.present.layers.filter(isControlAdapterLayer)
.find((l) => l.id === layerId)?.controlAdapter.processorPendingBatchId;
if (pendingBatchId) {
cancelProcessorBatch(dispatch, layerId, pendingBatchId);
}
log.trace('Control Adapter preprocessor cancelled');
} else {
// Some other error condition...
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
if (error instanceof Object) {
if ('data' in error && 'status' in error) {
if (error.status === 403) {
dispatch(caLayerImageChanged({ layerId, imageDTO: null }));
return;
}
}
}
toast({
id: 'GRAPH_QUEUE_FAILED',
title: t('queue.graphFailedToQueue'),
status: 'error',
});
}
} finally {
req.reset();
}
},
});
};

View File

@ -0,0 +1,85 @@
import type { AnyListenerPredicate } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { RootState } from 'app/store/store';
import { controlAdapterImageProcessed } from 'features/controlAdapters/store/actions';
import {
controlAdapterAutoConfigToggled,
controlAdapterImageChanged,
controlAdapterModelChanged,
controlAdapterProcessorParamsChanged,
controlAdapterProcessortTypeChanged,
selectControlAdapterById,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
type AnyControlAdapterParamChangeAction =
| ReturnType<typeof controlAdapterProcessorParamsChanged>
| ReturnType<typeof controlAdapterModelChanged>
| ReturnType<typeof controlAdapterImageChanged>
| ReturnType<typeof controlAdapterProcessortTypeChanged>
| ReturnType<typeof controlAdapterAutoConfigToggled>;
const predicate: AnyListenerPredicate<RootState> = (action, state, prevState) => {
const isActionMatched =
controlAdapterProcessorParamsChanged.match(action) ||
controlAdapterModelChanged.match(action) ||
controlAdapterImageChanged.match(action) ||
controlAdapterProcessortTypeChanged.match(action) ||
controlAdapterAutoConfigToggled.match(action);
if (!isActionMatched) {
return false;
}
const { id } = action.payload;
const prevCA = selectControlAdapterById(prevState.controlAdapters, id);
const ca = selectControlAdapterById(state.controlAdapters, id);
if (!prevCA || !isControlNetOrT2IAdapter(prevCA) || !ca || !isControlNetOrT2IAdapter(ca)) {
return false;
}
if (controlAdapterAutoConfigToggled.match(action)) {
// do not process if the user just disabled auto-config
if (prevCA.shouldAutoConfig === true) {
return false;
}
}
const { controlImage, processorType, shouldAutoConfig } = ca;
if (controlAdapterModelChanged.match(action) && !shouldAutoConfig) {
// do not process if the action is a model change but the processor settings are dirty
return false;
}
const isProcessorSelected = processorType !== 'none';
const hasControlImage = Boolean(controlImage);
return isProcessorSelected && hasControlImage;
};
const DEBOUNCE_MS = 300;
/**
* Listener that automatically processes a ControlNet image when its processor parameters are changed.
*
* The network request is debounced.
*/
export const addControlNetAutoProcessListener = (startAppListening: AppStartListening) => {
startAppListening({
predicate,
effect: async (action, { dispatch, cancelActiveListeners, delay }) => {
const log = logger('session');
const { id } = (action as AnyControlAdapterParamChangeAction).payload;
// Cancel any in-progress instances of this listener
cancelActiveListeners();
log.trace('ControlNet auto-process triggered');
// Delay before starting actual work
await delay(DEBOUNCE_MS);
dispatch(controlAdapterImageProcessed({ id }));
},
});
};

View File

@ -0,0 +1,118 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { controlAdapterImageProcessed } from 'features/controlAdapters/store/actions';
import {
controlAdapterImageChanged,
controlAdapterProcessedImageChanged,
pendingControlImagesCleared,
selectControlAdapterById,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig, ImageDTO } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions';
export const addControlNetImageProcessedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: controlAdapterImageProcessed,
effect: async (action, { dispatch, getState, take }) => {
const log = logger('session');
const { id } = action.payload;
const ca = selectControlAdapterById(getState().controlAdapters, id);
if (!ca?.controlImage || !isControlNetOrT2IAdapter(ca)) {
log.error('Unable to process ControlNet image');
return;
}
if (ca.processorType === 'none' || ca.processorNode.type === 'none') {
return;
}
// ControlNet one-off procressing graph is just the processor node, no edges.
// Also we need to grab the image.
const nodeId = ca.processorNode.id;
const enqueueBatchArg: BatchConfig = {
prepend: true,
batch: {
graph: {
nodes: {
[ca.processorNode.id]: {
...ca.processorNode,
is_intermediate: true,
use_cache: false,
image: { image_name: ca.controlImage },
},
},
edges: [],
},
runs: 1,
},
};
try {
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
fixedCacheKey: 'enqueueBatch',
})
);
const enqueueResult = await req.unwrap();
req.reset();
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) &&
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
action.payload.data.invocation_source_id === nodeId
);
// We still have to check the output type
if (invocationCompleteAction.payload.data.result.type === 'image_output') {
const { image_name } = invocationCompleteAction.payload.data.result.image;
// Wait for the ImageDTO to be received
const [{ payload }] = await take(
(action) =>
imagesApi.endpoints.getImageDTO.matchFulfilled(action) && action.payload.image_name === image_name
);
const processedControlImage = payload as ImageDTO;
log.debug({ controlNetId: action.payload, processedControlImage }, 'ControlNet image processed');
// Update the processed image in the store
dispatch(
controlAdapterProcessedImageChanged({
id,
processedControlImage: processedControlImage.image_name,
})
);
}
} catch (error) {
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
if (error instanceof Object) {
if ('data' in error && 'status' in error) {
if (error.status === 403) {
dispatch(pendingControlImagesCleared());
dispatch(controlAdapterImageChanged({ id, controlImage: null }));
return;
}
}
}
toast({
id: 'GRAPH_QUEUE_FAILED',
title: t('queue.graphFailedToQueue'),
status: 'error',
});
}
},
});
};

View File

@ -0,0 +1,144 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { parseify } from 'common/util/serialize';
import { canvasBatchIdAdded, stagingAreaInitialized } from 'features/canvas/store/canvasSlice';
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
import { canvasGraphBuilt } from 'features/nodes/store/actions';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildCanvasGraph } from 'features/nodes/util/graph/canvas/buildCanvasGraph';
import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { ImageDTO } from 'services/api/types';
/**
* This listener is responsible invoking the canvas. This involves a number of steps:
*
* 1. Generate image blobs from the canvas layers
* 2. Determine the generation mode from the layers (txt2img, img2img, inpaint)
* 3. Build the canvas graph
* 4. Create the session with the graph
* 5. Upload the init image if necessary
* 6. Upload the mask image if necessary
* 7. Update the init and mask images with the session ID
* 8. Initialize the staging area if not yet initialized
* 9. Dispatch the sessionReadyToInvoke action to invoke the session
*/
export const addEnqueueRequestedCanvasListener = (startAppListening: AppStartListening) => {
startAppListening({
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
enqueueRequested.match(action) && action.payload.tabName === 'canvas',
effect: async (action, { getState, dispatch }) => {
const log = logger('queue');
const { prepend } = action.payload;
const state = getState();
const { layerState, boundingBoxCoordinates, boundingBoxDimensions, isMaskEnabled, shouldPreserveMaskedArea } =
state.canvas;
// Build canvas blobs
const canvasBlobsAndImageData = await getCanvasData(
layerState,
boundingBoxCoordinates,
boundingBoxDimensions,
isMaskEnabled,
shouldPreserveMaskedArea
);
if (!canvasBlobsAndImageData) {
log.error('Unable to create canvas data');
return;
}
const { baseBlob, baseImageData, maskBlob, maskImageData } = canvasBlobsAndImageData;
// Determine the generation mode
const generationMode = getCanvasGenerationMode(baseImageData, maskImageData);
if (state.system.enableImageDebugging) {
const baseDataURL = await blobToDataURL(baseBlob);
const maskDataURL = await blobToDataURL(maskBlob);
openBase64ImageInTab([
{ base64: maskDataURL, caption: 'mask b64' },
{ base64: baseDataURL, caption: 'image b64' },
]);
}
log.debug(`Generation mode: ${generationMode}`);
// Temp placeholders for the init and mask images
let canvasInitImage: ImageDTO | undefined;
let canvasMaskImage: ImageDTO | undefined;
// For img2img and inpaint/outpaint, we need to upload the init images
if (['img2img', 'inpaint', 'outpaint'].includes(generationMode)) {
// upload the image, saving the request id
canvasInitImage = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([baseBlob], 'canvasInitImage.png', {
type: 'image/png',
}),
image_category: 'general',
is_intermediate: true,
})
).unwrap();
}
// For inpaint/outpaint, we also need to upload the mask layer
if (['inpaint', 'outpaint'].includes(generationMode)) {
// upload the image, saving the request id
canvasMaskImage = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([maskBlob], 'canvasMaskImage.png', {
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: true,
})
).unwrap();
}
const graph = await buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
log.debug({ graph: parseify(graph) }, `Canvas graph built`);
// currently this action is just listened to for logging
dispatch(canvasGraphBuilt(graph));
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
try {
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
const enqueueResult = await req.unwrap();
req.reset();
const batchId = enqueueResult.batch.batch_id as string; // we know the is a string, backend provides it
// Prep the canvas staging area if it is not yet initialized
if (!state.canvas.layerState.stagingArea.boundingBox) {
dispatch(
stagingAreaInitialized({
boundingBox: {
...state.canvas.boundingBoxCoordinates,
...state.canvas.boundingBoxDimensions,
},
})
);
}
// Associate the session with the canvas session ID
dispatch(canvasBatchIdAdded(batchId));
} catch {
// no-op
}
},
});
};

View File

@ -1,21 +1,10 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import type { Result } from 'common/util/result';
import { isErr, withResult, withResultAsync } from 'common/util/result';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import { sessionStagingAreaReset, sessionStartedStaging } from 'features/controlLayers/store/canvasSessionSlice';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { serializeError } from 'serialize-error';
import { buildGenerationTabGraph } from 'features/nodes/util/graph/generation/buildGenerationTabGraph';
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/generation/buildGenerationTabSDXLGraph';
import { queueApi } from 'services/api/endpoints/queue';
import type { Invocation } from 'services/api/types';
import { assert } from 'tsafe';
const log = logger('generation');
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
startAppListening({
@ -23,77 +12,33 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
enqueueRequested.match(action) && action.payload.tabName === 'generation',
effect: async (action, { getState, dispatch }) => {
const state = getState();
const model = state.params.model;
const { shouldShowProgressInViewer } = state.ui;
const model = state.generation.model;
const { prepend } = action.payload;
const manager = $canvasManager.get();
assert(manager, 'No model found in state');
let graph;
let didStartStaging = false;
if (!state.canvasSession.isStaging && state.canvasSession.mode === 'compose') {
dispatch(sessionStartedStaging());
didStartStaging = true;
if (model?.base === 'sdxl') {
graph = await buildGenerationTabSDXLGraph(state);
} else {
graph = await buildGenerationTabGraph(state);
}
const abortStaging = () => {
if (didStartStaging && getState().canvasSession.isStaging) {
dispatch(sessionStagingAreaReset());
}
};
let buildGraphResult: Result<
{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel' | 'sdxl_compel_prompt'> },
Error
>;
assert(model, 'No model found in state');
const base = model.base;
switch (base) {
case 'sdxl':
buildGraphResult = await withResultAsync(() => buildSDXLGraph(state, manager));
break;
case 'sd-1':
case `sd-2`:
buildGraphResult = await withResultAsync(() => buildSD1Graph(state, manager));
break;
default:
assert(false, `No graph builders for base ${base}`);
}
if (isErr(buildGraphResult)) {
log.error({ error: serializeError(buildGraphResult.error) }, 'Failed to build graph');
abortStaging();
return;
}
const { g, noise, posCond } = buildGraphResult.value;
const prepareBatchResult = withResult(() => prepareLinearUIBatch(state, g, prepend, noise, posCond));
if (isErr(prepareBatchResult)) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
abortStaging();
return;
}
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(prepareBatchResult.value, {
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
req.reset();
const enqueueResult = await withResultAsync(() => req.unwrap());
if (isErr(enqueueResult)) {
log.error({ error: serializeError(enqueueResult.error) }, 'Failed to enqueue batch');
abortStaging();
return;
try {
await req.unwrap();
if (shouldShowProgressInViewer) {
dispatch(isImageViewerOpenChanged(true));
}
} finally {
req.reset();
}
log.debug({ batchConfig: prepareBatchResult.value } as SerializableObject, 'Enqueued batch');
},
});
};

View File

@ -1,6 +1,5 @@
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
import { queueApi } from 'services/api/endpoints/queue';
@ -12,12 +11,12 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
enqueueRequested.match(action) && action.payload.tabName === 'workflows',
effect: async (action, { getState, dispatch }) => {
const state = getState();
const nodes = selectNodesSlice(state);
const { nodes, edges } = state.nodes.present;
const workflow = state.workflow;
const graph = buildNodesGraph(nodes);
const graph = buildNodesGraph(state.nodes.present);
const builtWorkflow = buildWorkflowWithValidation({
nodes: nodes.nodes,
edges: nodes.edges,
nodes,
edges,
workflow,
});
@ -30,8 +29,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
batch: {
graph,
workflow: builtWorkflow,
runs: state.params.iterations,
origin: 'workflows',
runs: state.generation.iterations,
},
prepend: action.payload.prepend,
};

View File

@ -14,9 +14,9 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
const { shouldShowProgressInViewer } = state.ui;
const { prepend } = action.payload;
const { g, noise, posCond } = await buildMultidiffusionUpscaleGraph(state);
const graph = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond);
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {

View File

@ -27,7 +27,7 @@ export const galleryImageClicked = createAction<{
export const addGalleryImageClickedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: galleryImageClicked,
effect: (action, { dispatch, getState }) => {
effect: async (action, { dispatch, getState }) => {
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
const state = getState();
const queryArgs = selectListImagesQueryArgs(state);

View File

@ -1,27 +1,24 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import { parseify } from 'common/util/serialize';
import { $templates } from 'features/nodes/store/nodesSlice';
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
import { size } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { appInfoApi } from 'services/api/endpoints/appInfo';
const log = logger('system');
export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: appInfoApi.endpoints.getOpenAPISchema.matchFulfilled,
effect: (action, { getState }) => {
const log = logger('system');
const schemaJSON = action.payload;
log.debug({ schemaJSON: parseify(schemaJSON) } as SerializableObject, 'Received OpenAPI schema');
log.debug({ schemaJSON: parseify(schemaJSON) }, 'Received OpenAPI schema');
const { nodesAllowlist, nodesDenylist } = getState().config;
const nodeTemplates = parseSchema(schemaJSON, nodesAllowlist, nodesDenylist);
log.debug({ nodeTemplates } as SerializableObject, `Built ${size(nodeTemplates)} node templates`);
log.debug({ nodeTemplates: parseify(nodeTemplates) }, `Built ${size(nodeTemplates)} node templates`);
$templates.set(nodeTemplates);
},
@ -33,7 +30,8 @@ export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening
// If action.meta.condition === true, the request was canceled/skipped because another request was in flight or
// the value was already in the cache. We don't want to log these errors.
if (!action.meta.condition) {
log.error({ error: serializeError(action.error) }, 'Problem retrieving OpenAPI Schema');
const log = logger('system');
log.error({ error: parseify(action.error) }, 'Problem retrieving OpenAPI Schema');
}
},
});

View File

@ -2,13 +2,15 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { imagesApi } from 'services/api/endpoints/images';
const log = logger('gallery');
export const addImageAddedToBoardFulfilledListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.addImageToBoard.matchFulfilled,
effect: (action) => {
const log = logger('images');
const { board_id, imageDTO } = action.meta.arg.originalArgs;
// TODO: update listImages cache for this board
log.debug({ board_id, imageDTO }, 'Image added to board');
},
});
@ -16,7 +18,9 @@ export const addImageAddedToBoardFulfilledListener = (startAppListening: AppStar
startAppListening({
matcher: imagesApi.endpoints.addImageToBoard.matchRejected,
effect: (action) => {
const log = logger('images');
const { board_id, imageDTO } = action.meta.arg.originalArgs;
log.debug({ board_id, imageDTO }, 'Problem adding image to board');
},
});

View File

@ -1,9 +1,20 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import { entityDeleted, ipaImageChanged } from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import {
controlAdapterImageChanged,
controlAdapterProcessedImageChanged,
selectControlAdapterAll,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import {
isControlAdapterLayer,
isInitialImageLayer,
isIPAdapterLayer,
isRegionalGuidanceLayer,
layerDeleted,
} from 'features/controlLayers/store/controlLayersSlice';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
@ -15,10 +26,6 @@ import { forEach, intersectionBy } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
const log = logger('gallery');
//TODO(psyche): handle image deletion (canvas sessions?)
// Some utils to delete images from different parts of the app
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
state.nodes.present.nodes.forEach((node) => {
@ -40,37 +47,52 @@ const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: Im
});
};
// const deleteControlAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
// state.canvas.present.controlAdapters.entities.forEach(({ id, imageObject, processedImageObject }) => {
// if (
// imageObject?.image.image_name === imageDTO.image_name ||
// processedImageObject?.image.image_name === imageDTO.image_name
// ) {
// dispatch(caImageChanged({ id, imageDTO: null }));
// dispatch(caProcessedImageChanged({ id, imageDTO: null }));
// }
// });
// };
const deleteIPAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
selectCanvasSlice(state).ipAdapters.entities.forEach((entity) => {
if (entity.ipAdapter.image?.image_name === imageDTO.image_name) {
dispatch(ipaImageChanged({ entityIdentifier: getEntityIdentifier(entity), imageDTO: null }));
const deleteControlAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
forEach(selectControlAdapterAll(state.controlAdapters), (ca) => {
if (
ca.controlImage === imageDTO.image_name ||
(isControlNetOrT2IAdapter(ca) && ca.processedControlImage === imageDTO.image_name)
) {
dispatch(
controlAdapterImageChanged({
id: ca.id,
controlImage: null,
})
);
dispatch(
controlAdapterProcessedImageChanged({
id: ca.id,
processedControlImage: null,
})
);
}
});
};
const deleteLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
selectCanvasSlice(state).rasterLayers.entities.forEach(({ id, objects }) => {
let shouldDelete = false;
for (const obj of objects) {
if (obj.type === 'image' && obj.image.image_name === imageDTO.image_name) {
shouldDelete = true;
break;
const deleteControlLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
state.controlLayers.present.layers.forEach((l) => {
if (isRegionalGuidanceLayer(l)) {
if (l.ipAdapters.some((ipa) => ipa.image?.name === imageDTO.image_name)) {
dispatch(layerDeleted(l.id));
}
}
if (shouldDelete) {
dispatch(entityDeleted({ entityIdentifier: { id, type: 'raster_layer' } }));
if (isControlAdapterLayer(l)) {
if (
l.controlAdapter.image?.name === imageDTO.image_name ||
l.controlAdapter.processedImage?.name === imageDTO.image_name
) {
dispatch(layerDeleted(l.id));
}
}
if (isIPAdapterLayer(l)) {
if (l.ipAdapter.image?.name === imageDTO.image_name) {
dispatch(layerDeleted(l.id));
}
}
if (isInitialImageLayer(l)) {
if (l.image?.name === imageDTO.image_name) {
dispatch(layerDeleted(l.id));
}
}
});
};
@ -123,10 +145,14 @@ export const addImageDeletionListeners = (startAppListening: AppStartListening)
}
}
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
if (imageUsage.isCanvasImage) {
dispatch(resetCanvas());
}
deleteControlAdapterImages(state, dispatch, imageDTO);
deleteNodesImages(state, dispatch, imageDTO);
// deleteControlAdapterImages(state, dispatch, imageDTO);
deleteIPAdapterImages(state, dispatch, imageDTO);
deleteLayerImages(state, dispatch, imageDTO);
deleteControlLayerImages(state, dispatch, imageDTO);
} catch {
// no-op
} finally {
@ -163,11 +189,14 @@ export const addImageDeletionListeners = (startAppListening: AppStartListening)
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
if (imagesUsage.some((i) => i.isCanvasImage)) {
dispatch(resetCanvas());
}
imageDTOs.forEach((imageDTO) => {
deleteControlAdapterImages(state, dispatch, imageDTO);
deleteNodesImages(state, dispatch, imageDTO);
// deleteControlAdapterImages(state, dispatch, imageDTO);
deleteIPAdapterImages(state, dispatch, imageDTO);
deleteLayerImages(state, dispatch, imageDTO);
deleteControlLayerImages(state, dispatch, imageDTO);
});
} catch {
// no-op
@ -191,6 +220,7 @@ export const addImageDeletionListeners = (startAppListening: AppStartListening)
startAppListening({
matcher: imagesApi.endpoints.deleteImage.matchFulfilled,
effect: (action) => {
const log = logger('images');
log.debug({ imageDTO: action.meta.arg.originalArgs }, 'Image deleted');
},
});
@ -198,6 +228,7 @@ export const addImageDeletionListeners = (startAppListening: AppStartListening)
startAppListening({
matcher: imagesApi.endpoints.deleteImage.matchRejected,
effect: (action) => {
const log = logger('images');
log.debug({ imageDTO: action.meta.arg.originalArgs }, 'Unable to delete image');
},
});

View File

@ -1,19 +1,28 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import {
controlLayerAdded,
ipaImageChanged,
rasterLayerAdded,
rgIPAdapterImageChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { CanvasControlLayerState, CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
controlAdapterImageChanged,
controlAdapterIsEnabledChanged,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import {
caLayerImageChanged,
iiLayerImageChanged,
ipaLayerImageChanged,
rgLayerIPAdapterImageChanged,
} from 'features/controlLayers/store/controlLayersSlice';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { imageToCompareChanged, isImageViewerOpenChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import {
imageSelected,
imageToCompareChanged,
isImageViewerOpenChanged,
selectionChanged,
} from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { imagesApi } from 'services/api/endpoints/images';
@ -22,12 +31,11 @@ export const dndDropped = createAction<{
activeData: TypesafeDraggableData;
}>('dnd/dndDropped');
const log = logger('system');
export const addImageDroppedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: dndDropped,
effect: (action, { dispatch, getState }) => {
effect: async (action, { dispatch, getState }) => {
const log = logger('dnd');
const { activeData, overData } = action.payload;
if (!isValidDrop(overData, activeData)) {
return;
@ -38,22 +46,80 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
} else if (activeData.payloadType === 'GALLERY_SELECTION') {
log.debug({ activeData, overData }, `Images (${getState().gallery.selection.length}) dropped`);
} else if (activeData.payloadType === 'NODE_FIELD') {
log.debug({ activeData, overData }, 'Node field dropped');
log.debug({ activeData: parseify(activeData), overData: parseify(overData) }, 'Node field dropped');
} else {
log.debug({ activeData, overData }, `Unknown payload dropped`);
}
/**
* Image dropped on current image
*/
if (
overData.actionType === 'SET_CURRENT_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(imageSelected(activeData.payload.imageDTO));
dispatch(isImageViewerOpenChanged(true));
return;
}
/**
* Image dropped on ControlNet
*/
if (
overData.actionType === 'SET_CONTROL_ADAPTER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { id } = overData.context;
dispatch(
controlAdapterImageChanged({
id,
controlImage: activeData.payload.imageDTO.image_name,
})
);
dispatch(
controlAdapterIsEnabledChanged({
id,
isEnabled: true,
})
);
return;
}
/**
* Image dropped on Control Adapter Layer
*/
if (
overData.actionType === 'SET_CA_LAYER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { layerId } = overData.context;
dispatch(
caLayerImageChanged({
layerId,
imageDTO: activeData.payload.imageDTO,
})
);
return;
}
/**
* Image dropped on IP Adapter Layer
*/
if (
overData.actionType === 'SET_IPA_IMAGE' &&
overData.actionType === 'SET_IPA_LAYER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { id } = overData.context;
const { layerId } = overData.context;
dispatch(
ipaImageChanged({ entityIdentifier: { id, type: 'ip_adapter' }, imageDTO: activeData.payload.imageDTO })
ipaLayerImageChanged({
layerId,
imageDTO: activeData.payload.imageDTO,
})
);
return;
}
@ -62,14 +128,14 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
* Image dropped on RG Layer IP Adapter
*/
if (
overData.actionType === 'SET_RG_IP_ADAPTER_IMAGE' &&
overData.actionType === 'SET_RG_LAYER_IP_ADAPTER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { id, ipAdapterId } = overData.context;
const { layerId, ipAdapterId } = overData.context;
dispatch(
rgIPAdapterImageChanged({
entityIdentifier: { id, type: 'regional_guidance' },
rgLayerIPAdapterImageChanged({
layerId,
ipAdapterId,
imageDTO: activeData.payload.imageDTO,
})
@ -78,38 +144,32 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
}
/**
* Image dropped on Raster layer
* Image dropped on II Layer Image
*/
if (
overData.actionType === 'ADD_RASTER_LAYER_FROM_IMAGE' &&
overData.actionType === 'SET_II_LAYER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
const overrides: Partial<CanvasRasterLayerState> = {
objects: [imageObject],
position: { x, y },
};
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
const { layerId } = overData.context;
dispatch(
iiLayerImageChanged({
layerId,
imageDTO: activeData.payload.imageDTO,
})
);
return;
}
/**
* Image dropped on Raster layer
* Image dropped on Canvas
*/
if (
overData.actionType === 'ADD_CONTROL_LAYER_FROM_IMAGE' &&
overData.actionType === 'SET_CANVAS_INITIAL_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
const overrides: Partial<CanvasControlLayerState> = {
objects: [imageObject],
position: { x, y },
};
dispatch(controlLayerAdded({ overrides, isSelected: true }));
dispatch(setInitialCanvasImage(activeData.payload.imageDTO, selectOptimalDimension(getState())));
return;
}

View File

@ -2,13 +2,13 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { imagesApi } from 'services/api/endpoints/images';
const log = logger('gallery');
export const addImageRemovedFromBoardFulfilledListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.removeImageFromBoard.matchFulfilled,
effect: (action) => {
const log = logger('images');
const imageDTO = action.meta.arg.originalArgs;
log.debug({ imageDTO }, 'Image removed from board');
},
});
@ -16,7 +16,9 @@ export const addImageRemovedFromBoardFulfilledListener = (startAppListening: App
startAppListening({
matcher: imagesApi.endpoints.removeImageFromBoard.matchRejected,
effect: (action) => {
const log = logger('images');
const imageDTO = action.meta.arg.originalArgs;
log.debug({ imageDTO }, 'Problem removing image from board');
},
});

View File

@ -6,17 +6,16 @@ import { imagesToDeleteSelected, isModalOpenChanged } from 'features/deleteImage
export const addImageToDeleteSelectedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: imagesToDeleteSelected,
effect: (action, { dispatch, getState }) => {
effect: async (action, { dispatch, getState }) => {
const imageDTOs = action.payload;
const state = getState();
const { shouldConfirmOnDelete } = state.system;
const imagesUsage = selectImageUsage(getState());
const isImageInUse =
imagesUsage.some((i) => i.isLayerImage) ||
imagesUsage.some((i) => i.isControlAdapterImage) ||
imagesUsage.some((i) => i.isIPAdapterImage) ||
imagesUsage.some((i) => i.isLayerImage);
imagesUsage.some((i) => i.isCanvasImage) ||
imagesUsage.some((i) => i.isControlImage) ||
imagesUsage.some((i) => i.isNodesImage);
if (shouldConfirmOnDelete || isImageInUse) {
dispatch(isModalOpenChanged(true));

View File

@ -1,8 +1,19 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { ipaImageChanged, rgIPAdapterImageChanged } from 'features/controlLayers/store/canvasSlice';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import {
controlAdapterImageChanged,
controlAdapterIsEnabledChanged,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import {
caLayerImageChanged,
iiLayerImageChanged,
ipaLayerImageChanged,
rgLayerIPAdapterImageChanged,
} from 'features/controlLayers/store/controlLayersSlice';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
@ -10,12 +21,11 @@ import { omit } from 'lodash-es';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
const log = logger('gallery');
export const addImageUploadedFulfilledListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.uploadImage.matchFulfilled,
effect: (action, { dispatch, getState }) => {
const log = logger('images');
const imageDTO = action.payload;
const state = getState();
const { autoAddBoardId } = state.gallery;
@ -71,6 +81,15 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
return;
}
if (postUploadAction?.type === 'SET_CANVAS_INITIAL_IMAGE') {
dispatch(setInitialCanvasImage(imageDTO, selectOptimalDimension(state)));
toast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setAsCanvasInitialImage'),
});
return;
}
if (postUploadAction?.type === 'SET_UPSCALE_INITIAL_IMAGE') {
dispatch(upscaleInitialImageChanged(imageDTO));
toast({
@ -80,33 +99,70 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
return;
}
// if (postUploadAction?.type === 'SET_CA_IMAGE') {
// const { id } = postUploadAction;
// dispatch(caImageChanged({ id, imageDTO }));
// toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
// return;
// }
if (postUploadAction?.type === 'SET_IPA_IMAGE') {
if (postUploadAction?.type === 'SET_CONTROL_ADAPTER_IMAGE') {
const { id } = postUploadAction;
dispatch(ipaImageChanged({ entityIdentifier: { id, type: 'ip_adapter' }, imageDTO }));
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
dispatch(
controlAdapterIsEnabledChanged({
id,
isEnabled: true,
})
);
dispatch(
controlAdapterImageChanged({
id,
controlImage: imageDTO.image_name,
})
);
toast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
});
return;
}
if (postUploadAction?.type === 'SET_RG_IP_ADAPTER_IMAGE') {
const { id, ipAdapterId } = postUploadAction;
dispatch(
rgIPAdapterImageChanged({ entityIdentifier: { id, type: 'regional_guidance' }, ipAdapterId, imageDTO })
);
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
return;
if (postUploadAction?.type === 'SET_CA_LAYER_IMAGE') {
const { layerId } = postUploadAction;
dispatch(caLayerImageChanged({ layerId, imageDTO }));
toast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
});
}
if (postUploadAction?.type === 'SET_IPA_LAYER_IMAGE') {
const { layerId } = postUploadAction;
dispatch(ipaLayerImageChanged({ layerId, imageDTO }));
toast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
});
}
if (postUploadAction?.type === 'SET_RG_LAYER_IP_ADAPTER_IMAGE') {
const { layerId, ipAdapterId } = postUploadAction;
dispatch(rgLayerIPAdapterImageChanged({ layerId, ipAdapterId, imageDTO }));
toast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
});
}
if (postUploadAction?.type === 'SET_II_LAYER_IMAGE') {
const { layerId } = postUploadAction;
dispatch(iiLayerImageChanged({ layerId, imageDTO }));
toast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
});
}
if (postUploadAction?.type === 'SET_NODES_IMAGE') {
const { nodeId, fieldName } = postUploadAction;
dispatch(fieldImageValueChanged({ nodeId, fieldName, value: imageDTO }));
toast({ ...DEFAULT_UPLOADED_TOAST, description: `${t('toast.setNodeField')} ${fieldName}` });
toast({
...DEFAULT_UPLOADED_TOAST,
description: `${t('toast.setNodeField')} ${fieldName}`,
});
return;
}
},
@ -115,6 +171,7 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
startAppListening({
matcher: imagesApi.endpoints.uploadImage.matchRejected,
effect: (action) => {
const log = logger('images');
const sanitizedData = {
arg: {
...omit(action.meta.arg.originalArgs, ['file', 'postUploadAction']),

View File

@ -6,7 +6,7 @@ import type { ImageDTO } from 'services/api/types';
export const addImagesStarredListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.starImages.matchFulfilled,
effect: (action, { dispatch, getState }) => {
effect: async (action, { dispatch, getState }) => {
const { updated_image_names: starredImages } = action.payload;
const state = getState();

View File

@ -6,7 +6,7 @@ import type { ImageDTO } from 'services/api/types';
export const addImagesUnstarredListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.unstarImages.matchFulfilled,
effect: (action, { dispatch, getState }) => {
effect: async (action, { dispatch, getState }) => {
const { updated_image_names: unstarredImages } = action.payload;
const state = getState();

View File

@ -1,18 +1,23 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import {
controlAdapterIsEnabledChanged,
selectControlAdapterAll,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { modelSelected } from 'features/parameters/store/actions';
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
const log = logger('models');
import { forEach } from 'lodash-es';
export const addModelSelectedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: modelSelected,
effect: (action, { getState, dispatch }) => {
const log = logger('models');
const state = getState();
const result = zParameterModel.safeParse(action.payload);
@ -24,36 +29,34 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
const newModel = result.data;
const newBaseModel = newModel.base;
const didBaseModelChange = state.params.model?.base !== newBaseModel;
const didBaseModelChange = state.generation.model?.base !== newBaseModel;
if (didBaseModelChange) {
// we may need to reset some incompatible submodels
let modelsCleared = 0;
// handle incompatible loras
state.loras.loras.forEach((lora) => {
forEach(state.lora.loras, (lora, id) => {
if (lora.model.base !== newBaseModel) {
dispatch(loraDeleted({ id: lora.id }));
dispatch(loraRemoved(id));
modelsCleared += 1;
}
});
// handle incompatible vae
const { vae } = state.params;
const { vae } = state.generation;
if (vae && vae.base !== newBaseModel) {
dispatch(vaeSelected(null));
modelsCleared += 1;
}
// handle incompatible controlnets
// state.canvas.present.controlAdapters.entities.forEach((ca) => {
// if (ca.model?.base !== newBaseModel) {
// modelsCleared += 1;
// if (ca.isEnabled) {
// dispatch(entityIsEnabledToggled({ entityIdentifier: { id: ca.id, type: 'control_adapter' } }));
// }
// }
// });
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
if (ca.model?.base !== newBaseModel) {
dispatch(controlAdapterIsEnabledChanged({ id: ca.id, isEnabled: false }));
modelsCleared += 1;
}
});
if (modelsCleared > 0) {
toast({
@ -67,7 +70,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
}
}
dispatch(modelChanged({ model: newModel, previousModel: state.params.model }));
dispatch(modelChanged(newModel, state.generation.model));
},
});
};

View File

@ -1,42 +1,36 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import type { SerializableObject } from 'common/types';
import type { JSONObject } from 'common/types';
import {
bboxHeightChanged,
bboxWidthChanged,
controlLayerModelChanged,
ipaModelChanged,
rgIPAdapterModelChanged,
} from 'features/controlLayers/store/canvasSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, refinerModelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { calculateNewSize } from 'features/parameters/components/DocumentSize/calculateNewSize';
controlAdapterModelCleared,
selectControlAdapterAll,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { heightChanged, widthChanged } from 'features/controlLayers/store/controlLayersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import { forEach } from 'lodash-es';
import type { Logger } from 'roarr';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import {
isControlNetOrT2IAdapterModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
isRefinerMainModelModelConfig,
isSpandrelImageToImageModelConfig,
isVAEModelConfig,
} from 'services/api/types';
const log = logger('models');
export const addModelsLoadedListener = (startAppListening: AppStartListening) => {
startAppListening({
predicate: modelsApi.endpoints.getModelConfigs.matchFulfilled,
effect: (action, { getState, dispatch }) => {
effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models');
log.info({ models: action.payload.entities }, `Models loaded (${action.payload.ids.length})`);
const state = getState();
@ -49,7 +43,6 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
handleLoRAModels(models, state, dispatch, log);
handleControlAdapterModels(models, state, dispatch, log);
handleSpandrelImageToImageModels(models, state, dispatch, log);
handleIPAdapterModels(models, state, dispatch, log);
},
});
};
@ -58,15 +51,15 @@ type ModelHandler = (
models: AnyModelConfig[],
state: RootState,
dispatch: AppDispatch,
log: Logger<SerializableObject>
log: Logger<JSONObject>
) => undefined;
const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
const currentModel = state.params.model;
const currentModel = state.generation.model;
const mainModels = models.filter(isNonRefinerMainModelConfig);
if (mainModels.length === 0) {
// No models loaded at all
dispatch(modelChanged({ model: null }));
dispatch(modelChanged(null));
return;
}
@ -81,16 +74,25 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
if (defaultModelInList) {
const result = zParameterModel.safeParse(defaultModelInList);
if (result.success) {
dispatch(modelChanged({ model: defaultModelInList, previousModel: currentModel }));
const { bbox } = selectCanvasSlice(state);
dispatch(modelChanged(defaultModelInList, currentModel));
const optimalDimension = getOptimalDimension(defaultModelInList);
if (getIsSizeOptimal(bbox.rect.width, bbox.rect.height, optimalDimension)) {
if (
getIsSizeOptimal(
state.controlLayers.present.size.width,
state.controlLayers.present.size.height,
optimalDimension
)
) {
return;
}
const { width, height } = calculateNewSize(bbox.aspectRatio.value, optimalDimension * optimalDimension);
const { width, height } = calculateNewSize(
state.controlLayers.present.size.aspectRatio.value,
optimalDimension * optimalDimension
);
dispatch(bboxWidthChanged({ width }));
dispatch(bboxHeightChanged({ height }));
dispatch(widthChanged({ width }));
dispatch(heightChanged({ height }));
return;
}
}
@ -102,11 +104,11 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
return;
}
dispatch(modelChanged({ model: result.data, previousModel: currentModel }));
dispatch(modelChanged(result.data, currentModel));
};
const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
const currentRefinerModel = state.params.refinerModel;
const currentRefinerModel = state.sdxl.refinerModel;
const refinerModels = models.filter(isRefinerMainModelModelConfig);
if (models.length === 0) {
// No models loaded at all
@ -125,7 +127,7 @@ const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
};
const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
const currentVae = state.params.vae;
const currentVae = state.generation.vae;
if (currentVae === null) {
// null is a valid VAE! it means "use the default with the main model"
@ -158,47 +160,28 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
};
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
const loraModels = models.filter(isLoRAModelConfig);
state.loras.loras.forEach((lora) => {
const isLoRAAvailable = loraModels.some((m) => m.key === lora.model.key);
const loras = state.lora.loras;
forEach(loras, (lora, id) => {
const isLoRAAvailable = models.some((m) => m.key === lora.model.key);
if (isLoRAAvailable) {
return;
}
dispatch(loraDeleted({ id: lora.id }));
dispatch(loraRemoved(id));
});
};
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
const caModels = models.filter(isControlNetOrT2IAdapterModelConfig);
selectCanvasSlice(state).controlLayers.entities.forEach((entity) => {
const isModelAvailable = caModels.some((m) => m.key === entity.controlAdapter.model?.key);
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
const isModelAvailable = models.some((m) => m.key === ca.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlLayerModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
});
};
const handleIPAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
const ipaModels = models.filter(isIPAdapterModelConfig);
selectCanvasSlice(state).ipAdapters.entities.forEach((entity) => {
const isModelAvailable = ipaModels.some((m) => m.key === entity.ipAdapter.model?.key);
if (isModelAvailable) {
return;
}
dispatch(ipaModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
});
selectCanvasSlice(state).regions.entities.forEach((entity) => {
entity.ipAdapters.forEach(({ id: ipAdapterId, model }) => {
const isModelAvailable = ipaModels.some((m) => m.key === model?.key);
if (isModelAvailable) {
return;
}
dispatch(
rgIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), ipAdapterId, modelConfig: null })
);
});
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
};

View File

@ -1,6 +1,6 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { positivePromptChanged } from 'features/controlLayers/store/paramsSlice';
import { positivePromptChanged } from 'features/controlLayers/store/controlLayersSlice';
import {
combinatorialToggled,
isErrorChanged,
@ -15,7 +15,7 @@ import { getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilder
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
import { utilitiesApi } from 'services/api/endpoints/utilities';
import { socketConnected } from 'services/events/setEventListeners';
import { socketConnected } from 'services/events/actions';
const matcher = isAnyOf(
positivePromptChanged,
@ -24,6 +24,8 @@ const matcher = isAnyOf(
maxPromptsReset,
socketConnected,
activeStylePresetIdChanged,
stylePresetsApi.endpoints.deleteStylePreset.matchFulfilled,
stylePresetsApi.endpoints.updateStylePreset.matchFulfilled,
stylePresetsApi.endpoints.listStylePresets.matchFulfilled
);

View File

@ -1,5 +1,6 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
import { heightChanged, widthChanged } from 'features/controlLayers/store/controlLayersSlice';
import { setDefaultSettings } from 'features/parameters/store/actions';
import {
setCfgRescaleMultiplier,
setCfgScale,
@ -7,8 +8,7 @@ import {
setSteps,
vaePrecisionChanged,
vaeSelected,
} from 'features/controlLayers/store/paramsSlice';
import { setDefaultSettings } from 'features/parameters/store/actions';
} from 'features/parameters/store/generationSlice';
import {
isParameterCFGRescaleMultiplier,
isParameterCFGScale,
@ -30,7 +30,7 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
effect: async (action, { dispatch, getState }) => {
const state = getState();
const currentModel = state.params.model;
const currentModel = state.generation.model;
if (!currentModel) {
return;
@ -98,13 +98,13 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
const setSizeOptions = { updateAspectRatio: true, clamp: true };
if (width) {
if (isParameterWidth(width)) {
dispatch(bboxWidthChanged({ width, ...setSizeOptions }));
dispatch(widthChanged({ width, ...setSizeOptions }));
}
}
if (height) {
if (isParameterHeight(height)) {
dispatch(bboxHeightChanged({ height, ...setSizeOptions }));
dispatch(heightChanged({ height, ...setSizeOptions }));
}
}

View File

@ -6,9 +6,9 @@ import { atom } from 'nanostores';
import { api } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';
import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
import { socketConnected } from 'services/events/setEventListeners';
import { socketConnected } from 'services/events/actions';
const log = logger('events');
const log = logger('socketio');
const $isFirstConnection = atom(true);

View File

@ -0,0 +1,14 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketDisconnected } from 'services/events/actions';
const log = logger('socketio');
export const addSocketDisconnectedEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketDisconnected,
effect: () => {
log.debug('Disconnected');
},
});
};

View File

@ -0,0 +1,26 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketGeneratorProgress } from 'services/events/actions';
const log = logger('socketio');
export const addGeneratorProgressEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketGeneratorProgress,
effect: (action) => {
log.trace(parseify(action.payload), `Generator progress`);
const { invocation_source_id, step, total_steps, progress_image } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
nes.progress = (step + 1) / total_steps;
nes.progressImage = progress_image ?? null;
upsertExecutionState(nes.nodeId, nes);
}
},
});
};

View File

@ -0,0 +1,122 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import {
boardIdSelected,
galleryViewChanged,
imageSelected,
isImageViewerOpenChanged,
offsetChanged,
} from 'features/gallery/store/gallerySlice';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { getCategories, getListImagesUrl } from 'services/api/util';
import { socketInvocationComplete } from 'services/events/actions';
// These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them
const nodeTypeDenylist = ['load_image', 'image'];
const log = logger('socketio');
export const addInvocationCompleteEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketInvocationComplete,
effect: async (action, { dispatch, getState }) => {
const { data } = action.payload;
log.debug({ data: parseify(data) }, `Invocation complete (${data.invocation.type})`);
const { result, invocation_source_id } = data;
// This complete event has an associated image output
if (data.result.type === 'image_output' && !nodeTypeDenylist.includes(data.invocation.type)) {
const { image_name } = data.result.image;
const { canvas, gallery } = getState();
// This populates the `getImageDTO` cache
const imageDTORequest = dispatch(
imagesApi.endpoints.getImageDTO.initiate(image_name, {
forceRefetch: true,
})
);
const imageDTO = await imageDTORequest.unwrap();
imageDTORequest.unsubscribe();
// Add canvas images to the staging area
if (canvas.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) {
dispatch(addImageToStagingArea(imageDTO));
}
if (!imageDTO.is_intermediate) {
// update the total images for the board
dispatch(
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
draft.total += 1;
})
);
dispatch(
imagesApi.util.invalidateTags([
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
{
type: 'ImageList',
id: getListImagesUrl({
board_id: imageDTO.board_id ?? 'none',
categories: getCategories(imageDTO),
}),
},
])
);
const { shouldAutoSwitch } = gallery;
// If auto-switch is enabled, select the new image
if (shouldAutoSwitch) {
// if auto-add is enabled, switch the gallery view and board if needed as the image comes in
if (gallery.galleryView !== 'images') {
dispatch(galleryViewChanged('images'));
}
if (imageDTO.board_id && imageDTO.board_id !== gallery.selectedBoardId) {
dispatch(
boardIdSelected({
boardId: imageDTO.board_id,
selectedImageName: imageDTO.image_name,
})
);
}
dispatch(offsetChanged({ offset: 0 }));
if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') {
dispatch(
boardIdSelected({
boardId: 'none',
selectedImageName: imageDTO.image_name,
})
);
}
dispatch(imageSelected(imageDTO));
dispatch(isImageViewerOpenChanged(true));
}
}
}
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) {
nes.progress = 1;
}
nes.outputs.push(result);
upsertExecutionState(nes.nodeId, nes);
}
},
});
};

View File

@ -0,0 +1,31 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketInvocationError } from 'services/events/actions';
const log = logger('socketio');
export const addInvocationErrorEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketInvocationError,
effect: (action) => {
const { invocation_source_id, invocation, error_type, error_message, error_traceback } = action.payload.data;
log.error(parseify(action.payload), `Invocation error (${invocation.type})`);
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.FAILED;
nes.progress = null;
nes.progressImage = null;
nes.error = {
error_type,
error_message,
error_traceback,
};
upsertExecutionState(nes.nodeId, nes);
}
},
});
};

View File

@ -0,0 +1,24 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketInvocationStarted } from 'services/events/actions';
const log = logger('socketio');
export const addInvocationStartedEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketInvocationStarted,
effect: (action) => {
log.debug(parseify(action.payload), `Invocation started (${action.payload.data.invocation.type})`);
const { invocation_source_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
upsertExecutionState(nes.nodeId, nes);
}
},
});
};

View File

@ -0,0 +1,196 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { api, LIST_TAG } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';
import {
socketModelInstallCancelled,
socketModelInstallComplete,
socketModelInstallDownloadProgress,
socketModelInstallDownloadsComplete,
socketModelInstallDownloadStarted,
socketModelInstallError,
socketModelInstallStarted,
} from 'services/events/actions';
/**
* A model install has two main stages - downloading and installing. All these events are namespaced under `model_install_`
* which is a bit misleading. For example, a `model_install_started` event is actually fired _after_ the model has fully
* downloaded and is being "physically" installed.
*
* Note: the download events are only fired for remote model installs, not local.
*
* Here's the expected flow:
* - API receives install request, model manager preps the install
* - `model_install_download_started` fired when the download starts
* - `model_install_download_progress` fired continually until the download is complete
* - `model_install_download_complete` fired when the download is complete
* - `model_install_started` fired when the "physical" installation starts
* - `model_install_complete` fired when the installation is complete
* - `model_install_cancelled` fired if the installation is cancelled
* - `model_install_error` fired if the installation has an error
*/
const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select();
export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketModelInstallDownloadStarted,
effect: async (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'downloading';
}
return draft;
})
);
}
},
});
startAppListening({
actionCreator: socketModelInstallStarted,
effect: async (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'running';
}
return draft;
})
);
}
},
});
startAppListening({
actionCreator: socketModelInstallDownloadProgress,
effect: async (action, { dispatch, getState }) => {
const { bytes, total_bytes, id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.bytes = bytes;
modelImport.total_bytes = total_bytes;
modelImport.status = 'downloading';
}
return draft;
})
);
}
},
});
startAppListening({
actionCreator: socketModelInstallComplete,
effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'completed';
}
return draft;
})
);
}
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
},
});
startAppListening({
actionCreator: socketModelInstallError,
effect: (action, { dispatch, getState }) => {
const { id, error, error_type } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'error';
modelImport.error_reason = error_type;
modelImport.error = error;
}
return draft;
})
);
}
},
});
startAppListening({
actionCreator: socketModelInstallCancelled,
effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'cancelled';
}
return draft;
})
);
}
},
});
startAppListening({
actionCreator: socketModelInstallDownloadsComplete,
effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'downloads_done';
}
return draft;
})
);
}
},
});
};

View File

@ -0,0 +1,42 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketModelLoadComplete, socketModelLoadStarted } from 'services/events/actions';
const log = logger('socketio');
export const addModelLoadEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketModelLoadStarted,
effect: (action) => {
const { config, submodel_type } = action.payload.data;
const { name, base, type } = config;
const extras: string[] = [base, type];
if (submodel_type) {
extras.push(submodel_type);
}
const message = `Model load started: ${name} (${extras.join(', ')})`;
log.debug(action.payload, message);
},
});
startAppListening({
actionCreator: socketModelLoadComplete,
effect: (action) => {
const { config, submodel_type } = action.payload.data;
const { name, base, type } = config;
const extras: string[] = [base, type];
if (submodel_type) {
extras.push(submodel_type);
}
const message = `Model load complete: ${name} (${extras.join(', ')})`;
log.debug(action.payload, message);
},
});
};

View File

@ -0,0 +1,114 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription';
import { toast } from 'features/toast/toast';
import { forEach } from 'lodash-es';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import { socketQueueItemStatusChanged } from 'services/events/actions';
const log = logger('socketio');
export const addSocketQueueItemStatusChangedEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketQueueItemStatusChanged,
effect: async (action, { dispatch, getState }) => {
// we've got new status for the queue item, batch and queue
const {
item_id,
session_id,
status,
started_at,
updated_at,
completed_at,
batch_status,
queue_status,
error_type,
error_message,
error_traceback,
} = action.payload.data;
log.debug(action.payload, `Queue item ${item_id} status updated: ${status}`);
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
dispatch(
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
queueItemsAdapter.updateOne(draft, {
id: String(item_id),
changes: {
status,
started_at,
updated_at: updated_at ?? undefined,
completed_at: completed_at ?? undefined,
error_type,
error_message,
error_traceback,
},
});
})
);
// Update the queue status (we do not get the processor status here)
dispatch(
queueApi.util.updateQueryData('getQueueStatus', undefined, (draft) => {
if (!draft) {
return;
}
Object.assign(draft.queue, queue_status);
})
);
// Update the batch status
dispatch(
queueApi.util.updateQueryData('getBatchStatus', { batch_id: batch_status.batch_id }, () => batch_status)
);
// Invalidate caches for things we cannot update
// TODO: technically, we could possibly update the current session queue item, but feels safer to just request it again
dispatch(
queueApi.util.invalidateTags([
'CurrentSessionQueueItem',
'NextSessionQueueItem',
'InvocationCacheStatus',
{ type: 'SessionQueueItem', id: item_id },
])
);
if (status === 'in_progress') {
forEach($nodeExecutionStates.get(), (nes) => {
if (!nes) {
return;
}
const clone = deepClone(nes);
clone.status = zNodeStatus.enum.PENDING;
clone.error = null;
clone.progress = null;
clone.progressImage = null;
clone.outputs = [];
$nodeExecutionStates.setKey(clone.nodeId, clone);
});
} else if (status === 'failed' && error_type) {
const isLocal = getState().config.isLocal ?? true;
const sessionId = session_id;
toast({
id: `INVOCATION_ERROR_${error_type}`,
title: getTitleFromErrorType(error_type),
status: 'error',
duration: null,
updateDescription: isLocal,
description: (
<ErrorToastDescription
errorType={error_type}
errorMessage={error_message}
sessionId={sessionId}
isLocal={isLocal}
/>
),
});
}
},
});
};

View File

@ -0,0 +1,43 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { stagingAreaImageSaved } from 'features/canvas/store/actions';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
export const addStagingAreaImageSavedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: stagingAreaImageSaved,
effect: async (action, { dispatch, getState }) => {
const { imageDTO } = action.payload;
try {
const newImageDTO = await dispatch(
imagesApi.endpoints.changeImageIsIntermediate.initiate({
imageDTO,
is_intermediate: false,
})
).unwrap();
// we may need to add it to the autoadd board
const { autoAddBoardId } = getState().gallery;
if (autoAddBoardId && autoAddBoardId !== 'none') {
await dispatch(
imagesApi.endpoints.addImageToBoard.initiate({
imageDTO: newImageDTO,
board_id: autoAddBoardId,
})
);
}
toast({ id: 'IMAGE_SAVED', title: t('toast.imageSaved'), status: 'success' });
} catch (error) {
toast({
id: 'IMAGE_SAVE_FAILED',
title: t('toast.imageSavingFailed'),
description: (error as Error)?.message,
status: 'error',
});
}
},
});
};

View File

@ -2,20 +2,18 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { updateAllNodesRequested } from 'features/nodes/store/actions';
import { $templates, nodesChanged } from 'features/nodes/store/nodesSlice';
import { selectNodes } from 'features/nodes/store/selectors';
import { NodeUpdateError } from 'features/nodes/types/error';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
const log = logger('workflows');
export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: updateAllNodesRequested,
effect: (action, { dispatch, getState }) => {
const nodes = selectNodes(getState());
const log = logger('nodes');
const { nodes } = getState().nodes.present;
const templates = $templates.get();
let unableToUpdateCount = 0;

View File

@ -1,6 +1,6 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
import { parseify } from 'common/util/serialize';
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
import { $templates } from 'features/nodes/store/nodesSlice';
import { $needsFit } from 'features/nodes/store/reactFlowInstance';
@ -10,14 +10,11 @@ import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { serializeError } from 'serialize-error';
import { checkBoardAccess, checkImageAccess, checkModelAccess } from 'services/api/hooks/accessChecks';
import type { GraphAndWorkflowResponse, NonNullableGraph } from 'services/api/types';
import { z } from 'zod';
import { fromZodError } from 'zod-validation-error';
const log = logger('workflows');
const getWorkflow = async (data: GraphAndWorkflowResponse, templates: Templates) => {
if (data.workflow) {
// Prefer to load the workflow if it's available - it has more information
@ -37,6 +34,7 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
startAppListening({
actionCreator: workflowLoadRequested,
effect: async (action, { dispatch }) => {
const log = logger('nodes');
const { data, asCopy } = action.payload;
const nodeTemplates = $templates.get();
@ -48,7 +46,6 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
delete workflow.id;
}
$nodeExecutionStates.set({});
dispatch(workflowLoaded(workflow));
if (!warnings.length) {
toast({
@ -72,7 +69,7 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
} catch (e) {
if (e instanceof WorkflowVersionError) {
// The workflow version was not recognized in the valid list of versions
log.error({ error: serializeError(e) }, e.message);
log.error({ error: parseify(e) }, e.message);
toast({
id: 'UNABLE_TO_VALIDATE_WORKFLOW',
title: t('nodes.unableToValidateWorkflow'),
@ -81,7 +78,7 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
});
} else if (e instanceof WorkflowMigrationError) {
// There was a problem migrating the workflow to the latest version
log.error({ error: serializeError(e) }, e.message);
log.error({ error: parseify(e) }, e.message);
toast({
id: 'UNABLE_TO_VALIDATE_WORKFLOW',
title: t('nodes.unableToValidateWorkflow'),
@ -93,7 +90,7 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
const { message } = fromZodError(e, {
prefix: t('nodes.workflowValidation'),
});
log.error({ error: serializeError(e) }, message);
log.error({ error: parseify(e) }, message);
toast({
id: 'UNABLE_TO_VALIDATE_WORKFLOW',
title: t('nodes.unableToValidateWorkflow'),
@ -102,7 +99,7 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
});
} else {
// Some other error occurred
log.error({ error: serializeError(e) }, t('nodes.unknownErrorValidatingWorkflow'));
log.error({ error: parseify(e) }, t('nodes.unknownErrorValidatingWorkflow'));
toast({
id: 'UNABLE_TO_VALIDATE_WORKFLOW',
title: t('nodes.unableToValidateWorkflow'),

View File

@ -1,5 +1,4 @@
import { useStore } from '@nanostores/react';
import type { AppStore } from 'app/store/store';
import type { createStore } from 'app/store/store';
import { atom } from 'nanostores';
// Inject socket options and url into window for debugging
@ -23,7 +22,7 @@ class ReduxStoreNotInitialized extends Error {
}
}
export const $store = atom<Readonly<AppStore | undefined>>();
export const $store = atom<Readonly<ReturnType<typeof createStore>> | undefined>();
export const getStore = () => {
const store = $store.get();
@ -32,11 +31,3 @@ export const getStore = () => {
}
return store;
};
export const useAppStore = () => {
const store = useStore($store);
if (!store) {
throw new ReduxStoreNotInitialized();
}
return store;
};

View File

@ -3,31 +3,37 @@ import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/too
import { logger } from 'app/logging/logger';
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
import type { SerializableObject } from 'common/types';
import { deepClone } from 'common/util/deepClone';
import type { JSONObject } from 'common/types';
import { canvasPersistConfig, canvasSlice } from 'features/canvas/store/canvasSlice';
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
import { canvasSessionPersistConfig, canvasSessionSlice } from 'features/controlLayers/store/canvasSessionSlice';
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { canvasPersistConfig, canvasSlice, canvasUndoableConfig } from 'features/controlLayers/store/canvasSlice';
import { lorasPersistConfig, lorasSlice } from 'features/controlLayers/store/lorasSlice';
import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/paramsSlice';
import { toolPersistConfig, toolSlice } from 'features/controlLayers/store/toolSlice';
import {
controlAdaptersPersistConfig,
controlAdaptersSlice,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import {
controlLayersPersistConfig,
controlLayersSlice,
controlLayersUndoableConfig,
} from 'features/controlLayers/store/controlLayersSlice';
import { deleteImageModalSlice } from 'features/deleteImageModal/store/slice';
import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
import { upscalePersistConfig, upscaleSlice } from 'features/parameters/store/upscaleSlice';
import { queueSlice } from 'features/queue/store/queueSlice';
import { sdxlPersistConfig, sdxlSlice } from 'features/sdxl/store/sdxlSlice';
import { stylePresetPersistConfig, stylePresetSlice } from 'features/stylePresets/store/stylePresetSlice';
import { configSlice } from 'features/system/store/configSlice';
import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice';
import { uiPersistConfig, uiSlice } from 'features/ui/store/uiSlice';
import { diff } from 'jsondiffpatch';
import { keys, mergeWith, omit, pick } from 'lodash-es';
import { defaultsDeep, keys, omit, pick } from 'lodash-es';
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import type { SerializeFunction, UnserializeFunction } from 'redux-remember';
import { rememberEnhancer, rememberReducer } from 'redux-remember';
@ -42,31 +48,29 @@ import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { listenerMiddleware } from './middleware/listenerMiddleware';
const log = logger('system');
const allReducers = {
[api.reducerPath]: api.reducer,
[canvasSlice.name]: canvasSlice.reducer,
[gallerySlice.name]: gallerySlice.reducer,
[generationSlice.name]: generationSlice.reducer,
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
[systemSlice.name]: systemSlice.reducer,
[configSlice.name]: configSlice.reducer,
[uiSlice.name]: uiSlice.reducer,
[controlAdaptersSlice.name]: controlAdaptersSlice.reducer,
[dynamicPromptsSlice.name]: dynamicPromptsSlice.reducer,
[deleteImageModalSlice.name]: deleteImageModalSlice.reducer,
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
[loraSlice.name]: loraSlice.reducer,
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
[sdxlSlice.name]: sdxlSlice.reducer,
[queueSlice.name]: queueSlice.reducer,
[workflowSlice.name]: workflowSlice.reducer,
[hrfSlice.name]: hrfSlice.reducer,
[canvasSlice.name]: undoable(canvasSlice.reducer, canvasUndoableConfig),
[controlLayersSlice.name]: undoable(controlLayersSlice.reducer, controlLayersUndoableConfig),
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
[api.reducerPath]: api.reducer,
[upscaleSlice.name]: upscaleSlice.reducer,
[stylePresetSlice.name]: stylePresetSlice.reducer,
[paramsSlice.name]: paramsSlice.reducer,
[toolSlice.name]: toolSlice.reducer,
[canvasSettingsSlice.name]: canvasSettingsSlice.reducer,
[canvasSessionSlice.name]: canvasSessionSlice.reducer,
[lorasSlice.name]: lorasSlice.reducer,
};
const rootReducer = combineReducers(allReducers);
@ -96,26 +100,27 @@ export type PersistConfig<T = any> = {
};
const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[canvasPersistConfig.name]: canvasPersistConfig,
[galleryPersistConfig.name]: galleryPersistConfig,
[generationPersistConfig.name]: generationPersistConfig,
[nodesPersistConfig.name]: nodesPersistConfig,
[systemPersistConfig.name]: systemPersistConfig,
[workflowPersistConfig.name]: workflowPersistConfig,
[uiPersistConfig.name]: uiPersistConfig,
[controlAdaptersPersistConfig.name]: controlAdaptersPersistConfig,
[dynamicPromptsPersistConfig.name]: dynamicPromptsPersistConfig,
[sdxlPersistConfig.name]: sdxlPersistConfig,
[loraPersistConfig.name]: loraPersistConfig,
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
[hrfPersistConfig.name]: hrfPersistConfig,
[canvasPersistConfig.name]: canvasPersistConfig,
[controlLayersPersistConfig.name]: controlLayersPersistConfig,
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
[upscalePersistConfig.name]: upscalePersistConfig,
[stylePresetPersistConfig.name]: stylePresetPersistConfig,
[paramsPersistConfig.name]: paramsPersistConfig,
[toolPersistConfig.name]: toolPersistConfig,
[canvasSettingsPersistConfig.name]: canvasSettingsPersistConfig,
[canvasSessionPersistConfig.name]: canvasSessionPersistConfig,
[lorasPersistConfig.name]: lorasPersistConfig,
};
const unserialize: UnserializeFunction = (data, key) => {
const log = logger('system');
const persistConfig = persistConfigs[key as keyof typeof persistConfigs];
if (!persistConfig) {
throw new Error(`No persist config for slice "${key}"`);
@ -125,21 +130,17 @@ const unserialize: UnserializeFunction = (data, key) => {
const parsed = JSON.parse(data);
// strip out old keys
const stripped = pick(deepClone(parsed), keys(initialState));
const stripped = pick(parsed, keys(initialState));
// run (additive) migrations
const migrated = migrate(stripped);
/*
* Merge in initial state as default values, covering any missing keys. You might be tempted to use _.defaultsDeep,
* but that merges arrays by index and partial objects by key. Using an identity function as the customizer results
* in behaviour like defaultsDeep, but doesn't overwrite any values that are not undefined in the migrated state.
*/
const transformed = mergeWith(migrated, initialState, (objVal) => objVal);
// merge in initial state as default values, covering any missing keys
const transformed = defaultsDeep(migrated, initialState);
log.debug(
{
persistedData: parsed,
rehydratedData: transformed,
diff: diff(parsed, transformed) as SerializableObject, // this is always serializable
diff: diff(parsed, transformed) as JSONObject, // this is always serializable
},
`Rehydrated slice "${key}"`
);
@ -201,8 +202,7 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
},
});
export type AppStore = ReturnType<typeof createStore>;
export type RootState = ReturnType<AppStore['getState']>;
export type RootState = ReturnType<ReturnType<typeof createStore>['getState']>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type AppThunkDispatch = ThunkDispatch<RootState, any, UnknownAction>;
export type AppDispatch = ReturnType<typeof createStore>['dispatch'];

View File

@ -1,6 +1,6 @@
import type { FilterType } from 'features/controlLayers/store/types';
import type { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import type { TabName } from 'features/ui/store/uiTypes';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import type { O } from 'ts-toolbelt';
/**
@ -72,7 +72,7 @@ export type AppConfig = {
maxUpscaleDimension?: number;
allowPrivateBoards: boolean;
allowPrivateStylePresets: boolean;
disabledTabs: TabName[];
disabledTabs: InvokeTabName[];
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];
nodesAllowlist: string[] | undefined;
@ -83,7 +83,7 @@ export type AppConfig = {
sd: {
defaultModel?: string;
disabledControlNetModels: string[];
disabledControlNetProcessors: FilterType[];
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
// Core parameters
iterations: NumericalParameterConfig;
width: NumericalParameterConfig; // initial value comes from model

View File

@ -33,23 +33,28 @@ type IAINoImageFallbackProps = FlexProps & {
};
export const IAINoContentFallback = memo((props: IAINoImageFallbackProps) => {
const { icon = PiImageBold, boxSize = 16, ...rest } = props;
const { icon = PiImageBold, boxSize = 16, sx, ...rest } = props;
const styles = useMemo(
() => ({
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
flexDir: 'column',
gap: 2,
userSelect: 'none',
opacity: 0.7,
color: 'base.500',
fontSize: 'md',
...sx,
}),
[sx]
);
return (
<Flex
w="full"
h="full"
alignItems="center"
justifyContent="center"
borderRadius="base"
flexDir="column"
gap={2}
userSelect="none"
opacity={0.7}
color="base.500"
fontSize="md"
{...rest}
>
<Flex sx={styles} {...rest}>
{icon && <Icon as={icon} boxSize={boxSize} opacity={0.7} />}
{props.label && <Text textAlign="center">{props.label}</Text>}
</Flex>

View File

@ -13,9 +13,8 @@ import {
Spacer,
Text,
} from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectSystemSlice, setShouldEnableInformationalPopovers } from 'features/system/store/systemSlice';
import { setShouldEnableInformationalPopovers } from 'features/system/store/systemSlice';
import { toast } from 'features/toast/toast';
import { merge, omit } from 'lodash-es';
import type { ReactElement } from 'react';
@ -32,13 +31,8 @@ type Props = {
children: ReactElement;
};
const selectShouldEnableInformationalPopovers = createSelector(
selectSystemSlice,
(system) => system.shouldEnableInformationalPopovers
);
export const InformationalPopover = memo(({ feature, children, inPortal = true, ...rest }: Props) => {
const shouldEnableInformationalPopovers = useAppSelector(selectShouldEnableInformationalPopovers);
const shouldEnableInformationalPopovers = useAppSelector((s) => s.system.shouldEnableInformationalPopovers);
const data = useMemo(() => POPOVER_DATA[feature], [feature]);

View File

@ -1,158 +0,0 @@
import { logger } from 'app/logging/logger';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { objectKeys } from 'common/util/objectKeys';
import { isEqual } from 'lodash-es';
import type { Atom } from 'nanostores';
import { atom, computed } from 'nanostores';
import type { RefObject } from 'react';
import { useEffect, useMemo } from 'react';
const log = logger('system');
const _INTERACTION_SCOPES = ['gallery', 'canvas', 'stagingArea', 'workflows', 'imageViewer'] as const;
type InteractionScope = (typeof _INTERACTION_SCOPES)[number];
export const $activeScopes = atom<Set<InteractionScope>>(new Set());
type InteractionScopeData = {
targets: Set<HTMLElement>;
$isActive: Atom<boolean>;
};
export const INTERACTION_SCOPES: Record<InteractionScope, InteractionScopeData> = _INTERACTION_SCOPES.reduce(
(acc, region) => {
acc[region] = {
targets: new Set(),
$isActive: computed($activeScopes, (activeScopes) => activeScopes.has(region)),
};
return acc;
},
{} as Record<InteractionScope, InteractionScopeData>
);
const formatScopes = (interactionScopes: Set<InteractionScope>) => {
if (interactionScopes.size === 0) {
return 'none';
}
return Array.from(interactionScopes).join(', ');
};
export const addScope = (scope: InteractionScope) => {
const currentScopes = $activeScopes.get();
if (currentScopes.has(scope)) {
return;
}
const newScopes = new Set(currentScopes);
newScopes.add(scope);
$activeScopes.set(newScopes);
log.trace(`Added scope ${scope}: ${formatScopes($activeScopes.get())}`);
};
export const removeScope = (scope: InteractionScope) => {
const currentScopes = $activeScopes.get();
if (!currentScopes.has(scope)) {
return;
}
const newScopes = new Set(currentScopes);
newScopes.delete(scope);
$activeScopes.set(newScopes);
log.trace(`Removed scope ${scope}: ${formatScopes($activeScopes.get())}`);
};
export const setScopes = (scopes: InteractionScope[]) => {
const newScopes = new Set(scopes);
$activeScopes.set(newScopes);
log.trace(`Set scopes: ${formatScopes($activeScopes.get())}`);
};
export const useScopeOnFocus = (scope: InteractionScope, ref: RefObject<HTMLElement>) => {
useEffect(() => {
const element = ref.current;
if (!element) {
return;
}
INTERACTION_SCOPES[scope].targets.add(element);
return () => {
INTERACTION_SCOPES[scope].targets.delete(element);
};
}, [ref, scope]);
};
type UseScopeOnMountOptions = {
mount?: boolean;
unmount?: boolean;
};
const defaultUseScopeOnMountOptions: UseScopeOnMountOptions = {
mount: true,
unmount: true,
};
export const useScopeOnMount = (scope: InteractionScope, options?: UseScopeOnMountOptions) => {
useEffect(() => {
const { mount, unmount } = { ...defaultUseScopeOnMountOptions, ...options };
if (mount) {
addScope(scope);
}
return () => {
if (unmount) {
removeScope(scope);
}
};
}, [options, scope]);
};
export const useScopeImperativeApi = (scope: InteractionScope) => {
const api = useMemo(() => {
return {
add: () => {
addScope(scope);
},
remove: () => {
removeScope(scope);
},
};
}, [scope]);
return api;
};
const handleFocusEvent = (_event: FocusEvent) => {
const activeElement = document.activeElement;
if (!(activeElement instanceof HTMLElement)) {
return;
}
const newActiveScopes = new Set<InteractionScope>();
for (const scope of objectKeys(INTERACTION_SCOPES)) {
for (const element of INTERACTION_SCOPES[scope].targets) {
if (element.contains(activeElement)) {
newActiveScopes.add(scope);
}
}
}
const oldActiveScopes = $activeScopes.get();
if (!isEqual(oldActiveScopes, newActiveScopes)) {
$activeScopes.set(newActiveScopes);
log.trace(`Scopes changed: ${formatScopes($activeScopes.get())}`);
}
};
export const useScopeFocusWatcher = () => {
useAssertSingleton('useScopeFocusWatcher');
useEffect(() => {
window.addEventListener('focus', handleFocusEvent, true);
return () => {
window.removeEventListener('focus', handleFocusEvent, true);
};
}, []);
};

View File

@ -1,4 +1,3 @@
import type { WritableAtom } from 'nanostores';
import { useCallback, useMemo, useState } from 'react';
export const useBoolean = (initialValue: boolean) => {
@ -20,33 +19,3 @@ export const useBoolean = (initialValue: boolean) => {
return api;
};
export const buildUseBoolean = ($boolean: WritableAtom<boolean>) => {
return () => {
const setTrue = useCallback(() => {
$boolean.set(true);
}, []);
const setFalse = useCallback(() => {
$boolean.set(false);
}, []);
const set = useCallback((value: boolean) => {
$boolean.set(value);
}, []);
const toggle = useCallback(() => {
$boolean.set(!$boolean.get());
}, []);
const api = useMemo(
() => ({
setTrue,
setFalse,
set,
toggle,
$boolean,
}),
[set, setFalse, setTrue, toggle]
);
return api;
};
};

View File

@ -1,8 +1,7 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useCallback, useEffect, useState } from 'react';
import type { Accept, FileRejection } from 'react-dropzone';
import { useDropzone } from 'react-dropzone';
@ -15,9 +14,13 @@ const accept: Accept = {
'image/jpeg': ['.jpg', '.jpeg', '.png'],
};
const selectPostUploadAction = createMemoizedSelector(selectActiveTab, (activeTabName) => {
const selectPostUploadAction = createMemoizedSelector(activeTabNameSelector, (activeTabName) => {
let postUploadAction: PostUploadAction = { type: 'TOAST' };
if (activeTabName === 'canvas') {
postUploadAction = { type: 'SET_CANVAS_INITIAL_IMAGE' };
}
if (activeTabName === 'upscaling') {
postUploadAction = { type: 'SET_UPSCALE_INITIAL_IMAGE' };
}
@ -27,9 +30,10 @@ const selectPostUploadAction = createMemoizedSelector(selectActiveTab, (activeTa
export const useFullscreenDropzone = () => {
const { t } = useTranslation();
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const [isHandlingUpload, setIsHandlingUpload] = useState<boolean>(false);
const postUploadAction = useAppSelector(selectPostUploadAction);
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
const [isHandlingUpload, setIsHandlingUpload] = useState<boolean>(false);
const [uploadImage] = useUploadImageMutation();
const fileRejectionCallback = useCallback(
@ -47,7 +51,7 @@ export const useFullscreenDropzone = () => {
);
const fileAcceptedCallback = useCallback(
(file: File) => {
async (file: File) => {
uploadImage({
file,
image_category: 'user',
@ -97,7 +101,7 @@ export const useFullscreenDropzone = () => {
useEffect(() => {
// This is a hack to allow pasting images into the uploader
const handlePaste = (e: ClipboardEvent) => {
const handlePaste = async (e: ClipboardEvent) => {
if (!dropzone.inputRef.current) {
return;
}

View File

@ -1,5 +1,4 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { addScope, removeScope, setScopes } from 'common/hooks/interactionScopes';
import { useCancelCurrentQueueItem } from 'features/queue/hooks/useCancelCurrentQueueItem';
import { useClearQueue } from 'features/queue/hooks/useClearQueue';
import { useQueueBack } from 'features/queue/hooks/useQueueBack';
@ -17,7 +16,7 @@ export const useGlobalHotkeys = () => {
['ctrl+enter', 'meta+enter'],
queueBack,
{
enabled: !isDisabledQueueBack && !isLoadingQueueBack,
enabled: () => !isDisabledQueueBack && !isLoadingQueueBack,
preventDefault: true,
enableOnFormTags: ['input', 'textarea', 'select'],
},
@ -30,7 +29,7 @@ export const useGlobalHotkeys = () => {
['ctrl+shift+enter', 'meta+shift+enter'],
queueFront,
{
enabled: !isDisabledQueueFront && !isLoadingQueueFront,
enabled: () => !isDisabledQueueFront && !isLoadingQueueFront,
preventDefault: true,
enableOnFormTags: ['input', 'textarea', 'select'],
},
@ -47,7 +46,7 @@ export const useGlobalHotkeys = () => {
['shift+x'],
cancelQueueItem,
{
enabled: !isDisabledCancelQueueItem && !isLoadingCancelQueueItem,
enabled: () => !isDisabledCancelQueueItem && !isLoadingCancelQueueItem,
preventDefault: true,
},
[cancelQueueItem, isDisabledCancelQueueItem, isLoadingCancelQueueItem]
@ -59,7 +58,7 @@ export const useGlobalHotkeys = () => {
['ctrl+shift+x', 'meta+shift+x'],
clearQueue,
{
enabled: !isDisabledClearQueue && !isLoadingClearQueue,
enabled: () => !isDisabledClearQueue && !isLoadingClearQueue,
preventDefault: true,
},
[clearQueue, isDisabledClearQueue, isLoadingClearQueue]
@ -69,8 +68,6 @@ export const useGlobalHotkeys = () => {
'1',
() => {
dispatch(setActiveTab('generation'));
addScope('canvas');
removeScope('workflows');
},
[dispatch]
);
@ -78,9 +75,7 @@ export const useGlobalHotkeys = () => {
useHotkeys(
'2',
() => {
dispatch(setActiveTab('upscaling'));
removeScope('canvas');
removeScope('workflows');
dispatch(setActiveTab('canvas'));
},
[dispatch]
);
@ -89,8 +84,6 @@ export const useGlobalHotkeys = () => {
'3',
() => {
dispatch(setActiveTab('workflows'));
removeScope('canvas');
addScope('workflows');
},
[dispatch]
);
@ -100,7 +93,6 @@ export const useGlobalHotkeys = () => {
() => {
if (isModelManagerEnabled) {
dispatch(setActiveTab('models'));
setScopes([]);
}
},
[dispatch, isModelManagerEnabled]
@ -110,7 +102,6 @@ export const useGlobalHotkeys = () => {
isModelManagerEnabled ? '5' : '4',
() => {
dispatch(setActiveTab('queue'));
setScopes([]);
},
[dispatch, isModelManagerEnabled]
);

View File

@ -1,8 +1,6 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { groupBy, reduce } from 'lodash-es';
import { useCallback, useMemo } from 'react';
@ -30,13 +28,11 @@ const groupByBaseFunc = <T extends AnyModelConfig>(model: T) => model.base.toUpp
const groupByBaseAndTypeFunc = <T extends AnyModelConfig>(model: T) =>
`${model.base.toUpperCase()} / ${model.type.replaceAll('_', ' ').toUpperCase()}`;
const selectBaseWithSDXLFallback = createSelector(selectParamsSlice, (params) => params.model?.base ?? 'sdxl');
export const useGroupedModelCombobox = <T extends AnyModelConfig>(
arg: UseGroupedModelComboboxArg<T>
): UseGroupedModelComboboxReturn => {
const { t } = useTranslation();
const base = useAppSelector(selectBaseWithSDXLFallback);
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, groupByType = false } = arg;
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
if (!modelConfigs) {
@ -58,9 +54,9 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
},
[] as GroupBase<ComboboxOption>[]
);
_options.sort((a) => (a.label?.split('/')[0]?.toLowerCase().includes(base) ? -1 : 1));
_options.sort((a) => (a.label?.split('/')[0]?.toLowerCase().includes(base_model) ? -1 : 1));
return _options;
}, [modelConfigs, groupByType, getIsDisabled, base]);
}, [modelConfigs, groupByType, getIsDisabled, base_model]);
const value = useMemo(
() =>

Some files were not shown because too many files have changed in this diff Show More