Merge branch 'main' into psyche/fix/ui/cl-listening-layers

This commit is contained in:
blessedcoolant 2024-05-13 04:05:35 +05:30 committed by GitHub
commit 6ec3dc0c0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 360 additions and 362 deletions

View File

@ -586,13 +586,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Scheduler, scheduler: Scheduler,
) -> StableDiffusionGeneratorPipeline: ) -> StableDiffusionGeneratorPipeline:
# TODO:
# configure_model_padding(
# unet,
# self.seamless,
# self.seamless_axes,
# )
class FakeVae: class FakeVae:
class FakeVaeConfig: class FakeVaeConfig:
def __init__(self) -> None: def __init__(self) -> None:

View File

@ -4,5 +4,4 @@ Initialization file for invokeai.backend.image_util methods.
from .infill_methods.patchmatch import PatchMatch # noqa: F401 from .infill_methods.patchmatch import PatchMatch # noqa: F401
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401 from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
from .seamless import configure_model_padding # noqa: F401
from .util import InitImageResizer, make_grid # noqa: F401 from .util import InitImageResizer, make_grid # noqa: F401

View File

@ -1,52 +0,0 @@
import torch.nn as nn
def _conv_forward_asymmetric(self, input, weight, bias):
"""
Patch for Conv2d._conv_forward that supports asymmetric padding
"""
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
return nn.functional.conv2d(
working,
weight,
bias,
self.stride,
nn.modules.utils._pair(0),
self.dilation,
self.groups,
)
def configure_model_padding(model, seamless, seamless_axes):
"""
Modifies the 2D convolution layers to use a circular padding mode based on
the `seamless` and `seamless_axes` options.
"""
# TODO: get an explicit interface for this in diffusers: https://github.com/huggingface/diffusers/issues/556
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
if seamless:
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
else:
m._conv_forward = nn.Conv2d._conv_forward.__get__(m, nn.Conv2d)
if hasattr(m, "asymmetric_padding_mode"):
del m.asymmetric_padding_mode
if hasattr(m, "asymmetric_padding"):
del m.asymmetric_padding

View File

@ -1,89 +1,51 @@
from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable, List, Union from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn import torch.nn as nn
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from diffusers.models.lora import LoRACompatibleConv
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
def _conv_forward_asymmetric(self, input, weight, bias):
"""
Patch for Conv2d._conv_forward that supports asymmetric padding
"""
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
return nn.functional.conv2d(
working,
weight,
bias,
self.stride,
nn.modules.utils._pair(0),
self.dilation,
self.groups,
)
@contextmanager @contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]): def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
if not seamless_axes: if not seamless_axes:
yield yield
return return
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor # override conv_forward
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = [] # https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
def _conv_forward_asymmetric(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
return torch.nn.functional.conv2d(
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
)
original_layers: List[Tuple[nn.Conv2d, Callable]] = []
try: try:
# Hard coded to skip down block layers, allowing for seamless tiling at the expense of prompt adherence x_mode = "circular" if "x" in seamless_axes else "constant"
skipped_layers = 1 y_mode = "circular" if "y" in seamless_axes else "constant"
for m_name, m in model.named_modules():
if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
continue
if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name: conv_layers: List[torch.nn.Conv2d] = []
# down_blocks.1.resnets.1.conv1
_, block_num, _, resnet_num, submodule_name = m_name.split(".")
block_num = int(block_num)
resnet_num = int(resnet_num)
if block_num >= len(model.down_blocks) - skipped_layers: for module in model.modules():
continue if isinstance(module, torch.nn.Conv2d):
conv_layers.append(module)
# Skip the second resnet (could be configurable) for layer in conv_layers:
if resnet_num > 0: if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
continue layer.lora_layer = lambda *x: 0
original_layers.append((layer, layer._conv_forward))
# Skip Conv2d layers (could be configurable) layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
if submodule_name == "conv2":
continue
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
yield yield
finally: finally:
for module, orig_conv_forward in to_restore: for layer, orig_conv_forward in original_layers:
module._conv_forward = orig_conv_forward layer._conv_forward = orig_conv_forward
if hasattr(module, "asymmetric_padding_mode"):
del module.asymmetric_padding_mode
if hasattr(module, "asymmetric_padding"):
del module.asymmetric_padding

View File

@ -261,7 +261,6 @@
"queue": "Queue", "queue": "Queue",
"queueFront": "Add to Front of Queue", "queueFront": "Add to Front of Queue",
"queueBack": "Add to Queue", "queueBack": "Add to Queue",
"queueCountPrediction": "{{promptsCount}} prompts \u00d7 {{iterations}} iterations -> {{count}} generations",
"queueEmpty": "Queue Empty", "queueEmpty": "Queue Empty",
"enqueueing": "Queueing Batch", "enqueueing": "Queueing Batch",
"resume": "Resume", "resume": "Resume",
@ -314,7 +313,13 @@
"batchFailedToQueue": "Failed to Queue Batch", "batchFailedToQueue": "Failed to Queue Batch",
"graphQueued": "Graph queued", "graphQueued": "Graph queued",
"graphFailedToQueue": "Failed to queue graph", "graphFailedToQueue": "Failed to queue graph",
"openQueue": "Open Queue" "openQueue": "Open Queue",
"prompts_one": "Prompt",
"prompts_other": "Prompts",
"iterations_one": "Iteration",
"iterations_other": "Iterations",
"generations_one": "Generation",
"generations_other": "Generations"
}, },
"invocationCache": { "invocationCache": {
"invocationCache": "Invocation Cache", "invocationCache": "Invocation Cache",
@ -934,7 +939,20 @@
"noModelSelected": "No model selected", "noModelSelected": "No model selected",
"noPrompts": "No prompts generated", "noPrompts": "No prompts generated",
"noNodesInGraph": "No nodes in graph", "noNodesInGraph": "No nodes in graph",
"systemDisconnected": "System disconnected" "systemDisconnected": "System disconnected",
"layer": {
"initialImageNoImageSelected": "no initial image selected",
"controlAdapterNoModelSelected": "no Control Adapter model selected",
"controlAdapterIncompatibleBaseModel": "incompatible Control Adapter base model",
"controlAdapterNoImageSelected": "no Control Adapter image selected",
"controlAdapterImageNotProcessed": "Control Adapter image not processed",
"t2iAdapterIncompatibleDimensions": "T2I Adapter requires image dimension to be multiples of 64",
"ipAdapterNoModelSelected": "no IP adapter selected",
"ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model",
"ipAdapterNoImageSelected": "no IP Adapter image selected",
"rgNoPromptsOrIPAdapters": "no text prompts or IP Adapters",
"rgNoRegion": "no region selected"
}
}, },
"maskBlur": "Mask Blur", "maskBlur": "Mask Blur",
"negativePromptPlaceholder": "Negative Prompt", "negativePromptPlaceholder": "Negative Prompt",
@ -945,8 +963,6 @@
"positivePromptPlaceholder": "Positive Prompt", "positivePromptPlaceholder": "Positive Prompt",
"globalPositivePromptPlaceholder": "Global Positive Prompt", "globalPositivePromptPlaceholder": "Global Positive Prompt",
"iterations": "Iterations", "iterations": "Iterations",
"iterationsWithCount_one": "{{count}} Iteration",
"iterationsWithCount_other": "{{count}} Iterations",
"scale": "Scale", "scale": "Scale",
"scaleBeforeProcessing": "Scale Before Processing", "scaleBeforeProcessing": "Scale Before Processing",
"scaledHeight": "Scaled H", "scaledHeight": "Scaled H",

View File

@ -1,13 +1,14 @@
import { isAnyOf } from '@reduxjs/toolkit'; import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch } from 'app/store/store';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { import {
caLayerImageChanged, caLayerImageChanged,
caLayerIsProcessingImageChanged,
caLayerModelChanged, caLayerModelChanged,
caLayerProcessedImageChanged, caLayerProcessedImageChanged,
caLayerProcessorConfigChanged, caLayerProcessorConfigChanged,
caLayerProcessorPendingBatchIdChanged,
caLayerRecalled, caLayerRecalled,
isControlAdapterLayer, isControlAdapterLayer,
} from 'features/controlLayers/store/controlLayersSlice'; } from 'features/controlLayers/store/controlLayersSlice';
@ -15,47 +16,39 @@ import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
import { isImageOutput } from 'features/nodes/types/common'; import { isImageOutput } from 'features/nodes/types/common';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next'; import { t } from 'i18next';
import { isEqual } from 'lodash-es'; import { getImageDTO } from 'services/api/endpoints/images';
import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig, ImageDTO } from 'services/api/types'; import type { BatchConfig } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions'; import { socketInvocationComplete } from 'services/events/actions';
import { assert } from 'tsafe';
const matcher = isAnyOf(caLayerImageChanged, caLayerProcessorConfigChanged, caLayerModelChanged, caLayerRecalled); const matcher = isAnyOf(caLayerImageChanged, caLayerProcessorConfigChanged, caLayerModelChanged, caLayerRecalled);
const DEBOUNCE_MS = 300; const DEBOUNCE_MS = 300;
const log = logger('session'); const log = logger('session');
/**
* Simple helper to cancel a batch and reset the pending batch ID
*/
const cancelProcessorBatch = async (dispatch: AppDispatch, layerId: string, batchId: string) => {
const req = dispatch(queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: [batchId] }));
log.trace({ batchId }, 'Cancelling existing preprocessor batch');
try {
await req.unwrap();
} catch {
// no-op
} finally {
req.reset();
// Always reset the pending batch ID - the cancel req could fail if the batch doesn't exist
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: null }));
}
};
export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => { export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
matcher, matcher,
effect: async (action, { dispatch, getState, getOriginalState, cancelActiveListeners, delay, take }) => { effect: async (action, { dispatch, getState, cancelActiveListeners, delay, take, signal }) => {
const layerId = caLayerRecalled.match(action) ? action.payload.id : action.payload.layerId; const layerId = caLayerRecalled.match(action) ? action.payload.id : action.payload.layerId;
const precheckLayerOriginal = getOriginalState()
.controlLayers.present.layers.filter(isControlAdapterLayer)
.find((l) => l.id === layerId);
const precheckLayer = getState()
.controlLayers.present.layers.filter(isControlAdapterLayer)
.find((l) => l.id === layerId);
// Conditions to bail
const layerDoesNotExist = !precheckLayer;
const layerHasNoImage = !precheckLayer?.controlAdapter.image;
const layerHasNoProcessorConfig = !precheckLayer?.controlAdapter.processorConfig;
const layerIsAlreadyProcessingImage = precheckLayer?.controlAdapter.isProcessingImage;
const areImageAndProcessorUnchanged =
isEqual(precheckLayer?.controlAdapter.image, precheckLayerOriginal?.controlAdapter.image) &&
isEqual(precheckLayer?.controlAdapter.processorConfig, precheckLayerOriginal?.controlAdapter.processorConfig);
if (
layerDoesNotExist ||
layerHasNoImage ||
layerHasNoProcessorConfig ||
areImageAndProcessorUnchanged ||
layerIsAlreadyProcessingImage
) {
return;
}
// Cancel any in-progress instances of this listener // Cancel any in-progress instances of this listener
cancelActiveListeners(); cancelActiveListeners();
@ -63,19 +56,31 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
// Delay before starting actual work // Delay before starting actual work
await delay(DEBOUNCE_MS); await delay(DEBOUNCE_MS);
dispatch(caLayerIsProcessingImageChanged({ layerId, isProcessingImage: true }));
// Double-check that we are still eligible for processing // Double-check that we are still eligible for processing
const state = getState(); const state = getState();
const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId); const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId);
const image = layer?.controlAdapter.image;
const config = layer?.controlAdapter.processorConfig;
// If we have no image or there is no processor config, bail // If we have no image or there is no processor config, bail
if (!layer || !image || !config) { if (!layer) {
return; return;
} }
const image = layer.controlAdapter.image;
const config = layer.controlAdapter.processorConfig;
if (!image || !config) {
// The user has reset the image or config, so we should clear the processed image
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null }));
}
// At this point, the user has stopped fiddling with the processor settings and there is a processor selected.
// If there is a pending processor batch, cancel it.
if (layer.controlAdapter.processorPendingBatchId) {
cancelProcessorBatch(dispatch, layerId, layer.controlAdapter.processorPendingBatchId);
}
// @ts-expect-error: TS isn't able to narrow the typing of buildNode and `config` will error... // @ts-expect-error: TS isn't able to narrow the typing of buildNode and `config` will error...
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config); const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config);
const enqueueBatchArg: BatchConfig = { const enqueueBatchArg: BatchConfig = {
@ -83,7 +88,11 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
batch: { batch: {
graph: { graph: {
nodes: { nodes: {
[processorNode.id]: { ...processorNode, is_intermediate: true }, [processorNode.id]: {
...processorNode,
// Control images are always intermediate - do not save to gallery
is_intermediate: true,
},
}, },
edges: [], edges: [],
}, },
@ -91,16 +100,21 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
}, },
}; };
// Kick off the processor batch
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
fixedCacheKey: 'enqueueBatch',
})
);
try { try {
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
fixedCacheKey: 'enqueueBatch',
})
);
const enqueueResult = await req.unwrap(); const enqueueResult = await req.unwrap();
req.reset(); // TODO(psyche): Update the pydantic models, pretty sure we will _always_ have a batch_id here, but the model says it's optional
assert(enqueueResult.batch.batch_id, 'Batch ID not returned from queue');
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: enqueueResult.batch.batch_id }));
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued')); log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
// Wait for the processor node to complete
const [invocationCompleteAction] = await take( const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> => (action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) && socketInvocationComplete.match(action) &&
@ -109,47 +123,52 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
); );
// We still have to check the output type // We still have to check the output type
if (isImageOutput(invocationCompleteAction.payload.data.result)) { assert(
const { image_name } = invocationCompleteAction.payload.data.result.image; isImageOutput(invocationCompleteAction.payload.data.result),
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
);
const { image_name } = invocationCompleteAction.payload.data.result.image;
// Wait for the ImageDTO to be received const imageDTO = await getImageDTO(image_name);
const [{ payload }] = await take( assert(imageDTO, "Failed to fetch processor output's image DTO");
(action) =>
imagesApi.endpoints.getImageDTO.matchFulfilled(action) && action.payload.image_name === image_name
);
const imageDTO = payload as ImageDTO; // Whew! We made it. Update the layer with the processed image
log.debug({ layerId, imageDTO }, 'ControlNet image processed');
log.debug({ layerId, imageDTO }, 'ControlNet image processed'); dispatch(caLayerProcessedImageChanged({ layerId, imageDTO }));
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: null }));
// Update the processed image in the store
dispatch(
caLayerProcessedImageChanged({
layerId,
imageDTO,
})
);
dispatch(caLayerIsProcessingImageChanged({ layerId, isProcessingImage: false }));
}
} catch (error) { } catch (error) {
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue')); if (signal.aborted) {
dispatch(caLayerIsProcessingImageChanged({ layerId, isProcessingImage: false })); // The listener was canceled - we need to cancel the pending processor batch, if there is one (could have changed by now).
const pendingBatchId = getState()
.controlLayers.present.layers.filter(isControlAdapterLayer)
.find((l) => l.id === layerId)?.controlAdapter.processorPendingBatchId;
if (pendingBatchId) {
cancelProcessorBatch(dispatch, layerId, pendingBatchId);
}
log.trace('Control Adapter preprocessor cancelled');
} else {
// Some other error condition...
console.log(error);
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
if (error instanceof Object) { if (error instanceof Object) {
if ('data' in error && 'status' in error) { if ('data' in error && 'status' in error) {
if (error.status === 403) { if (error.status === 403) {
dispatch(caLayerImageChanged({ layerId, imageDTO: null })); dispatch(caLayerImageChanged({ layerId, imageDTO: null }));
return; return;
}
} }
} }
}
dispatch( dispatch(
addToast({ addToast({
title: t('queue.graphFailedToQueue'), title: t('queue.graphFailedToQueue'),
status: 'error', status: 'error',
}) })
); );
}
} finally {
req.reset();
} }
}, },
}); });

View File

@ -13,6 +13,7 @@ type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
onChange: (value: T | null) => void; onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean; getIsDisabled?: (model: T) => boolean;
isLoading?: boolean; isLoading?: boolean;
groupByType?: boolean;
}; };
type UseGroupedModelComboboxReturn = { type UseGroupedModelComboboxReturn = {
@ -23,17 +24,21 @@ type UseGroupedModelComboboxReturn = {
noOptionsMessage: () => string; noOptionsMessage: () => string;
}; };
const groupByBaseFunc = <T extends AnyModelConfig>(model: T) => model.base.toUpperCase();
const groupByBaseAndTypeFunc = <T extends AnyModelConfig>(model: T) =>
`${model.base.toUpperCase()} / ${model.type.replaceAll('_', ' ').toUpperCase()}`;
export const useGroupedModelCombobox = <T extends AnyModelConfig>( export const useGroupedModelCombobox = <T extends AnyModelConfig>(
arg: UseGroupedModelComboboxArg<T> arg: UseGroupedModelComboboxArg<T>
): UseGroupedModelComboboxReturn => { ): UseGroupedModelComboboxReturn => {
const { t } = useTranslation(); const { t } = useTranslation();
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl'); const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading } = arg; const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, groupByType = false } = arg;
const options = useMemo<GroupBase<ComboboxOption>[]>(() => { const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
if (!modelConfigs) { if (!modelConfigs) {
return []; return [];
} }
const groupedModels = groupBy(modelConfigs, 'base'); const groupedModels = groupBy(modelConfigs, groupByType ? groupByBaseAndTypeFunc : groupByBaseFunc);
const _options = reduce( const _options = reduce(
groupedModels, groupedModels,
(acc, val, label) => { (acc, val, label) => {
@ -49,9 +54,9 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
}, },
[] as GroupBase<ComboboxOption>[] [] as GroupBase<ComboboxOption>[]
); );
_options.sort((a) => (a.label === base_model ? -1 : 1)); _options.sort((a) => (a.label?.split('/')[0]?.toLowerCase().includes(base_model) ? -1 : 1));
return _options; return _options;
}, [getIsDisabled, modelConfigs, base_model]); }, [modelConfigs, groupByType, getIsDisabled, base_model]);
const value = useMemo( const value = useMemo(
() => () =>

View File

@ -6,6 +6,7 @@ import {
} from 'features/controlAdapters/store/controlAdaptersSlice'; } from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { selectControlLayersSlice } from 'features/controlLayers/store/controlLayersSlice'; import { selectControlLayersSlice } from 'features/controlLayers/store/controlLayersSlice';
import type { Layer } from 'features/controlLayers/store/types';
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt'; import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
@ -14,9 +15,16 @@ import { selectGenerationSlice } from 'features/parameters/store/generationSlice
import { selectSystemSlice } from 'features/system/store/systemSlice'; import { selectSystemSlice } from 'features/system/store/systemSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import i18n from 'i18next'; import i18n from 'i18next';
import { forEach } from 'lodash-es'; import { forEach, upperFirst } from 'lodash-es';
import { getConnectedEdges } from 'reactflow'; import { getConnectedEdges } from 'reactflow';
const LAYER_TYPE_TO_TKEY: Record<Layer['type'], string> = {
initial_image_layer: 'controlLayers.globalInitialImage',
control_adapter_layer: 'controlLayers.globalControlAdapter',
ip_adapter_layer: 'controlLayers.globalIPAdapter',
regional_guidance_layer: 'controlLayers.regionalGuidance',
};
const selector = createMemoizedSelector( const selector = createMemoizedSelector(
[ [
selectControlAdaptersSlice, selectControlAdaptersSlice,
@ -29,21 +37,22 @@ const selector = createMemoizedSelector(
], ],
(controlAdapters, generation, system, nodes, dynamicPrompts, controlLayers, activeTabName) => { (controlAdapters, generation, system, nodes, dynamicPrompts, controlLayers, activeTabName) => {
const { model } = generation; const { model } = generation;
const { size } = controlLayers.present;
const { positivePrompt } = controlLayers.present; const { positivePrompt } = controlLayers.present;
const { isConnected } = system; const { isConnected } = system;
const reasons: string[] = []; const reasons: { prefix?: string; content: string }[] = [];
// Cannot generate if not connected // Cannot generate if not connected
if (!isConnected) { if (!isConnected) {
reasons.push(i18n.t('parameters.invoke.systemDisconnected')); reasons.push({ content: i18n.t('parameters.invoke.systemDisconnected') });
} }
if (activeTabName === 'workflows') { if (activeTabName === 'workflows') {
if (nodes.shouldValidateGraph) { if (nodes.shouldValidateGraph) {
if (!nodes.nodes.length) { if (!nodes.nodes.length) {
reasons.push(i18n.t('parameters.invoke.noNodesInGraph')); reasons.push({ content: i18n.t('parameters.invoke.noNodesInGraph') });
} }
nodes.nodes.forEach((node) => { nodes.nodes.forEach((node) => {
@ -55,7 +64,7 @@ const selector = createMemoizedSelector(
if (!nodeTemplate) { if (!nodeTemplate) {
// Node type not found // Node type not found
reasons.push(i18n.t('parameters.invoke.missingNodeTemplate')); reasons.push({ content: i18n.t('parameters.invoke.missingNodeTemplate') });
return; return;
} }
@ -68,17 +77,17 @@ const selector = createMemoizedSelector(
); );
if (!fieldTemplate) { if (!fieldTemplate) {
reasons.push(i18n.t('parameters.invoke.missingFieldTemplate')); reasons.push({ content: i18n.t('parameters.invoke.missingFieldTemplate') });
return; return;
} }
if (fieldTemplate.required && field.value === undefined && !hasConnection) { if (fieldTemplate.required && field.value === undefined && !hasConnection) {
reasons.push( reasons.push({
i18n.t('parameters.invoke.missingInputForField', { content: i18n.t('parameters.invoke.missingInputForField', {
nodeLabel: node.data.label || nodeTemplate.title, nodeLabel: node.data.label || nodeTemplate.title,
fieldLabel: field.label || fieldTemplate.title, fieldLabel: field.label || fieldTemplate.title,
}) }),
); });
return; return;
} }
}); });
@ -86,62 +95,94 @@ const selector = createMemoizedSelector(
} }
} else { } else {
if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) { if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) {
reasons.push(i18n.t('parameters.invoke.noPrompts')); reasons.push({ content: i18n.t('parameters.invoke.noPrompts') });
} }
if (!model) { if (!model) {
reasons.push(i18n.t('parameters.invoke.noModelSelected')); reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
} }
if (activeTabName === 'generation') { if (activeTabName === 'generation') {
// Handling for generation tab // Handling for generation tab
controlLayers.present.layers controlLayers.present.layers
.filter((l) => l.isEnabled) .filter((l) => l.isEnabled)
.flatMap((l) => { .forEach((l, i) => {
const layerLiteral = i18n.t('controlLayers.layers_one');
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[l.type]);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
const problems: string[] = [];
if (l.type === 'control_adapter_layer') { if (l.type === 'control_adapter_layer') {
return l.controlAdapter; // Must have model
} else if (l.type === 'ip_adapter_layer') { if (!l.controlAdapter.model) {
return l.ipAdapter; problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected'));
} else if (l.type === 'regional_guidance_layer') { }
return l.ipAdapters; // Model base must match
if (l.controlAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'));
}
// Must have a control image OR, if it has a processor, it must have a processed image
if (!l.controlAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoImageSelected'));
} else if (l.controlAdapter.processorConfig && !l.controlAdapter.processedImage) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterImageNotProcessed'));
}
// T2I Adapters require images have dimensions that are multiples of 64
if (l.controlAdapter.type === 't2i_adapter' && (size.width % 64 !== 0 || size.height % 64 !== 0)) {
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions'));
}
} }
return [];
})
.forEach((ca, i) => {
const hasNoModel = !ca.model;
const mismatchedModelBase = ca.model?.base !== model?.base;
const hasNoImage = !ca.image;
const imageNotProcessed =
(ca.type === 'controlnet' || ca.type === 't2i_adapter') && !ca.processedImage && ca.processorConfig;
if (hasNoModel) { if (l.type === 'ip_adapter_layer') {
reasons.push( // Must have model
i18n.t('parameters.invoke.noModelForControlAdapter', { if (!l.ipAdapter.model) {
number: i + 1, problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
}) }
); // Model base must match
if (l.ipAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
}
// Must have an image
if (!l.ipAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
}
} }
if (mismatchedModelBase) {
// This should never happen, just a sanity check if (l.type === 'initial_image_layer') {
reasons.push( // Must have an image
i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', { if (!l.image) {
number: i + 1, problems.push(i18n.t('parameters.invoke.layer.initialImageNoImageSelected'));
}) }
);
} }
if (hasNoImage) {
reasons.push( if (l.type === 'regional_guidance_layer') {
i18n.t('parameters.invoke.noControlImageForControlAdapter', { // Must have a region
number: i + 1, if (l.maskObjects.length === 0) {
}) problems.push(i18n.t('parameters.invoke.layer.rgNoRegion'));
); }
// Must have at least 1 prompt or IP Adapter
if (l.positivePrompt === null && l.negativePrompt === null && l.ipAdapters.length === 0) {
problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters'));
}
l.ipAdapters.forEach((ipAdapter) => {
// Must have model
if (!ipAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
}
// Model base must match
if (ipAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
}
// Must have an image
if (!ipAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
}
});
} }
if (imageNotProcessed) {
reasons.push( if (problems.length) {
i18n.t('parameters.invoke.imageNotProcessedForControlAdapter', { const content = upperFirst(problems.join(', '));
number: i + 1, reasons.push({ prefix, content });
})
);
} }
}); });
} else { } else {
@ -154,29 +195,19 @@ const selector = createMemoizedSelector(
} }
if (!ca.model) { if (!ca.model) {
reasons.push( reasons.push({ content: i18n.t('parameters.invoke.noModelForControlAdapter', { number: i + 1 }) });
i18n.t('parameters.invoke.noModelForControlAdapter', {
number: i + 1,
})
);
} else if (ca.model.base !== model?.base) { } else if (ca.model.base !== model?.base) {
// This should never happen, just a sanity check // This should never happen, just a sanity check
reasons.push( reasons.push({
i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', { content: i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', { number: i + 1 }),
number: i + 1, });
})
);
} }
if ( if (
!ca.controlImage || !ca.controlImage ||
(isControlNetOrT2IAdapter(ca) && !ca.processedControlImage && ca.processorType !== 'none') (isControlNetOrT2IAdapter(ca) && !ca.processedControlImage && ca.processorType !== 'none')
) { ) {
reasons.push( reasons.push({ content: i18n.t('parameters.invoke.noControlImageForControlAdapter', { number: i + 1 }) });
i18n.t('parameters.invoke.noControlImageForControlAdapter', {
number: i + 1,
})
);
} }
}); });
} }
@ -187,6 +218,6 @@ const selector = createMemoizedSelector(
); );
export const useIsReadyToEnqueue = () => { export const useIsReadyToEnqueue = () => {
const { isReady, reasons } = useAppSelector(selector); const value = useAppSelector(selector);
return { isReady, reasons }; return value;
}; };

View File

@ -21,7 +21,6 @@ import {
setShouldShowBoundingBox, setShouldShowBoundingBox,
} from 'features/canvas/store/canvasSlice'; } from 'features/canvas/store/canvasSlice';
import type { CanvasLayer } from 'features/canvas/store/canvasTypes'; import type { CanvasLayer } from 'features/canvas/store/canvasTypes';
import { LAYER_NAMES_DICT } from 'features/canvas/store/canvasTypes';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -216,13 +215,20 @@ const IAICanvasToolbar = () => {
[dispatch, isMaskEnabled] [dispatch, isMaskEnabled]
); );
const value = useMemo(() => LAYER_NAMES_DICT.filter((o) => o.value === layer)[0], [layer]); const layerOptions = useMemo<{ label: string; value: CanvasLayer }[]>(
() => [
{ label: t('unifiedCanvas.base'), value: 'base' },
{ label: t('unifiedCanvas.mask'), value: 'mask' },
],
[t]
);
const layerValue = useMemo(() => layerOptions.filter((o) => o.value === layer)[0] ?? null, [layer, layerOptions]);
return ( return (
<Flex alignItems="center" gap={2} flexWrap="wrap"> <Flex alignItems="center" gap={2} flexWrap="wrap">
<Tooltip label={`${t('unifiedCanvas.layer')} (Q)`}> <Tooltip label={`${t('unifiedCanvas.layer')} (Q)`}>
<FormControl isDisabled={isStaging} w="5rem"> <FormControl isDisabled={isStaging} w="5rem">
<Combobox value={value} options={LAYER_NAMES_DICT} onChange={handleChangeLayer} /> <Combobox value={layerValue} options={layerOptions} onChange={handleChangeLayer} />
</FormControl> </FormControl>
</Tooltip> </Tooltip>

View File

@ -5,11 +5,6 @@ import { z } from 'zod';
export type CanvasLayer = 'base' | 'mask'; export type CanvasLayer = 'base' | 'mask';
export const LAYER_NAMES_DICT: { label: string; value: CanvasLayer }[] = [
{ label: 'Base', value: 'base' },
{ label: 'Mask', value: 'mask' },
];
const zBoundingBoxScaleMethod = z.enum(['none', 'auto', 'manual']); const zBoundingBoxScaleMethod = z.enum(['none', 'auto', 'manual']);
export type BoundingBoxScaleMethod = z.infer<typeof zBoundingBoxScaleMethod>; export type BoundingBoxScaleMethod = z.infer<typeof zBoundingBoxScaleMethod>;
export const isBoundingBoxScaleMethod = (v: unknown): v is BoundingBoxScaleMethod => export const isBoundingBoxScaleMethod = (v: unknown): v is BoundingBoxScaleMethod =>

View File

@ -124,7 +124,7 @@ export const ControlAdapterImagePreview = memo(
controlImage && controlImage &&
processedControlImage && processedControlImage &&
!isMouseOverImage && !isMouseOverImage &&
!controlAdapter.isProcessingImage && !controlAdapter.processorPendingBatchId &&
controlAdapter.processorConfig !== null; controlAdapter.processorConfig !== null;
useEffect(() => { useEffect(() => {
@ -190,7 +190,7 @@ export const ControlAdapterImagePreview = memo(
/> />
</> </>
{controlAdapter.isProcessingImage && ( {controlAdapter.processorPendingBatchId !== null && (
<Flex <Flex
position="absolute" position="absolute"
top={0} top={0}

View File

@ -42,6 +42,7 @@ export const ControlAdapterModelCombobox = memo(({ modelKey, onChange: onChangeM
selectedModel, selectedModel,
getIsDisabled, getIsDisabled,
isLoading, isLoading,
groupByType: true,
}); });
return ( return (

View File

@ -2,14 +2,13 @@ import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAndIPAdapter/processors/types'; import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAndIPAdapter/processors/types';
import type { DepthAnythingModelSize, DepthAnythingProcessorConfig } from 'features/controlLayers/util/controlAdapters'; import type { DepthAnythingModelSize, DepthAnythingProcessorConfig } from 'features/controlLayers/util/controlAdapters';
import { CA_PROCESSOR_DATA, isDepthAnythingModelSize } from 'features/controlLayers/util/controlAdapters'; import { isDepthAnythingModelSize } from 'features/controlLayers/util/controlAdapters';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import ProcessorWrapper from './ProcessorWrapper'; import ProcessorWrapper from './ProcessorWrapper';
type Props = ProcessorComponentProps<DepthAnythingProcessorConfig>; type Props = ProcessorComponentProps<DepthAnythingProcessorConfig>;
const DEFAULTS = CA_PROCESSOR_DATA['depth_anything_image_processor'].buildDefaults();
export const DepthAnythingProcessor = memo(({ onChange, config }: Props) => { export const DepthAnythingProcessor = memo(({ onChange, config }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();
@ -38,12 +37,7 @@ export const DepthAnythingProcessor = memo(({ onChange, config }: Props) => {
<ProcessorWrapper> <ProcessorWrapper>
<FormControl> <FormControl>
<FormLabel m={0}>{t('controlnet.modelSize')}</FormLabel> <FormLabel m={0}>{t('controlnet.modelSize')}</FormLabel>
<Combobox <Combobox value={value} options={options} onChange={handleModelSizeChange} isSearchable={false} />
value={value}
defaultInputValue={DEFAULTS.model_size}
options={options}
onChange={handleModelSizeChange}
/>
</FormControl> </FormControl>
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -27,7 +27,7 @@ import { modelChanged } from 'features/parameters/store/generationSlice';
import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas'; import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension'; import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import type { IRect, Vector2d } from 'konva/lib/types'; import type { IRect, Vector2d } from 'konva/lib/types';
import { isEqual, partition } from 'lodash-es'; import { isEqual, partition, unset } from 'lodash-es';
import { atom } from 'nanostores'; import { atom } from 'nanostores';
import type { RgbColor } from 'react-colorful'; import type { RgbColor } from 'react-colorful';
import type { UndoableOptions } from 'redux-undo'; import type { UndoableOptions } from 'redux-undo';
@ -49,7 +49,7 @@ import type {
} from './types'; } from './types';
export const initialControlLayersState: ControlLayersState = { export const initialControlLayersState: ControlLayersState = {
_version: 2, _version: 3,
selectedLayerId: null, selectedLayerId: null,
brushSize: 100, brushSize: 100,
layers: [], layers: [],
@ -334,13 +334,13 @@ export const controlLayersSlice = createSlice({
const layer = selectCALayerOrThrow(state, layerId); const layer = selectCALayerOrThrow(state, layerId);
layer.opacity = opacity; layer.opacity = opacity;
}, },
caLayerIsProcessingImageChanged: ( caLayerProcessorPendingBatchIdChanged: (
state, state,
action: PayloadAction<{ layerId: string; isProcessingImage: boolean }> action: PayloadAction<{ layerId: string; batchId: string | null }>
) => { ) => {
const { layerId, isProcessingImage } = action.payload; const { layerId, batchId } = action.payload;
const layer = selectCALayerOrThrow(state, layerId); const layer = selectCALayerOrThrow(state, layerId);
layer.controlAdapter.isProcessingImage = isProcessingImage; layer.controlAdapter.processorPendingBatchId = batchId;
}, },
//#endregion //#endregion
@ -800,7 +800,7 @@ export const {
caLayerProcessorConfigChanged, caLayerProcessorConfigChanged,
caLayerIsFilterEnabledChanged, caLayerIsFilterEnabledChanged,
caLayerOpacityChanged, caLayerOpacityChanged,
caLayerIsProcessingImageChanged, caLayerProcessorPendingBatchIdChanged,
// IPA Layers // IPA Layers
ipaLayerAdded, ipaLayerAdded,
ipaLayerRecalled, ipaLayerRecalled,
@ -857,7 +857,16 @@ export const selectControlLayersSlice = (state: RootState) => state.controlLayer
const migrateControlLayersState = (state: any): any => { const migrateControlLayersState = (state: any): any => {
if (state._version === 1) { if (state._version === 1) {
// Reset state for users on v1 (e.g. beta users), some changes could cause // Reset state for users on v1 (e.g. beta users), some changes could cause
return deepClone(initialControlLayersState); state = deepClone(initialControlLayersState);
}
if (state._version === 2) {
// The CA `isProcessingImage` flag was replaced with a `processorPendingBatchId` property, fix up CA layers
for (const layer of (state as ControlLayersState).layers) {
if (layer.type === 'control_adapter_layer') {
layer.controlAdapter.processorPendingBatchId = null;
unset(layer.controlAdapter, 'isProcessingImage');
}
}
} }
return state; return state;
}; };

View File

@ -113,7 +113,7 @@ export const zLayer = z.discriminatedUnion('type', [
export type Layer = z.infer<typeof zLayer>; export type Layer = z.infer<typeof zLayer>;
export type ControlLayersState = { export type ControlLayersState = {
_version: 2; _version: 3;
selectedLayerId: string | null; selectedLayerId: string | null;
layers: Layer[]; layers: Layer[];
brushSize: number; brushSize: number;

View File

@ -198,8 +198,8 @@ const zControlAdapterBase = z.object({
weight: z.number().gte(0).lte(1), weight: z.number().gte(0).lte(1),
image: zImageWithDims.nullable(), image: zImageWithDims.nullable(),
processedImage: zImageWithDims.nullable(), processedImage: zImageWithDims.nullable(),
isProcessingImage: z.boolean(),
processorConfig: zProcessorConfig.nullable(), processorConfig: zProcessorConfig.nullable(),
processorPendingBatchId: z.string().nullable().default(null),
beginEndStepPct: zBeginEndStepPct, beginEndStepPct: zBeginEndStepPct,
}); });
@ -521,8 +521,8 @@ export const initialControlNetV2: Omit<ControlNetConfigV2, 'id'> = {
controlMode: 'balanced', controlMode: 'balanced',
image: null, image: null,
processedImage: null, processedImage: null,
isProcessingImage: false,
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(), processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
processorPendingBatchId: null,
}; };
export const initialT2IAdapterV2: Omit<T2IAdapterConfigV2, 'id'> = { export const initialT2IAdapterV2: Omit<T2IAdapterConfigV2, 'id'> = {
@ -532,8 +532,8 @@ export const initialT2IAdapterV2: Omit<T2IAdapterConfigV2, 'id'> = {
beginEndStepPct: [0, 1], beginEndStepPct: [0, 1],
image: null, image: null,
processedImage: null, processedImage: null,
isProcessingImage: false,
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(), processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
processorPendingBatchId: null,
}; };
export const initialIPAdapterV2: Omit<IPAdapterConfigV2, 'id'> = { export const initialIPAdapterV2: Omit<IPAdapterConfigV2, 'id'> = {

View File

@ -587,7 +587,7 @@ const parseControlNetToControlAdapterLayer: MetadataParseFunc<ControlAdapterLaye
image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null, image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null,
processedImage: processedImageDTO ? imageDTOToImageWithDims(processedImageDTO) : null, processedImage: processedImageDTO ? imageDTOToImageWithDims(processedImageDTO) : null,
processorConfig, processorConfig,
isProcessingImage: false, processorPendingBatchId: null,
}, },
}; };
@ -651,7 +651,7 @@ const parseT2IAdapterToControlAdapterLayer: MetadataParseFunc<ControlAdapterLaye
image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null, image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null,
processedImage: processedImageDTO ? imageDTOToImageWithDims(processedImageDTO) : null, processedImage: processedImageDTO ? imageDTOToImageWithDims(processedImageDTO) : null,
processorConfig, processorConfig,
isProcessingImage: false, processorPendingBatchId: null,
}, },
}; };

View File

@ -16,25 +16,26 @@ export const InvokeQueueBackButton = memo(() => {
return ( return (
<Flex pos="relative" flexGrow={1} minW="240px"> <Flex pos="relative" flexGrow={1} minW="240px">
<QueueIterationsNumberInput /> <QueueIterationsNumberInput />
<Button <QueueButtonTooltip>
onClick={queueBack} <Button
isLoading={isLoading || isLoadingDynamicPrompts} onClick={queueBack}
loadingText={invoke} isLoading={isLoading || isLoadingDynamicPrompts}
isDisabled={isDisabled} loadingText={invoke}
rightIcon={<RiSparkling2Fill />} isDisabled={isDisabled}
tooltip={<QueueButtonTooltip />} rightIcon={<RiSparkling2Fill />}
variant="solid" variant="solid"
zIndex={1} zIndex={1}
colorScheme="invokeYellow" colorScheme="invokeYellow"
size="lg" size="lg"
w="calc(100% - 60px)" w="calc(100% - 60px)"
flexShrink={0} flexShrink={0}
justifyContent="space-between" justifyContent="space-between"
spinnerPlacement="end" spinnerPlacement="end"
> >
<span>{invoke}</span> <span>{invoke}</span>
<Spacer /> <Spacer />
</Button> </Button>
</QueueButtonTooltip>
</Flex> </Flex>
); );
}); });

View File

@ -1,10 +1,11 @@
import { Divider, Flex, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library'; import { Divider, Flex, ListItem, Text, Tooltip, UnorderedList } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { useIsReadyToEnqueue } from 'common/hooks/useIsReadyToEnqueue'; import { useIsReadyToEnqueue } from 'common/hooks/useIsReadyToEnqueue';
import { selectControlLayersSlice } from 'features/controlLayers/store/controlLayersSlice'; import { selectControlLayersSlice } from 'features/controlLayers/store/controlLayersSlice';
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt'; import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
import type { PropsWithChildren } from 'react';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useEnqueueBatchMutation } from 'services/api/endpoints/queue'; import { useEnqueueBatchMutation } from 'services/api/endpoints/queue';
@ -21,17 +22,32 @@ type Props = {
prepend?: boolean; prepend?: boolean;
}; };
export const QueueButtonTooltip = memo(({ prepend = false }: Props) => { export const QueueButtonTooltip = (props: PropsWithChildren<Props>) => {
return (
<Tooltip label={<TooltipContent prepend={props.prepend} />} maxW={512}>
{props.children}
</Tooltip>
);
};
const TooltipContent = memo(({ prepend = false }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const { isReady, reasons } = useIsReadyToEnqueue(); const { isReady, reasons } = useIsReadyToEnqueue();
const isLoadingDynamicPrompts = useAppSelector((s) => s.dynamicPrompts.isLoading); const isLoadingDynamicPrompts = useAppSelector((s) => s.dynamicPrompts.isLoading);
const promptsCount = useAppSelector(selectPromptsCount); const promptsCount = useAppSelector(selectPromptsCount);
const iterations = useAppSelector((s) => s.generation.iterations); const iterationsCount = useAppSelector((s) => s.generation.iterations);
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId); const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
const autoAddBoardName = useBoardName(autoAddBoardId); const autoAddBoardName = useBoardName(autoAddBoardId);
const [_, { isLoading }] = useEnqueueBatchMutation({ const [_, { isLoading }] = useEnqueueBatchMutation({
fixedCacheKey: 'enqueueBatch', fixedCacheKey: 'enqueueBatch',
}); });
const queueCountPredictionLabel = useMemo(() => {
const generationCount = Math.min(promptsCount * iterationsCount, 10000);
const prompts = t('queue.prompts', { count: promptsCount });
const iterations = t('queue.iterations', { count: iterationsCount });
const generations = t('queue.generations', { count: generationCount });
return `${promptsCount} ${prompts} \u00d7 ${iterationsCount} ${iterations} -> ${generationCount} ${generations}`.toLowerCase();
}, [iterationsCount, promptsCount, t]);
const label = useMemo(() => { const label = useMemo(() => {
if (isLoading) { if (isLoading) {
@ -52,20 +68,21 @@ export const QueueButtonTooltip = memo(({ prepend = false }: Props) => {
return ( return (
<Flex flexDir="column" gap={1}> <Flex flexDir="column" gap={1}>
<Text fontWeight="semibold">{label}</Text> <Text fontWeight="semibold">{label}</Text>
<Text> <Text>{queueCountPredictionLabel}</Text>
{t('queue.queueCountPrediction', {
promptsCount,
iterations,
count: Math.min(promptsCount * iterations, 10000),
})}
</Text>
{reasons.length > 0 && ( {reasons.length > 0 && (
<> <>
<Divider opacity={0.2} borderColor="base.900" /> <Divider opacity={0.2} borderColor="base.900" />
<UnorderedList> <UnorderedList>
{reasons.map((reason, i) => ( {reasons.map((reason, i) => (
<ListItem key={`${reason}.${i}`}> <ListItem key={`${reason.content}.${i}`}>
<Text>{reason}</Text> <span>
{reason.prefix && (
<Text as="span" fontWeight="semibold">
{reason.prefix}:{' '}
</Text>
)}
<Text as="span">{reason.content}</Text>
</span>
</ListItem> </ListItem>
))} ))}
</UnorderedList> </UnorderedList>
@ -82,4 +99,4 @@ export const QueueButtonTooltip = memo(({ prepend = false }: Props) => {
); );
}); });
QueueButtonTooltip.displayName = 'QueueButtonTooltip'; TooltipContent.displayName = 'QueueButtonTooltipContent';

View File

@ -10,15 +10,16 @@ const QueueFrontButton = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const { queueFront, isLoading, isDisabled } = useQueueFront(); const { queueFront, isLoading, isDisabled } = useQueueFront();
return ( return (
<IconButton <QueueButtonTooltip prepend>
aria-label={t('queue.queueFront')} <IconButton
isDisabled={isDisabled} aria-label={t('queue.queueFront')}
isLoading={isLoading} isDisabled={isDisabled}
onClick={queueFront} isLoading={isLoading}
tooltip={<QueueButtonTooltip prepend />} onClick={queueFront}
icon={<AiFillThunderbolt />} icon={<AiFillThunderbolt />}
size="lg" size="lg"
/> />
</QueueButtonTooltip>
); );
}; };

View File

@ -63,16 +63,17 @@ const FloatingSidePanelButtons = (props: Props) => {
sx={floatingButtonStyles} sx={floatingButtonStyles}
icon={<PiSlidersHorizontalBold size="16px" />} icon={<PiSlidersHorizontalBold size="16px" />}
/> />
<IconButton <QueueButtonTooltip>
aria-label={t('queue.queueBack')} <IconButton
onClick={queueBack} aria-label={t('queue.queueBack')}
isLoading={isLoading} onClick={queueBack}
isDisabled={isDisabled} isLoading={isLoading}
icon={queueButtonIcon} isDisabled={isDisabled}
colorScheme="invokeYellow" icon={queueButtonIcon}
tooltip={<QueueButtonTooltip />} colorScheme="invokeYellow"
sx={floatingButtonStyles} sx={floatingButtonStyles}
/> />
</QueueButtonTooltip>
<CancelCurrentQueueItemIconButton sx={floatingButtonStyles} /> <CancelCurrentQueueItemIconButton sx={floatingButtonStyles} />
</ButtonGroup> </ButtonGroup>
<ClearAllQueueIconButton sx={floatingButtonStyles} onOpen={disclosure.onOpen} /> <ClearAllQueueIconButton sx={floatingButtonStyles} onOpen={disclosure.onOpen} />