feat(ui): revise graph building for control layers, fix issues w/ invocation complete events

This commit is contained in:
psychedelicious 2024-08-23 14:49:16 +10:00
parent d9f4266630
commit 427ea6da5c
24 changed files with 469 additions and 508 deletions

View File

@ -1,3 +1,4 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions'; import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager'; import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
@ -5,9 +6,14 @@ import { sessionStagingAreaReset, sessionStartedStaging } from 'features/control
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph'; import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph'; import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
import type { Invocation } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
const log = logger('generation');
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => { export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
predicate: (action): action is ReturnType<typeof enqueueRequested> => predicate: (action): action is ReturnType<typeof enqueueRequested> =>
@ -27,20 +33,28 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
} }
try { try {
let g; let g: Graph;
let noise: Invocation<'noise'>;
let posCond: Invocation<'compel' | 'sdxl_compel_prompt'>;
assert(model, 'No model found in state'); assert(model, 'No model found in state');
const base = model.base; const base = model.base;
if (base === 'sdxl') { if (base === 'sdxl') {
g = await buildSDXLGraph(state, manager); const result = await buildSDXLGraph(state, manager);
g = result.g;
noise = result.noise;
posCond = result.posCond;
} else if (base === 'sd-1' || base === 'sd-2') { } else if (base === 'sd-1' || base === 'sd-2') {
g = await buildSD1Graph(state, manager); const result = await buildSD1Graph(state, manager);
g = result.g;
noise = result.noise;
posCond = result.posCond;
} else { } else {
assert(false, `No graph builders for base ${base}`); assert(false, `No graph builders for base ${base}`);
} }
const batchConfig = prepareLinearUIBatch(state, g, prepend); const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond);
const req = dispatch( const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
@ -50,6 +64,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
req.reset(); req.reset();
await req.unwrap(); await req.unwrap();
} catch (error) { } catch (error) {
log.error({ error: serializeError(error) }, 'Failed to enqueue batch');
if (didStartStaging && getState().canvasV2.session.isStaging) { if (didStartStaging && getState().canvasV2.session.isStaging) {
dispatch(sessionStagingAreaReset()); dispatch(sessionStagingAreaReset());
} }

View File

@ -1,5 +1,5 @@
import type { Coordinate, Rect, RgbaColor } from 'features/controlLayers/store/types'; import type { Coordinate, Rect } from 'features/controlLayers/store/types';
import Konva from 'konva'; import type Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node'; import type { KonvaEventObject } from 'konva/lib/Node';
import type { Vector2d } from 'konva/lib/types'; import type { Vector2d } from 'konva/lib/types';
import { customAlphabet } from 'nanoid'; import { customAlphabet } from 'nanoid';
@ -279,30 +279,6 @@ export const konvaNodeToBlob = (node: Konva.Node, bbox?: Rect): Promise<Blob> =>
return canvasToBlob(canvas); return canvasToBlob(canvas);
}; };
/**
* Gets the pixel under the cursor on the stage, or null if the cursor is not over the stage.
* @param stage The konva stage
*/
export const getPixelUnderCursor = (stage: Konva.Stage): RgbaColor | null => {
const cursorPos = stage.getPointerPosition();
const pixelRatio = Konva.pixelRatio;
if (!cursorPos) {
return null;
}
const ctx = stage.toCanvas().getContext('2d');
if (!ctx) {
return null;
}
const [r, g, b, a] = ctx.getImageData(cursorPos.x * pixelRatio, cursorPos.y * pixelRatio, 1, 1).data;
if (r === undefined || g === undefined || b === undefined || a === undefined) {
return null;
}
return { r, g, b, a };
};
export const previewBlob = (blob: Blob, label?: string) => { export const previewBlob = (blob: Blob, label?: string) => {
const url = URL.createObjectURL(blob); const url = URL.createObjectURL(blob);
const w = window.open(''); const w = window.open('');

View File

@ -1,4 +1,5 @@
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type { GraphType } from 'features/nodes/util/graph/generation/Graph'; import type { GraphType } from 'features/nodes/util/graph/generation/Graph';
import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { Graph } from 'features/nodes/util/graph/generation/Graph';
@ -7,8 +8,6 @@ import type { ImageDTO } from 'services/api/types';
import { isSpandrelImageToImageModelConfig } from 'services/api/types'; import { isSpandrelImageToImageModelConfig } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { SPANDREL } from './constants';
type Arg = { type Arg = {
image: ImageDTO; image: ImageDTO;
state: RootState; state: RootState;
@ -21,8 +20,8 @@ export const buildAdHocPostProcessingGraph = async ({ image, state }: Arg): Prom
const g = new Graph('adhoc-post-processing-graph'); const g = new Graph('adhoc-post-processing-graph');
g.addNode({ g.addNode({
id: SPANDREL,
type: 'spandrel_image_to_image', type: 'spandrel_image_to_image',
id: getPrefixedId('spandrel'),
image_to_image_model: postProcessingModel, image_to_image_model: postProcessingModel,
image, image,
board: getBoardField(state), board: getBoardField(state),

View File

@ -3,11 +3,15 @@ import { generateSeeds } from 'common/util/generateSeeds';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { range } from 'lodash-es'; import { range } from 'lodash-es';
import type { components } from 'services/api/schema'; import type { components } from 'services/api/schema';
import type { Batch, BatchConfig } from 'services/api/types'; import type { Batch, BatchConfig, Invocation } from 'services/api/types';
import { NOISE, POSITIVE_CONDITIONING } from './constants'; export const prepareLinearUIBatch = (
state: RootState,
export const prepareLinearUIBatch = (state: RootState, g: Graph, prepend: boolean): BatchConfig => { g: Graph,
prepend: boolean,
noise: Invocation<'noise'>,
posCond: Invocation<'compel' | 'sdxl_compel_prompt'>
): BatchConfig => {
const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.canvasV2.params; const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.canvasV2.params;
const { prompts, seedBehaviour } = state.dynamicPrompts; const { prompts, seedBehaviour } = state.dynamicPrompts;
@ -22,13 +26,11 @@ export const prepareLinearUIBatch = (state: RootState, g: Graph, prepend: boolea
start: shouldRandomizeSeed ? undefined : seed, start: shouldRandomizeSeed ? undefined : seed,
}); });
if (g.hasNode(NOISE)) {
firstBatchDatumList.push({ firstBatchDatumList.push({
node_path: NOISE, node_path: noise.id,
field_name: 'seed', field_name: 'seed',
items: seeds, items: seeds,
}); });
}
// add to metadata // add to metadata
g.removeMetadata(['seed']); g.removeMetadata(['seed']);
@ -44,13 +46,11 @@ export const prepareLinearUIBatch = (state: RootState, g: Graph, prepend: boolea
start: shouldRandomizeSeed ? undefined : seed, start: shouldRandomizeSeed ? undefined : seed,
}); });
if (g.hasNode(NOISE)) {
secondBatchDatumList.push({ secondBatchDatumList.push({
node_path: NOISE, node_path: noise.id,
field_name: 'seed', field_name: 'seed',
items: seeds, items: seeds,
}); });
}
// add to metadata // add to metadata
g.removeMetadata(['seed']); g.removeMetadata(['seed']);
@ -65,13 +65,11 @@ export const prepareLinearUIBatch = (state: RootState, g: Graph, prepend: boolea
const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts; const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts;
// zipped batch of prompts // zipped batch of prompts
if (g.hasNode(POSITIVE_CONDITIONING)) {
firstBatchDatumList.push({ firstBatchDatumList.push({
node_path: POSITIVE_CONDITIONING, node_path: posCond.id,
field_name: 'prompt', field_name: 'prompt',
items: extendedPrompts, items: extendedPrompts,
}); });
}
// add to metadata // add to metadata
g.removeMetadata(['positive_prompt']); g.removeMetadata(['positive_prompt']);
@ -82,13 +80,11 @@ export const prepareLinearUIBatch = (state: RootState, g: Graph, prepend: boolea
}); });
if (shouldConcatPrompts && model?.base === 'sdxl') { if (shouldConcatPrompts && model?.base === 'sdxl') {
if (g.hasNode(POSITIVE_CONDITIONING)) {
firstBatchDatumList.push({ firstBatchDatumList.push({
node_path: POSITIVE_CONDITIONING, node_path: posCond.id,
field_name: 'style', field_name: 'style',
items: extendedPrompts, items: extendedPrompts,
}); });
}
// add to metadata // add to metadata
g.removeMetadata(['positive_style_prompt']); g.removeMetadata(['positive_style_prompt']);

View File

@ -1,25 +1,11 @@
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addSDXLLoRAs } from 'features/nodes/util/graph/generation/addSDXLLoRAs'; import { addSDXLLoRAs } from 'features/nodes/util/graph/generation/addSDXLLoRAs';
import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { isNonRefinerMainModelConfig, isSpandrelImageToImageModelConfig } from 'services/api/types'; import { isNonRefinerMainModelConfig, isSpandrelImageToImageModelConfig } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import {
CLIP_SKIP,
CONTROL_NET_COLLECT,
IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
SDXL_MODEL_LOADER,
SPANDREL,
TILED_MULTI_DIFFUSION_DENOISE_LATENTS,
UNSHARP_MASK,
VAE_LOADER,
} from './constants';
import { addLoRAs } from './generation/addLoRAs'; import { addLoRAs } from './generation/addLoRAs';
import { getBoardField, getPresetModifiedPrompts } from './graphBuilderUtils'; import { getBoardField, getPresetModifiedPrompts } from './graphBuilderUtils';
@ -35,8 +21,8 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
const g = new Graph(); const g = new Graph();
const upscaleNode = g.addNode({ const upscaleNode = g.addNode({
id: SPANDREL,
type: 'spandrel_image_to_image_autoscale', type: 'spandrel_image_to_image_autoscale',
id: getPrefixedId('spandrel_autoscale'),
image: upscaleInitialImage, image: upscaleInitialImage,
image_to_image_model: upscaleModel, image_to_image_model: upscaleModel,
fit_to_multiple_of_8: true, fit_to_multiple_of_8: true,
@ -44,8 +30,8 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
}); });
const unsharpMaskNode2 = g.addNode({ const unsharpMaskNode2 = g.addNode({
id: `${UNSHARP_MASK}_2`,
type: 'unsharp_mask', type: 'unsharp_mask',
id: getPrefixedId('unsharp_2'),
radius: 2, radius: 2,
strength: 60, strength: 60,
}); });
@ -53,8 +39,8 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
g.addEdge(upscaleNode, 'image', unsharpMaskNode2, 'image'); g.addEdge(upscaleNode, 'image', unsharpMaskNode2, 'image');
const noiseNode = g.addNode({ const noiseNode = g.addNode({
id: NOISE,
type: 'noise', type: 'noise',
id: getPrefixedId('noise'),
seed, seed,
}); });
@ -62,8 +48,8 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
g.addEdge(unsharpMaskNode2, 'height', noiseNode, 'height'); g.addEdge(unsharpMaskNode2, 'height', noiseNode, 'height');
const i2lNode = g.addNode({ const i2lNode = g.addNode({
id: IMAGE_TO_LATENTS,
type: 'i2l', type: 'i2l',
id: getPrefixedId('i2l'),
fp32: vaePrecision === 'fp32', fp32: vaePrecision === 'fp32',
tiled: true, tiled: true,
}); });
@ -72,7 +58,7 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
const l2iNode = g.addNode({ const l2iNode = g.addNode({
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: getPrefixedId('l2i'),
fp32: vaePrecision === 'fp32', fp32: vaePrecision === 'fp32',
tiled: true, tiled: true,
board: getBoardField(state), board: getBoardField(state),
@ -80,8 +66,8 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
}); });
const tiledMultidiffusionNode = g.addNode({ const tiledMultidiffusionNode = g.addNode({
id: TILED_MULTI_DIFFUSION_DENOISE_LATENTS,
type: 'tiled_multi_diffusion_denoise_latents', type: 'tiled_multi_diffusion_denoise_latents',
id: getPrefixedId('tiled_multidiffusion_denoise_latents'),
tile_height: 1024, // is this dependent on base model tile_height: 1024, // is this dependent on base model
tile_width: 1024, // is this dependent on base model tile_width: 1024, // is this dependent on base model
tile_overlap: 128, tile_overlap: 128,
@ -102,19 +88,19 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
posCondNode = g.addNode({ posCondNode = g.addNode({
type: 'sdxl_compel_prompt', type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING, id: getPrefixedId('pos_cond'),
prompt: positivePrompt, prompt: positivePrompt,
style: positiveStylePrompt, style: positiveStylePrompt,
}); });
negCondNode = g.addNode({ negCondNode = g.addNode({
type: 'sdxl_compel_prompt', type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING, id: getPrefixedId('neg_cond'),
prompt: negativePrompt, prompt: negativePrompt,
style: negativeStylePrompt, style: negativeStylePrompt,
}); });
modelNode = g.addNode({ modelNode = g.addNode({
type: 'sdxl_model_loader', type: 'sdxl_model_loader',
id: SDXL_MODEL_LOADER, id: getPrefixedId('sdxl_model_loader'),
model, model,
}); });
g.addEdge(modelNode, 'clip', posCondNode, 'clip'); g.addEdge(modelNode, 'clip', posCondNode, 'clip');
@ -135,22 +121,22 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
posCondNode = g.addNode({ posCondNode = g.addNode({
type: 'compel', type: 'compel',
id: POSITIVE_CONDITIONING, id: getPrefixedId('pos_cond'),
prompt: positivePrompt, prompt: positivePrompt,
}); });
negCondNode = g.addNode({ negCondNode = g.addNode({
type: 'compel', type: 'compel',
id: NEGATIVE_CONDITIONING, id: getPrefixedId('neg_cond'),
prompt: negativePrompt, prompt: negativePrompt,
}); });
modelNode = g.addNode({ modelNode = g.addNode({
type: 'main_model_loader', type: 'main_model_loader',
id: MAIN_MODEL_LOADER, id: getPrefixedId('sd1_model_loader'),
model, model,
}); });
const clipSkipNode = g.addNode({ const clipSkipNode = g.addNode({
type: 'clip_skip', type: 'clip_skip',
id: CLIP_SKIP, id: getPrefixedId('clip_skip'),
}); });
g.addEdge(modelNode, 'clip', clipSkipNode, 'clip'); g.addEdge(modelNode, 'clip', clipSkipNode, 'clip');
@ -193,8 +179,8 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
let vaeNode; let vaeNode;
if (vae) { if (vae) {
vaeNode = g.addNode({ vaeNode = g.addNode({
id: VAE_LOADER,
type: 'vae_loader', type: 'vae_loader',
id: getPrefixedId('vae'),
vae_model: vae, vae_model: vae,
}); });
} }
@ -236,8 +222,8 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
g.addEdge(unsharpMaskNode2, 'image', controlnetNode2, 'image'); g.addEdge(unsharpMaskNode2, 'image', controlnetNode2, 'image');
const collectNode = g.addNode({ const collectNode = g.addNode({
id: CONTROL_NET_COLLECT,
type: 'collect', type: 'collect',
id: getPrefixedId('controlnet_collector'),
}); });
g.addEdge(controlnetNode1, 'control', collectNode, 'item'); g.addEdge(controlnetNode1, 'control', collectNode, 'item');
g.addEdge(controlnetNode2, 'control', collectNode, 'item'); g.addEdge(controlnetNode2, 'control', collectNode, 'item');

View File

@ -1,69 +0,0 @@
// friendly node ids
export const POSITIVE_CONDITIONING = 'positive_conditioning';
export const NEGATIVE_CONDITIONING = 'negative_conditioning';
export const DENOISE_LATENTS = 'denoise_latents';
export const DENOISE_LATENTS_HRF = 'denoise_latents_hrf';
export const LATENTS_TO_IMAGE = 'latents_to_image';
export const LATENTS_TO_IMAGE_HRF_HR = 'latents_to_image_hrf_hr';
export const LATENTS_TO_IMAGE_HRF_LR = 'latents_to_image_hrf_lr';
export const IMAGE_TO_LATENTS_HRF = 'image_to_latents_hrf';
export const RESIZE_HRF = 'resize_hrf';
export const ESRGAN_HRF = 'esrgan_hrf';
export const NSFW_CHECKER = 'nsfw_checker';
export const WATERMARKER = 'invisible_watermark';
export const NOISE = 'noise';
export const NOISE_HRF = 'noise_hrf';
export const MAIN_MODEL_LOADER = 'main_model_loader';
export const VAE_LOADER = 'vae_loader';
export const LORA_LOADER = 'lora_loader';
export const CLIP_SKIP = 'clip_skip';
export const IMAGE_TO_LATENTS = 'image_to_latents';
export const RESIZE = 'resize_image';
export const IMG2IMG_RESIZE = 'img2img_resize';
export const CANVAS_OUTPUT = 'canvas_output';
export const INPAINT_IMAGE = 'inpaint_image';
export const INPAINT_IMAGE_RESIZE_UP = 'inpaint_image_resize_up';
export const INPAINT_IMAGE_RESIZE_DOWN = 'inpaint_image_resize_down';
export const INPAINT_INFILL = 'inpaint_infill';
export const INPAINT_INFILL_RESIZE_DOWN = 'inpaint_infill_resize_down';
export const INPAINT_CREATE_MASK = 'inpaint_create_mask';
export const CANVAS_COHERENCE_NOISE = 'canvas_coherence_noise';
export const MASK_FROM_ALPHA = 'tomask';
export const MASK_COMBINE = 'mask_combine';
export const MASK_RESIZE_UP = 'mask_resize_up';
export const MASK_RESIZE_DOWN = 'mask_resize_down';
export const CONTROL_NET_COLLECT = 'control_net_collect';
export const IP_ADAPTER_COLLECT = 'ip_adapter_collect';
export const T2I_ADAPTER_COLLECT = 't2i_adapter_collect';
export const METADATA = 'core_metadata';
export const SPANDREL = 'spandrel';
export const SDXL_MODEL_LOADER = 'sdxl_model_loader';
export const SDXL_DENOISE_LATENTS = 'sdxl_denoise_latents';
export const SDXL_REFINER_MODEL_LOADER = 'sdxl_refiner_model_loader';
export const SDXL_REFINER_POSITIVE_CONDITIONING = 'sdxl_refiner_positive_conditioning';
export const SDXL_REFINER_NEGATIVE_CONDITIONING = 'sdxl_refiner_negative_conditioning';
export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents';
export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
export const SEAMLESS = 'seamless';
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
export const PROMPT_REGION_MASK_TO_TENSOR_PREFIX = 'prompt_region_mask_to_tensor';
export const PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX = 'prompt_region_invert_tensor_mask';
export const PROMPT_REGION_POSITIVE_COND_PREFIX = 'prompt_region_positive_cond';
export const PROMPT_REGION_NEGATIVE_COND_PREFIX = 'prompt_region_negative_cond';
export const PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX = 'prompt_region_positive_cond_inverted';
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';
export const UNSHARP_MASK = 'unsharp_mask';
export const TILED_MULTI_DIFFUSION_DENOISE_LATENTS = 'tiled_multi_diffusion_denoise_latents';
// friendly graph ids
export const CONTROL_LAYERS_GRAPH = 'control_layers_graph';
export const SDXL_CONTROL_LAYERS_GRAPH = 'sdxl_control_layers_graph';
export const CANVAS_TEXT_TO_IMAGE_GRAPH = 'canvas_text_to_image_graph';
export const CANVAS_IMAGE_TO_IMAGE_GRAPH = 'canvas_image_to_image_graph';
export const CANVAS_INPAINT_GRAPH = 'canvas_inpaint_graph';
export const CANVAS_OUTPAINT_GRAPH = 'canvas_outpaint_graph';
export const SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH = 'sdxl_canvas_text_to_image_graph';
export const SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH = 'sdxl_canvas_image_to_image_graph';
export const SDXL_CANVAS_INPAINT_GRAPH = 'sdxl_canvas_inpaint_graph';
export const SDXL_CANVAS_OUTPAINT_GRAPH = 'sdxl_canvas_outpaint_graph';

View File

@ -5,58 +5,81 @@ import type {
Rect, Rect,
T2IAdapterConfig, T2IAdapterConfig,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import { CONTROL_NET_COLLECT, T2I_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types'; import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
export const addControlAdapters = async ( type AddControlNetsResult = {
addedControlNets: number;
};
export const addControlNets = async (
manager: CanvasManager, manager: CanvasManager,
layers: CanvasControlLayerState[], layers: CanvasControlLayerState[],
g: Graph, g: Graph,
bbox: Rect, bbox: Rect,
denoise: Invocation<'denoise_latents'>, collector: Invocation<'collect'>,
base: BaseModelType base: BaseModelType
): Promise<CanvasControlLayerState[]> => { ): Promise<AddControlNetsResult> => {
const validControlLayers = layers const validControlLayers = layers
.filter((layer) => layer.isEnabled) .filter((layer) => layer.isEnabled)
.filter((layer) => isValidControlAdapter(layer.controlAdapter, base)); .filter((layer) => isValidControlAdapter(layer.controlAdapter, base))
.filter((layer) => layer.controlAdapter.type === 'controlnet');
const result: AddControlNetsResult = {
addedControlNets: 0,
};
for (const layer of validControlLayers) { for (const layer of validControlLayers) {
result.addedControlNets++;
const adapter = manager.adapters.controlLayers.get(layer.id); const adapter = manager.adapters.controlLayers.get(layer.id);
assert(adapter, 'Adapter not found'); assert(adapter, 'Adapter not found');
const imageDTO = await adapter.renderer.rasterize({ rect: bbox, attrs: { opacity: 1, filters: [] } }); const imageDTO = await adapter.renderer.rasterize({ rect: bbox, attrs: { opacity: 1, filters: [] } });
if (layer.controlAdapter.type === 'controlnet') { await addControlNetToGraph(g, layer, imageDTO, collector);
await addControlNetToGraph(g, layer, imageDTO, denoise);
} else {
await addT2IAdapterToGraph(g, layer, imageDTO, denoise);
} }
}
return validControlLayers; return result;
}; };
const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => { type AddT2IAdaptersResult = {
try { addedT2IAdapters: number;
// Attempt to retrieve the collector };
const controlNetCollect = g.getNode(CONTROL_NET_COLLECT);
assert(controlNetCollect.type === 'collect'); export const addT2IAdapters = async (
return controlNetCollect; manager: CanvasManager,
} catch { layers: CanvasControlLayerState[],
// Add the ControlNet collector g: Graph,
const controlNetCollect = g.addNode({ bbox: Rect,
id: CONTROL_NET_COLLECT, collector: Invocation<'collect'>,
type: 'collect', base: BaseModelType
}); ): Promise<AddT2IAdaptersResult> => {
g.addEdge(controlNetCollect, 'collection', denoise, 'control'); const validControlLayers = layers
return controlNetCollect; .filter((layer) => layer.isEnabled)
.filter((layer) => isValidControlAdapter(layer.controlAdapter, base))
.filter((layer) => layer.controlAdapter.type === 't2i_adapter');
const result: AddT2IAdaptersResult = {
addedT2IAdapters: 0,
};
for (const layer of validControlLayers) {
result.addedT2IAdapters++;
const adapter = manager.adapters.controlLayers.get(layer.id);
assert(adapter, 'Adapter not found');
const imageDTO = await adapter.renderer.rasterize({ rect: bbox, attrs: { opacity: 1, filters: [] } });
await addT2IAdapterToGraph(g, layer, imageDTO, collector);
} }
return result;
}; };
const addControlNetToGraph = ( const addControlNetToGraph = (
g: Graph, g: Graph,
layer: CanvasControlLayerState, layer: CanvasControlLayerState,
imageDTO: ImageDTO, imageDTO: ImageDTO,
denoise: Invocation<'denoise_latents'> collector: Invocation<'collect'>
) => { ) => {
const { id, controlAdapter } = layer; const { id, controlAdapter } = layer;
assert(controlAdapter.type === 'controlnet'); assert(controlAdapter.type === 'controlnet');
@ -64,8 +87,6 @@ const addControlNetToGraph = (
assert(model !== null); assert(model !== null);
const { image_name } = imageDTO; const { image_name } = imageDTO;
const controlNetCollect = addControlNetCollectorSafe(g, denoise);
const controlNet = g.addNode({ const controlNet = g.addNode({
id: `control_net_${id}`, id: `control_net_${id}`,
type: 'controlnet', type: 'controlnet',
@ -77,32 +98,14 @@ const addControlNetToGraph = (
control_weight: weight, control_weight: weight,
image: { image_name }, image: { image_name },
}); });
g.addEdge(controlNet, 'control', controlNetCollect, 'item'); g.addEdge(controlNet, 'control', collector, 'item');
};
const addT2IAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => {
try {
// You see, we've already got one!
const t2iAdapterCollect = g.getNode(T2I_ADAPTER_COLLECT);
assert(t2iAdapterCollect.type === 'collect');
return t2iAdapterCollect;
} catch {
const t2iAdapterCollect = g.addNode({
id: T2I_ADAPTER_COLLECT,
type: 'collect',
});
g.addEdge(t2iAdapterCollect, 'collection', denoise, 't2i_adapter');
return t2iAdapterCollect;
}
}; };
const addT2IAdapterToGraph = ( const addT2IAdapterToGraph = (
g: Graph, g: Graph,
layer: CanvasControlLayerState, layer: CanvasControlLayerState,
imageDTO: ImageDTO, imageDTO: ImageDTO,
denoise: Invocation<'denoise_latents'> collector: Invocation<'collect'>
) => { ) => {
const { id, controlAdapter } = layer; const { id, controlAdapter } = layer;
assert(controlAdapter.type === 't2i_adapter'); assert(controlAdapter.type === 't2i_adapter');
@ -110,8 +113,6 @@ const addT2IAdapterToGraph = (
assert(model !== null); assert(model !== null);
const { image_name } = imageDTO; const { image_name } = imageDTO;
const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise);
const t2iAdapter = g.addNode({ const t2iAdapter = g.addNode({
id: `t2i_adapter_${id}`, id: `t2i_adapter_${id}`,
type: 't2i_adapter', type: 't2i_adapter',
@ -123,7 +124,7 @@ const addT2IAdapterToGraph = (
image: { image_name }, image: { image_name },
}); });
g.addEdge(t2iAdapter, 't2i_adapter', t2iAdapterCollect, 'item'); g.addEdge(t2iAdapter, 't2i_adapter', collector, 'item');
}; };
const isValidControlAdapter = (controlAdapter: ControlNetConfig | T2IAdapterConfig, base: BaseModelType): boolean => { const isValidControlAdapter = (controlAdapter: ControlNetConfig | T2IAdapterConfig, base: BaseModelType): boolean => {

View File

@ -1,44 +1,38 @@
import type { CanvasIPAdapterState, IPAdapterConfig } from 'features/controlLayers/store/types'; import type { CanvasIPAdapterState, IPAdapterConfig } from 'features/controlLayers/store/types';
import { IP_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { BaseModelType, Invocation } from 'services/api/types'; import type { BaseModelType, Invocation } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
type AddIPAdaptersResult = {
addedIPAdapters: number;
};
export const addIPAdapters = ( export const addIPAdapters = (
ipAdapters: CanvasIPAdapterState[], ipAdapters: CanvasIPAdapterState[],
g: Graph, g: Graph,
denoise: Invocation<'denoise_latents'>, collector: Invocation<'collect'>,
base: BaseModelType base: BaseModelType
): CanvasIPAdapterState[] => { ): AddIPAdaptersResult => {
const validIPAdapters = ipAdapters.filter((entity) => isValidIPAdapter(entity.ipAdapter, base)); const validIPAdapters = ipAdapters.filter((entity) => isValidIPAdapter(entity.ipAdapter, base));
const result: AddIPAdaptersResult = {
addedIPAdapters: 0,
};
for (const ipa of validIPAdapters) { for (const ipa of validIPAdapters) {
addIPAdapter(ipa, g, denoise); result.addedIPAdapters++;
addIPAdapter(ipa, g, collector);
} }
return validIPAdapters;
return result;
}; };
export const addIPAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => { const addIPAdapter = (entity: CanvasIPAdapterState, g: Graph, collector: Invocation<'collect'>) => {
try {
// You see, we've already got one!
const ipAdapterCollect = g.getNode(IP_ADAPTER_COLLECT);
assert(ipAdapterCollect.type === 'collect');
return ipAdapterCollect;
} catch {
const ipAdapterCollect = g.addNode({
id: IP_ADAPTER_COLLECT,
type: 'collect',
});
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
return ipAdapterCollect;
}
};
const addIPAdapter = (entity: CanvasIPAdapterState, g: Graph, denoise: Invocation<'denoise_latents'>) => {
const { id, ipAdapter } = entity; const { id, ipAdapter } = entity;
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter; const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
assert(image, 'IP Adapter image is required'); assert(image, 'IP Adapter image is required');
assert(model, 'IP Adapter model is required'); assert(model, 'IP Adapter model is required');
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
const ipAdapterNode = g.addNode({ const ipAdapterNode = g.addNode({
id: `ip_adapter_${id}`, id: `ip_adapter_${id}`,
@ -53,7 +47,7 @@ const addIPAdapter = (entity: CanvasIPAdapterState, g: Graph, denoise: Invocatio
image_name: image.image_name, image_name: image.image_name,
}, },
}); });
g.addEdge(ipAdapterNode, 'ip_adapter', ipAdapterCollect, 'item'); g.addEdge(ipAdapterNode, 'ip_adapter', collector, 'item');
}; };
export const isValidIPAdapter = (ipAdapter: IPAdapterConfig, base: BaseModelType): boolean => { export const isValidIPAdapter = (ipAdapter: IPAdapterConfig, base: BaseModelType): boolean => {

View File

@ -1,4 +1,5 @@
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { CanvasV2State, Dimensions } from 'features/controlLayers/store/types'; import type { CanvasV2State, Dimensions } from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
@ -22,15 +23,15 @@ export const addImageToImage = async (
if (!isEqual(scaledSize, originalSize)) { if (!isEqual(scaledSize, originalSize)) {
// Resize the initial image to the scaled size, denoise, then resize back to the original size // Resize the initial image to the scaled size, denoise, then resize back to the original size
const resizeImageToScaledSize = g.addNode({ const resizeImageToScaledSize = g.addNode({
id: 'initial_image_resize_in',
type: 'img_resize', type: 'img_resize',
id: getPrefixedId('initial_image_resize_in'),
image: { image_name }, image: { image_name },
...scaledSize, ...scaledSize,
}); });
const i2l = g.addNode({ id: 'i2l', type: 'i2l' }); const i2l = g.addNode({ id: 'i2l', type: 'i2l' });
const resizeImageToOriginalSize = g.addNode({ const resizeImageToOriginalSize = g.addNode({
id: 'initial_image_resize_out',
type: 'img_resize', type: 'img_resize',
id: getPrefixedId('initial_image_resize_out'),
...originalSize, ...originalSize,
}); });

View File

@ -1,4 +1,5 @@
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { CanvasV2State, Dimensions } from 'features/controlLayers/store/types'; import type { CanvasV2State, Dimensions } from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ParameterPrecision } from 'features/parameters/types/parameterSchemas'; import type { ParameterPrecision } from 'features/parameters/types/parameterSchemas';
@ -26,36 +27,36 @@ export const addInpaint = async (
if (!isEqual(scaledSize, originalSize)) { if (!isEqual(scaledSize, originalSize)) {
// Scale before processing requires some resizing // Scale before processing requires some resizing
const i2l = g.addNode({ id: 'i2l', type: 'i2l' }); const i2l = g.addNode({ id: getPrefixedId('i2l'), type: 'i2l' });
const resizeImageToScaledSize = g.addNode({ const resizeImageToScaledSize = g.addNode({
id: 'resize_image_to_scaled_size',
type: 'img_resize', type: 'img_resize',
id: getPrefixedId('resize_image_to_scaled_size'),
image: { image_name: initialImage.image_name }, image: { image_name: initialImage.image_name },
...scaledSize, ...scaledSize,
}); });
const alphaToMask = g.addNode({ const alphaToMask = g.addNode({
id: 'alpha_to_mask', id: getPrefixedId('alpha_to_mask'),
type: 'tomask', type: 'tomask',
image: { image_name: maskImage.image_name }, image: { image_name: maskImage.image_name },
invert: true, invert: true,
}); });
const resizeMaskToScaledSize = g.addNode({ const resizeMaskToScaledSize = g.addNode({
id: 'resize_mask_to_scaled_size', id: getPrefixedId('resize_mask_to_scaled_size'),
type: 'img_resize', type: 'img_resize',
...scaledSize, ...scaledSize,
}); });
const resizeImageToOriginalSize = g.addNode({ const resizeImageToOriginalSize = g.addNode({
id: 'resize_image_to_original_size', id: getPrefixedId('resize_image_to_original_size'),
type: 'img_resize', type: 'img_resize',
...originalSize, ...originalSize,
}); });
const resizeMaskToOriginalSize = g.addNode({ const resizeMaskToOriginalSize = g.addNode({
id: 'resize_mask_to_original_size', id: getPrefixedId('resize_mask_to_original_size'),
type: 'img_resize', type: 'img_resize',
...originalSize, ...originalSize,
}); });
const createGradientMask = g.addNode({ const createGradientMask = g.addNode({
id: 'create_gradient_mask', id: getPrefixedId('create_gradient_mask'),
type: 'create_gradient_mask', type: 'create_gradient_mask',
coherence_mode: compositing.canvasCoherenceMode, coherence_mode: compositing.canvasCoherenceMode,
minimum_denoise: compositing.canvasCoherenceMinDenoise, minimum_denoise: compositing.canvasCoherenceMinDenoise,
@ -63,7 +64,7 @@ export const addInpaint = async (
fp32: vaePrecision === 'fp32', fp32: vaePrecision === 'fp32',
}); });
const canvasPasteBack = g.addNode({ const canvasPasteBack = g.addNode({
id: 'canvas_v2_mask_and_crop', id: getPrefixedId('canvas_v2_mask_and_crop'),
type: 'canvas_v2_mask_and_crop', type: 'canvas_v2_mask_and_crop',
mask_blur: compositing.maskBlur, mask_blur: compositing.maskBlur,
}); });
@ -92,15 +93,15 @@ export const addInpaint = async (
return canvasPasteBack; return canvasPasteBack;
} else { } else {
// No scale before processing, much simpler // No scale before processing, much simpler
const i2l = g.addNode({ id: 'i2l', type: 'i2l', image: { image_name: initialImage.image_name } }); const i2l = g.addNode({ id: getPrefixedId('i2l'), type: 'i2l', image: { image_name: initialImage.image_name } });
const alphaToMask = g.addNode({ const alphaToMask = g.addNode({
id: 'alpha_to_mask', id: getPrefixedId('alpha_to_mask'),
type: 'tomask', type: 'tomask',
image: { image_name: maskImage.image_name }, image: { image_name: maskImage.image_name },
invert: true, invert: true,
}); });
const createGradientMask = g.addNode({ const createGradientMask = g.addNode({
id: 'create_gradient_mask', id: getPrefixedId('create_gradient_mask'),
type: 'create_gradient_mask', type: 'create_gradient_mask',
coherence_mode: compositing.canvasCoherenceMode, coherence_mode: compositing.canvasCoherenceMode,
minimum_denoise: compositing.canvasCoherenceMinDenoise, minimum_denoise: compositing.canvasCoherenceMinDenoise,
@ -109,7 +110,7 @@ export const addInpaint = async (
image: { image_name: initialImage.image_name }, image: { image_name: initialImage.image_name },
}); });
const canvasPasteBack = g.addNode({ const canvasPasteBack = g.addNode({
id: 'canvas_v2_mask_and_crop', id: getPrefixedId('canvas_v2_mask_and_crop'),
type: 'canvas_v2_mask_and_crop', type: 'canvas_v2_mask_and_crop',
mask_blur: compositing.maskBlur, mask_blur: compositing.maskBlur,
}); });

View File

@ -1,6 +1,6 @@
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { zModelIdentifierField } from 'features/nodes/types/common'; import { zModelIdentifierField } from 'features/nodes/types/common';
import { LORA_LOADER } from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { Invocation, S } from 'services/api/types'; import type { Invocation, S } from 'services/api/types';
@ -28,12 +28,12 @@ export const addLoRAs = (
// We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies // We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies
// each LoRA to the UNet and CLIP. // each LoRA to the UNet and CLIP.
const loraCollector = g.addNode({ const loraCollector = g.addNode({
id: `${LORA_LOADER}_collect`,
type: 'collect', type: 'collect',
id: getPrefixedId('lora_collector'),
}); });
const loraCollectionLoader = g.addNode({ const loraCollectionLoader = g.addNode({
id: LORA_LOADER,
type: 'lora_collection_loader', type: 'lora_collection_loader',
id: getPrefixedId('lora_collection_loader'),
}); });
g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras'); g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras');
@ -50,12 +50,11 @@ export const addLoRAs = (
for (const lora of enabledLoRAs) { for (const lora of enabledLoRAs) {
const { weight } = lora; const { weight } = lora;
const { key } = lora.model;
const parsedModel = zModelIdentifierField.parse(lora.model); const parsedModel = zModelIdentifierField.parse(lora.model);
const loraSelector = g.addNode({ const loraSelector = g.addNode({
type: 'lora_selector', type: 'lora_selector',
id: `${LORA_LOADER}_${key}`, id: getPrefixedId('lora_selector'),
lora: parsedModel, lora: parsedModel,
weight, weight,
}); });

View File

@ -1,4 +1,4 @@
import { NSFW_CHECKER } from 'features/nodes/util/graph/constants'; import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { Invocation } from 'services/api/types'; import type { Invocation } from 'services/api/types';
@ -13,8 +13,8 @@ export const addNSFWChecker = (
imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop'> imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop'>
): Invocation<'img_nsfw'> => { ): Invocation<'img_nsfw'> => {
const nsfw = g.addNode({ const nsfw = g.addNode({
id: NSFW_CHECKER,
type: 'img_nsfw', type: 'img_nsfw',
id: getPrefixedId('nsfw_checker'),
}); });
g.addEdge(imageOutput, 'image', nsfw, 'image'); g.addEdge(imageOutput, 'image', nsfw, 'image');

View File

@ -1,4 +1,5 @@
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { CanvasV2State, Dimensions } from 'features/controlLayers/store/types'; import type { CanvasV2State, Dimensions } from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getInfill } from 'features/nodes/util/graph/graphBuilderUtils'; import { getInfill } from 'features/nodes/util/graph/graphBuilderUtils';
@ -31,18 +32,18 @@ export const addOutpaint = async (
// Combine the inpaint mask and the initial image's alpha channel into a single mask // Combine the inpaint mask and the initial image's alpha channel into a single mask
const maskAlphaToMask = g.addNode({ const maskAlphaToMask = g.addNode({
id: 'alpha_to_mask', id: getPrefixedId('alpha_to_mask'),
type: 'tomask', type: 'tomask',
image: { image_name: maskImage.image_name }, image: { image_name: maskImage.image_name },
invert: true, invert: true,
}); });
const initialImageAlphaToMask = g.addNode({ const initialImageAlphaToMask = g.addNode({
id: 'image_alpha_to_mask', id: getPrefixedId('image_alpha_to_mask'),
type: 'tomask', type: 'tomask',
image: { image_name: initialImage.image_name }, image: { image_name: initialImage.image_name },
}); });
const maskCombine = g.addNode({ const maskCombine = g.addNode({
id: 'mask_combine', id: getPrefixedId('mask_combine'),
type: 'mask_combine', type: 'mask_combine',
}); });
g.addEdge(maskAlphaToMask, 'image', maskCombine, 'mask1'); g.addEdge(maskAlphaToMask, 'image', maskCombine, 'mask1');
@ -50,7 +51,7 @@ export const addOutpaint = async (
// Resize the combined and initial image to the scaled size // Resize the combined and initial image to the scaled size
const resizeInputMaskToScaledSize = g.addNode({ const resizeInputMaskToScaledSize = g.addNode({
id: 'resize_mask_to_scaled_size', id: getPrefixedId('resize_mask_to_scaled_size'),
type: 'img_resize', type: 'img_resize',
...scaledSize, ...scaledSize,
}); });
@ -58,7 +59,7 @@ export const addOutpaint = async (
// Resize the initial image to the scaled size and infill // Resize the initial image to the scaled size and infill
const resizeInputImageToScaledSize = g.addNode({ const resizeInputImageToScaledSize = g.addNode({
id: 'resize_image_to_scaled_size', id: getPrefixedId('resize_image_to_scaled_size'),
type: 'img_resize', type: 'img_resize',
image: { image_name: initialImage.image_name }, image: { image_name: initialImage.image_name },
...scaledSize, ...scaledSize,
@ -67,7 +68,7 @@ export const addOutpaint = async (
// Create the gradient denoising mask from the combined mask // Create the gradient denoising mask from the combined mask
const createGradientMask = g.addNode({ const createGradientMask = g.addNode({
id: 'create_gradient_mask', id: getPrefixedId('create_gradient_mask'),
type: 'create_gradient_mask', type: 'create_gradient_mask',
coherence_mode: compositing.canvasCoherenceMode, coherence_mode: compositing.canvasCoherenceMode,
minimum_denoise: compositing.canvasCoherenceMinDenoise, minimum_denoise: compositing.canvasCoherenceMinDenoise,
@ -81,24 +82,24 @@ export const addOutpaint = async (
g.addEdge(createGradientMask, 'denoise_mask', denoise, 'denoise_mask'); g.addEdge(createGradientMask, 'denoise_mask', denoise, 'denoise_mask');
// Decode infilled image and connect to denoise // Decode infilled image and connect to denoise
const i2l = g.addNode({ id: 'i2l', type: 'i2l' }); const i2l = g.addNode({ id: getPrefixedId('i2l'), type: 'i2l' });
g.addEdge(infill, 'image', i2l, 'image'); g.addEdge(infill, 'image', i2l, 'image');
g.addEdge(vaeSource, 'vae', i2l, 'vae'); g.addEdge(vaeSource, 'vae', i2l, 'vae');
g.addEdge(i2l, 'latents', denoise, 'latents'); g.addEdge(i2l, 'latents', denoise, 'latents');
// Resize the output image back to the original size // Resize the output image back to the original size
const resizeOutputImageToOriginalSize = g.addNode({ const resizeOutputImageToOriginalSize = g.addNode({
id: 'resize_image_to_original_size', id: getPrefixedId('resize_image_to_original_size'),
type: 'img_resize', type: 'img_resize',
...originalSize, ...originalSize,
}); });
const resizeOutputMaskToOriginalSize = g.addNode({ const resizeOutputMaskToOriginalSize = g.addNode({
id: 'resize_mask_to_original_size', id: getPrefixedId('resize_mask_to_original_size'),
type: 'img_resize', type: 'img_resize',
...originalSize, ...originalSize,
}); });
const canvasPasteBack = g.addNode({ const canvasPasteBack = g.addNode({
id: 'canvas_v2_mask_and_crop', id: getPrefixedId('canvas_v2_mask_and_crop'),
type: 'canvas_v2_mask_and_crop', type: 'canvas_v2_mask_and_crop',
mask_blur: compositing.maskBlur, mask_blur: compositing.maskBlur,
}); });
@ -117,24 +118,24 @@ export const addOutpaint = async (
} else { } else {
infill.image = { image_name: initialImage.image_name }; infill.image = { image_name: initialImage.image_name };
// No scale before processing, much simpler // No scale before processing, much simpler
const i2l = g.addNode({ id: 'i2l', type: 'i2l' }); const i2l = g.addNode({ id: getPrefixedId('i2l'), type: 'i2l' });
const maskAlphaToMask = g.addNode({ const maskAlphaToMask = g.addNode({
id: 'mask_alpha_to_mask', id: getPrefixedId('mask_alpha_to_mask'),
type: 'tomask', type: 'tomask',
image: { image_name: maskImage.image_name }, image: { image_name: maskImage.image_name },
invert: true, invert: true,
}); });
const initialImageAlphaToMask = g.addNode({ const initialImageAlphaToMask = g.addNode({
id: 'image_alpha_to_mask', id: getPrefixedId('image_alpha_to_mask'),
type: 'tomask', type: 'tomask',
image: { image_name: initialImage.image_name }, image: { image_name: initialImage.image_name },
}); });
const maskCombine = g.addNode({ const maskCombine = g.addNode({
id: 'mask_combine', id: getPrefixedId('mask_combine'),
type: 'mask_combine', type: 'mask_combine',
}); });
const createGradientMask = g.addNode({ const createGradientMask = g.addNode({
id: 'create_gradient_mask', id: getPrefixedId('create_gradient_mask'),
type: 'create_gradient_mask', type: 'create_gradient_mask',
coherence_mode: compositing.canvasCoherenceMode, coherence_mode: compositing.canvasCoherenceMode,
minimum_denoise: compositing.canvasCoherenceMinDenoise, minimum_denoise: compositing.canvasCoherenceMinDenoise,
@ -143,7 +144,7 @@ export const addOutpaint = async (
image: { image_name: initialImage.image_name }, image: { image_name: initialImage.image_name },
}); });
const canvasPasteBack = g.addNode({ const canvasPasteBack = g.addNode({
id: 'canvas_v2_mask_and_crop', id: getPrefixedId('canvas_v2_mask_and_crop'),
type: 'canvas_v2_mask_and_crop', type: 'canvas_v2_mask_and_crop',
mask_blur: compositing.maskBlur, mask_blur: compositing.maskBlur,
}); });

View File

@ -1,22 +1,23 @@
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { import type {
CanvasRegionalGuidanceState, CanvasRegionalGuidanceState,
Rect, Rect,
RegionalGuidanceIPAdapterConfig, RegionalGuidanceIPAdapterConfig,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import { import { isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters';
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
PROMPT_REGION_MASK_TO_TENSOR_PREFIX,
PROMPT_REGION_NEGATIVE_COND_PREFIX,
PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX,
PROMPT_REGION_POSITIVE_COND_PREFIX,
} from 'features/nodes/util/graph/constants';
import { addIPAdapterCollectorSafe, isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { BaseModelType, Invocation } from 'services/api/types'; import type { BaseModelType, Invocation } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
type AddedRegionResult = {
addedPositivePrompt: boolean;
addedNegativePrompt: boolean;
addedAutoNegativePositivePrompt: boolean;
addedIPAdapters: number;
};
/** /**
* Adds regional guidance to the graph * Adds regional guidance to the graph
* @param regions Array of regions to add * @param regions Array of regions to add
@ -27,6 +28,7 @@ import { assert } from 'tsafe';
* @param negCond The negative conditioning node * @param negCond The negative conditioning node
* @param posCondCollect The positive conditioning collector * @param posCondCollect The positive conditioning collector
* @param negCondCollect The negative conditioning collector * @param negCondCollect The negative conditioning collector
* @param ipAdapterCollect The IP adapter collector
* @returns A promise that resolves to the regions that were successfully added to the graph * @returns A promise that resolves to the regions that were successfully added to the graph
*/ */
@ -40,21 +42,29 @@ export const addRegions = async (
posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>, posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>, negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
posCondCollect: Invocation<'collect'>, posCondCollect: Invocation<'collect'>,
negCondCollect: Invocation<'collect'> negCondCollect: Invocation<'collect'>,
): Promise<CanvasRegionalGuidanceState[]> => { ipAdapterCollect: Invocation<'collect'>
): Promise<AddedRegionResult[]> => {
const isSDXL = base === 'sdxl'; const isSDXL = base === 'sdxl';
const validRegions = regions.filter((rg) => isValidRegion(rg, base)); const validRegions = regions.filter((rg) => isValidRegion(rg, base));
const results: AddedRegionResult[] = [];
for (const region of validRegions) { for (const region of validRegions) {
const result: AddedRegionResult = {
addedPositivePrompt: false,
addedNegativePrompt: false,
addedAutoNegativePositivePrompt: false,
addedIPAdapters: 0,
};
const adapter = manager.adapters.regionMasks.get(region.id); const adapter = manager.adapters.regionMasks.get(region.id);
assert(adapter, 'Adapter not found'); assert(adapter, 'Adapter not found');
const imageDTO = await adapter.renderer.rasterize({ rect: bbox }); const imageDTO = await adapter.renderer.rasterize({ rect: bbox });
// The main mask-to-tensor node // The main mask-to-tensor node
const maskToTensor = g.addNode({ const maskToTensor = g.addNode({
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${region.id}`,
type: 'alpha_mask_to_tensor', type: 'alpha_mask_to_tensor',
id: getPrefixedId('prompt_region_mask_to_tensor'),
image: { image: {
image_name: imageDTO.image_name, image_name: imageDTO.image_name,
}, },
@ -62,17 +72,18 @@ export const addRegions = async (
if (region.positivePrompt) { if (region.positivePrompt) {
// The main positive conditioning node // The main positive conditioning node
result.addedPositivePrompt = true;
const regionalPosCond = g.addNode( const regionalPosCond = g.addNode(
isSDXL isSDXL
? { ? {
type: 'sdxl_compel_prompt', type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${region.id}`, id: getPrefixedId('prompt_region_positive_cond'),
prompt: region.positivePrompt, prompt: region.positivePrompt,
style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields? style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields?
} }
: { : {
type: 'compel', type: 'compel',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${region.id}`, id: getPrefixedId('prompt_region_positive_cond'),
prompt: region.positivePrompt, prompt: region.positivePrompt,
} }
); );
@ -99,18 +110,19 @@ export const addRegions = async (
} }
if (region.negativePrompt) { if (region.negativePrompt) {
result.addedNegativePrompt = true;
// The main negative conditioning node // The main negative conditioning node
const regionalNegCond = g.addNode( const regionalNegCond = g.addNode(
isSDXL isSDXL
? { ? {
type: 'sdxl_compel_prompt', type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${region.id}`, id: getPrefixedId('prompt_region_negative_cond'),
prompt: region.negativePrompt, prompt: region.negativePrompt,
style: region.negativePrompt, style: region.negativePrompt,
} }
: { : {
type: 'compel', type: 'compel',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${region.id}`, id: getPrefixedId('prompt_region_negative_cond'),
prompt: region.negativePrompt, prompt: region.negativePrompt,
} }
); );
@ -135,10 +147,11 @@ export const addRegions = async (
} }
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node // If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
if (region.autoNegative === 'invert' && region.positivePrompt) { if (region.autoNegative && region.positivePrompt) {
result.addedAutoNegativePositivePrompt = true;
// We re-use the mask image, but invert it when converting to tensor // We re-use the mask image, but invert it when converting to tensor
const invertTensorMask = g.addNode({ const invertTensorMask = g.addNode({
id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${region.id}`, id: getPrefixedId('prompt_region_invert_tensor_mask'),
type: 'invert_tensor_mask', type: 'invert_tensor_mask',
}); });
// Connect the OG mask image to the inverted mask-to-tensor node // Connect the OG mask image to the inverted mask-to-tensor node
@ -148,13 +161,13 @@ export const addRegions = async (
isSDXL isSDXL
? { ? {
type: 'sdxl_compel_prompt', type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${region.id}`, id: getPrefixedId('prompt_region_positive_cond_inverted'),
prompt: region.positivePrompt, prompt: region.positivePrompt,
style: region.positivePrompt, style: region.positivePrompt,
} }
: { : {
type: 'compel', type: 'compel',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${region.id}`, id: getPrefixedId('prompt_region_positive_cond_inverted'),
prompt: region.positivePrompt, prompt: region.positivePrompt,
} }
); );
@ -183,7 +196,7 @@ export const addRegions = async (
); );
for (const ipa of validRGIPAdapters) { for (const ipa of validRGIPAdapters) {
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise); result.addedIPAdapters++;
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipa; const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipa;
assert(model, 'IP Adapter model is required'); assert(model, 'IP Adapter model is required');
assert(image, 'IP Adapter image is required'); assert(image, 'IP Adapter image is required');
@ -206,14 +219,18 @@ export const addRegions = async (
g.addEdge(maskToTensor, 'mask', ipAdapter, 'mask'); g.addEdge(maskToTensor, 'mask', ipAdapter, 'mask');
g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item'); g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item');
} }
results.push(result);
} }
g.upsertMetadata({ regions: validRegions }); g.upsertMetadata({ regions: validRegions });
return validRegions;
return results;
}; };
export const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => { export const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => {
const isEnabled = rg.isEnabled;
const hasTextPrompt = Boolean(rg.positivePrompt || rg.negativePrompt); const hasTextPrompt = Boolean(rg.positivePrompt || rg.negativePrompt);
const hasIPAdapter = rg.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)).length > 0; const hasIPAdapter = rg.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)).length > 0;
return hasTextPrompt || hasIPAdapter; return isEnabled && (hasTextPrompt || hasIPAdapter);
}; };

View File

@ -1,6 +1,6 @@
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { zModelIdentifierField } from 'features/nodes/types/common'; import { zModelIdentifierField } from 'features/nodes/types/common';
import { LORA_LOADER } from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { Invocation, S } from 'services/api/types'; import type { Invocation, S } from 'services/api/types';
@ -25,12 +25,12 @@ export const addSDXLLoRAs = (
// We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies // We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies
// each LoRA to the UNet and CLIP. // each LoRA to the UNet and CLIP.
const loraCollector = g.addNode({ const loraCollector = g.addNode({
id: `${LORA_LOADER}_collect`, id: getPrefixedId('lora_collector'),
type: 'collect', type: 'collect',
}); });
const loraCollectionLoader = g.addNode({ const loraCollectionLoader = g.addNode({
id: LORA_LOADER,
type: 'sdxl_lora_collection_loader', type: 'sdxl_lora_collection_loader',
id: getPrefixedId('sdxl_lora_collection_loader'),
}); });
g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras'); g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras');
@ -50,12 +50,11 @@ export const addSDXLLoRAs = (
for (const lora of enabledLoRAs) { for (const lora of enabledLoRAs) {
const { weight } = lora; const { weight } = lora;
const { key } = lora.model;
const parsedModel = zModelIdentifierField.parse(lora.model); const parsedModel = zModelIdentifierField.parse(lora.model);
const loraSelector = g.addNode({ const loraSelector = g.addNode({
type: 'lora_selector', type: 'lora_selector',
id: `${LORA_LOADER}_${key}`, id: getPrefixedId('lora_selector'),
lora: parsedModel, lora: parsedModel,
weight, weight,
}); });

View File

@ -1,12 +1,6 @@
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
SDXL_REFINER_DENOISE_LATENTS,
SDXL_REFINER_MODEL_LOADER,
SDXL_REFINER_NEGATIVE_CONDITIONING,
SDXL_REFINER_POSITIVE_CONDITIONING,
SDXL_REFINER_SEAMLESS,
} from 'features/nodes/util/graph/constants';
import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { Invocation } from 'services/api/types'; import type { Invocation } from 'services/api/types';
import { isRefinerMainModelModelConfig } from 'services/api/types'; import { isRefinerMainModelModelConfig } from 'services/api/types';
@ -42,24 +36,24 @@ export const addSDXLRefiner = async (
const refinerModelLoader = g.addNode({ const refinerModelLoader = g.addNode({
type: 'sdxl_refiner_model_loader', type: 'sdxl_refiner_model_loader',
id: SDXL_REFINER_MODEL_LOADER, id: getPrefixedId('refiner_model_loader'),
model: refinerModel, model: refinerModel,
}); });
const refinerPosCond = g.addNode({ const refinerPosCond = g.addNode({
type: 'sdxl_refiner_compel_prompt', type: 'sdxl_refiner_compel_prompt',
id: SDXL_REFINER_POSITIVE_CONDITIONING, id: getPrefixedId('refiner_pos_cond'),
style: posCond.style, style: posCond.style,
aesthetic_score: refinerPositiveAestheticScore, aesthetic_score: refinerPositiveAestheticScore,
}); });
const refinerNegCond = g.addNode({ const refinerNegCond = g.addNode({
type: 'sdxl_refiner_compel_prompt', type: 'sdxl_refiner_compel_prompt',
id: SDXL_REFINER_NEGATIVE_CONDITIONING, id: getPrefixedId('refiner_neg_cond'),
style: negCond.style, style: negCond.style,
aesthetic_score: refinerNegativeAestheticScore, aesthetic_score: refinerNegativeAestheticScore,
}); });
const refinerDenoise = g.addNode({ const refinerDenoise = g.addNode({
type: 'denoise_latents', type: 'denoise_latents',
id: SDXL_REFINER_DENOISE_LATENTS, id: getPrefixedId('refiner_denoise_latents'),
cfg_scale: refinerCFGScale, cfg_scale: refinerCFGScale,
steps: refinerSteps, steps: refinerSteps,
scheduler: refinerScheduler, scheduler: refinerScheduler,
@ -69,8 +63,8 @@ export const addSDXLRefiner = async (
if (seamless) { if (seamless) {
const refinerSeamless = g.addNode({ const refinerSeamless = g.addNode({
id: SDXL_REFINER_SEAMLESS,
type: 'seamless', type: 'seamless',
id: getPrefixedId('refiner_seamless'),
seamless_x: seamless.seamless_x, seamless_x: seamless.seamless_x,
seamless_y: seamless.seamless_y, seamless_y: seamless.seamless_y,
}); });

View File

@ -1,5 +1,5 @@
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { SEAMLESS } from 'features/nodes/util/graph/constants'; import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { Invocation } from 'services/api/types'; import type { Invocation } from 'services/api/types';
@ -28,8 +28,8 @@ export const addSeamless = (
} }
const seamless = g.addNode({ const seamless = g.addNode({
id: SEAMLESS,
type: 'seamless', type: 'seamless',
id: getPrefixedId('seamless'),
seamless_x, seamless_x,
seamless_y, seamless_y,
}); });

View File

@ -1,3 +1,4 @@
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { Dimensions } from 'features/controlLayers/store/types'; import type { Dimensions } from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
@ -12,7 +13,7 @@ export const addTextToImage = (
if (!isEqual(scaledSize, originalSize)) { if (!isEqual(scaledSize, originalSize)) {
// We need to resize the output image back to the original size // We need to resize the output image back to the original size
const resizeImageToOriginalSize = g.addNode({ const resizeImageToOriginalSize = g.addNode({
id: 'resize_image_to_original_size', id: getPrefixedId('resize_image_to_original_size'),
type: 'img_resize', type: 'img_resize',
...originalSize, ...originalSize,
}); });

View File

@ -1,4 +1,4 @@
import { WATERMARKER } from 'features/nodes/util/graph/constants'; import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { Invocation } from 'services/api/types'; import type { Invocation } from 'services/api/types';
@ -13,8 +13,8 @@ export const addWatermarker = (
imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop'> imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop'>
): Invocation<'img_watermark'> => { ): Invocation<'img_watermark'> => {
const watermark = g.addNode({ const watermark = g.addNode({
id: WATERMARKER,
type: 'img_watermark', type: 'img_watermark',
id: getPrefixedId('watermarker'),
}); });
g.addEdge(imageOutput, 'image', watermark, 'image'); g.addEdge(imageOutput, 'image', watermark, 'image');

View File

@ -1,22 +1,9 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { import { addControlNets, addT2IAdapters } from 'features/nodes/util/graph/generation/addControlAdapters';
CANVAS_OUTPUT,
CLIP_SKIP,
CONTROL_LAYERS_GRAPH,
DENOISE_LATENTS,
LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT,
NOISE,
POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
VAE_LOADER,
} from 'features/nodes/util/graph/constants';
import { addControlAdapters } from 'features/nodes/util/graph/generation/addControlAdapters';
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage'; import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
// import { addHRF } from 'features/nodes/util/graph/generation/addHRF'; // import { addHRF } from 'features/nodes/util/graph/generation/addHRF';
@ -37,7 +24,10 @@ import { addRegions } from './addRegions';
const log = logger('system'); const log = logger('system');
export const buildSD1Graph = async (state: RootState, manager: CanvasManager): Promise<Graph> => { export const buildSD1Graph = async (
state: RootState,
manager: CanvasManager
): Promise<{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel'> }> => {
const generationMode = manager.compositor.getGenerationMode(); const generationMode = manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building SD1/SD2 graph'); log.debug({ generationMode }, 'Building SD1/SD2 graph');
@ -62,38 +52,38 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P
const { originalSize, scaledSize } = getSizes(bbox); const { originalSize, scaledSize } = getSizes(bbox);
const g = new Graph(CONTROL_LAYERS_GRAPH); const g = new Graph(getPrefixedId('sd1_graph'));
const modelLoader = g.addNode({ const modelLoader = g.addNode({
type: 'main_model_loader', type: 'main_model_loader',
id: MAIN_MODEL_LOADER, id: getPrefixedId('sd1_model_loader'),
model, model,
}); });
const clipSkip = g.addNode({ const clipSkip = g.addNode({
type: 'clip_skip', type: 'clip_skip',
id: CLIP_SKIP, id: getPrefixedId('clip_skip'),
skipped_layers, skipped_layers,
}); });
const posCond = g.addNode({ const posCond = g.addNode({
type: 'compel', type: 'compel',
id: POSITIVE_CONDITIONING, id: getPrefixedId('pos_cond'),
prompt: positivePrompt, prompt: positivePrompt,
}); });
const posCondCollect = g.addNode({ const posCondCollect = g.addNode({
type: 'collect', type: 'collect',
id: POSITIVE_CONDITIONING_COLLECT, id: getPrefixedId('pos_cond_collect'),
}); });
const negCond = g.addNode({ const negCond = g.addNode({
type: 'compel', type: 'compel',
id: NEGATIVE_CONDITIONING, id: getPrefixedId('neg_cond'),
prompt: negativePrompt, prompt: negativePrompt,
}); });
const negCondCollect = g.addNode({ const negCondCollect = g.addNode({
type: 'collect', type: 'collect',
id: NEGATIVE_CONDITIONING_COLLECT, id: getPrefixedId('neg_cond_collect'),
}); });
const noise = g.addNode({ const noise = g.addNode({
type: 'noise', type: 'noise',
id: NOISE, id: getPrefixedId('noise'),
seed, seed,
width: scaledSize.width, width: scaledSize.width,
height: scaledSize.height, height: scaledSize.height,
@ -101,7 +91,7 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P
}); });
const denoise = g.addNode({ const denoise = g.addNode({
type: 'denoise_latents', type: 'denoise_latents',
id: DENOISE_LATENTS, id: getPrefixedId('denoise_latents'),
cfg_scale, cfg_scale,
cfg_rescale_multiplier, cfg_rescale_multiplier,
scheduler, scheduler,
@ -111,14 +101,14 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P
}); });
const l2i = g.addNode({ const l2i = g.addNode({
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: getPrefixedId('l2i'),
fp32: vaePrecision === 'fp32', fp32: vaePrecision === 'fp32',
}); });
const vaeLoader = const vaeLoader =
vae?.base === model.base vae?.base === model.base
? g.addNode({ ? g.addNode({
type: 'vae_loader', type: 'vae_loader',
id: VAE_LOADER, id: getPrefixedId('vae'),
vae_model: vae, vae_model: vae,
}) })
: null; : null;
@ -214,16 +204,49 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P
); );
} }
const _addedCAs = await addControlAdapters( const controlNetCollector = g.addNode({
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
const controlNetResult = await addControlNets(
manager, manager,
state.canvasV2.controlLayers.entities, state.canvasV2.controlLayers.entities,
g, g,
state.canvasV2.bbox.rect, state.canvasV2.bbox.rect,
denoise, controlNetCollector,
modelConfig.base modelConfig.base
); );
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base); if (controlNetResult.addedControlNets > 0) {
const _addedRegions = await addRegions( g.addEdge(controlNetCollector, 'collection', denoise, 'control');
} else {
g.deleteNode(controlNetCollector.id);
}
const t2iAdapterCollector = g.addNode({
type: 'collect',
id: getPrefixedId('t2i_adapter_collector'),
});
const t2iAdapterResult = await addT2IAdapters(
manager,
state.canvasV2.controlLayers.entities,
g,
state.canvasV2.bbox.rect,
controlNetCollector,
modelConfig.base
);
if (t2iAdapterResult.addedT2IAdapters > 0) {
g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter');
} else {
g.deleteNode(t2iAdapterCollector.id);
}
const ipAdapterCollector = g.addNode({
type: 'collect',
id: getPrefixedId('ip_adapter_collector'),
});
const ipAdapterResult = addIPAdapters(state.canvasV2.ipAdapters.entities, g, ipAdapterCollector, modelConfig.base);
const regionsResult = await addRegions(
manager, manager,
state.canvasV2.regions.entities, state.canvasV2.regions.entities,
g, g,
@ -233,13 +256,17 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P
posCond, posCond,
negCond, negCond,
posCondCollect, posCondCollect,
negCondCollect negCondCollect,
ipAdapterCollector
); );
// const isHRFAllowed = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l)); const totalIPAdaptersAdded =
// if (isHRFAllowed && state.hrf.hrfEnabled) { ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
// imageOutput = addHRF(state, g, denoise, noise, l2i, vaeSource); if (totalIPAdaptersAdded > 0) {
// } g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
} else {
g.deleteNode(ipAdapterCollector.id);
}
if (state.system.shouldUseNSFWChecker) { if (state.system.shouldUseNSFWChecker) {
canvasOutput = addNSFWChecker(g, canvasOutput); canvasOutput = addNSFWChecker(g, canvasOutput);
@ -252,12 +279,12 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P
const shouldSaveToGallery = session.mode === 'generate' || settings.autoSave; const shouldSaveToGallery = session.mode === 'generate' || settings.autoSave;
g.updateNode(canvasOutput, { g.updateNode(canvasOutput, {
id: CANVAS_OUTPUT, id: getPrefixedId('canvas_output'),
is_intermediate: !shouldSaveToGallery, is_intermediate: !shouldSaveToGallery,
use_cache: false, use_cache: false,
board: getBoardField(state), board: getBoardField(state),
}); });
g.setMetadataReceivingNode(canvasOutput); g.setMetadataReceivingNode(canvasOutput);
return g; return { g, noise, posCond };
}; };

View File

@ -1,21 +1,9 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { import { addControlNets, addT2IAdapters } from 'features/nodes/util/graph/generation/addControlAdapters';
CANVAS_OUTPUT,
LATENTS_TO_IMAGE,
NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT,
NOISE,
POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
SDXL_CONTROL_LAYERS_GRAPH,
SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER,
VAE_LOADER,
} from 'features/nodes/util/graph/constants';
import { addControlAdapters } from 'features/nodes/util/graph/generation/addControlAdapters';
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage'; import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
import { addIPAdapters } from 'features/nodes/util/graph/generation/addIPAdapters'; import { addIPAdapters } from 'features/nodes/util/graph/generation/addIPAdapters';
@ -36,7 +24,10 @@ import { addRegions } from './addRegions';
const log = logger('system'); const log = logger('system');
export const buildSDXLGraph = async (state: RootState, manager: CanvasManager): Promise<Graph> => { export const buildSDXLGraph = async (
state: RootState,
manager: CanvasManager
): Promise<{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'sdxl_compel_prompt'> }> => {
const generationMode = manager.compositor.getGenerationMode(); const generationMode = manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building SDXL graph'); log.debug({ generationMode }, 'Building SDXL graph');
@ -62,35 +53,35 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager):
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state); const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state);
const g = new Graph(SDXL_CONTROL_LAYERS_GRAPH); const g = new Graph(getPrefixedId('sdxl_graph'));
const modelLoader = g.addNode({ const modelLoader = g.addNode({
type: 'sdxl_model_loader', type: 'sdxl_model_loader',
id: SDXL_MODEL_LOADER, id: getPrefixedId('sdxl_model_loader'),
model, model,
}); });
const posCond = g.addNode({ const posCond = g.addNode({
type: 'sdxl_compel_prompt', type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING, id: getPrefixedId('pos_cond'),
prompt: positivePrompt, prompt: positivePrompt,
style: positiveStylePrompt, style: positiveStylePrompt,
}); });
const posCondCollect = g.addNode({ const posCondCollect = g.addNode({
type: 'collect', type: 'collect',
id: POSITIVE_CONDITIONING_COLLECT, id: getPrefixedId('pos_cond_collect'),
}); });
const negCond = g.addNode({ const negCond = g.addNode({
type: 'sdxl_compel_prompt', type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING, id: getPrefixedId('neg_cond'),
prompt: negativePrompt, prompt: negativePrompt,
style: negativeStylePrompt, style: negativeStylePrompt,
}); });
const negCondCollect = g.addNode({ const negCondCollect = g.addNode({
type: 'collect', type: 'collect',
id: NEGATIVE_CONDITIONING_COLLECT, id: getPrefixedId('neg_cond_collect'),
}); });
const noise = g.addNode({ const noise = g.addNode({
type: 'noise', type: 'noise',
id: NOISE, id: getPrefixedId('noise'),
seed, seed,
width: scaledSize.width, width: scaledSize.width,
height: scaledSize.height, height: scaledSize.height,
@ -98,7 +89,7 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager):
}); });
const denoise = g.addNode({ const denoise = g.addNode({
type: 'denoise_latents', type: 'denoise_latents',
id: SDXL_DENOISE_LATENTS, id: getPrefixedId('denoise_latents'),
cfg_scale, cfg_scale,
cfg_rescale_multiplier, cfg_rescale_multiplier,
scheduler, scheduler,
@ -108,14 +99,14 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager):
}); });
const l2i = g.addNode({ const l2i = g.addNode({
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: getPrefixedId('l2i'),
fp32: vaePrecision === 'fp32', fp32: vaePrecision === 'fp32',
}); });
const vaeLoader = const vaeLoader =
vae?.base === model.base vae?.base === model.base
? g.addNode({ ? g.addNode({
type: 'vae_loader', type: 'vae_loader',
id: VAE_LOADER, id: getPrefixedId('vae'),
vae_model: vae, vae_model: vae,
}) })
: null; : null;
@ -216,16 +207,47 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager):
); );
} }
const _addedCAs = await addControlAdapters( const controlNetCollector = g.createNode({
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
const controlNetResult = await addControlNets(
manager, manager,
state.canvasV2.controlLayers.entities, state.canvasV2.controlLayers.entities,
g, g,
state.canvasV2.bbox.rect, state.canvasV2.bbox.rect,
denoise, controlNetCollector,
modelConfig.base modelConfig.base
); );
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base); if (controlNetResult.addedControlNets > 0) {
const _addedRegions = await addRegions( g.addNode(controlNetCollector);
g.addEdge(controlNetCollector, 'collection', denoise, 'control');
}
const t2iAdapterCollector = g.createNode({
type: 'collect',
id: getPrefixedId('t2i_adapter_collector'),
});
const t2iAdapterResult = await addT2IAdapters(
manager,
state.canvasV2.controlLayers.entities,
g,
state.canvasV2.bbox.rect,
controlNetCollector,
modelConfig.base
);
if (t2iAdapterResult.addedT2IAdapters > 0) {
g.addNode(t2iAdapterCollector);
g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter');
}
const ipAdapterCollector = g.createNode({
type: 'collect',
id: getPrefixedId('ip_adapter_collector'),
});
const ipAdapterResult = addIPAdapters(state.canvasV2.ipAdapters.entities, g, ipAdapterCollector, modelConfig.base);
const regionsResult = await addRegions(
manager, manager,
state.canvasV2.regions.entities, state.canvasV2.regions.entities,
g, g,
@ -235,9 +257,17 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager):
posCond, posCond,
negCond, negCond,
posCondCollect, posCondCollect,
negCondCollect negCondCollect,
ipAdapterCollector
); );
const totalIPAdaptersAdded =
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
if (totalIPAdaptersAdded > 0) {
g.addNode(ipAdapterCollector);
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
}
if (state.system.shouldUseNSFWChecker) { if (state.system.shouldUseNSFWChecker) {
canvasOutput = addNSFWChecker(g, canvasOutput); canvasOutput = addNSFWChecker(g, canvasOutput);
} }
@ -249,12 +279,12 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager):
const shouldSaveToGallery = session.mode === 'generate' || settings.autoSave; const shouldSaveToGallery = session.mode === 'generate' || settings.autoSave;
g.updateNode(canvasOutput, { g.updateNode(canvasOutput, {
id: CANVAS_OUTPUT, id: getPrefixedId('canvas_output'),
is_intermediate: !shouldSaveToGallery, is_intermediate: !shouldSaveToGallery,
use_cache: false, use_cache: false,
board: getBoardField(state), board: getBoardField(state),
}); });
g.setMetadataReceivingNode(canvasOutput); g.setMetadataReceivingNode(canvasOutput);
return g; return { g, noise, posCond };
}; };

View File

@ -8,18 +8,21 @@ import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks
import { zNodeStatus } from 'features/nodes/types/invocation'; import { zNodeStatus } from 'features/nodes/types/invocation';
import { boardsApi } from 'services/api/endpoints/boards'; import { boardsApi } from 'services/api/endpoints/boards';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images'; import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types'; import type { ImageDTO, S } from 'services/api/types';
import { getCategories, getListImagesUrl } from 'services/api/util'; import { getCategories, getListImagesUrl } from 'services/api/util';
import type { InvocationCompleteEvent, InvocationDenoiseProgressEvent } from 'services/events/types';
const log = logger('events'); const log = logger('events');
const isCanvasOutput = (data: S['InvocationCompleteEvent']) => {
return data.invocation_source_id.split(':')[0] === 'canvas_output';
};
export const buildOnInvocationComplete = ( export const buildOnInvocationComplete = (
getState: () => RootState, getState: () => RootState,
dispatch: AppDispatch, dispatch: AppDispatch,
nodeTypeDenylist: string[], nodeTypeDenylist: string[],
setLastProgressEvent: (event: InvocationDenoiseProgressEvent | null) => void, setLastProgressEvent: (event: S['InvocationDenoiseProgressEvent'] | null) => void,
setLastCanvasProgressEvent: (event: InvocationDenoiseProgressEvent | null) => void setLastCanvasProgressEvent: (event: S['InvocationDenoiseProgressEvent'] | null) => void
) => { ) => {
const addImageToGallery = (imageDTO: ImageDTO) => { const addImageToGallery = (imageDTO: ImageDTO) => {
if (imageDTO.is_intermediate) { if (imageDTO.is_intermediate) {
@ -80,16 +83,19 @@ export const buildOnInvocationComplete = (
} }
}; };
return async (data: InvocationCompleteEvent) => { const getResultImageDTO = (data: S['InvocationCompleteEvent']) => {
log.debug( const { result } = data;
{ data } as SerializableObject, if (result.type === 'image_output') {
`Invocation complete (${data.invocation.type}, ${data.invocation_source_id})` return getImageDTO(result.image.image_name);
); } else if (result.type === 'canvas_v2_mask_and_crop_output') {
return getImageDTO(result.image.image_name);
}
return null;
};
const handleOriginWorkflows = async (data: S['InvocationCompleteEvent']) => {
const { result, invocation_source_id } = data; const { result, invocation_source_id } = data;
// Update the node execution states - the image output is handled below
if (data.origin === 'workflows') {
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) { if (nes) {
nes.status = zNodeStatus.enum.COMPLETED; nes.status = zNodeStatus.enum.COMPLETED;
@ -99,29 +105,25 @@ export const buildOnInvocationComplete = (
nes.outputs.push(result); nes.outputs.push(result);
upsertExecutionState(nes.nodeId, nes); upsertExecutionState(nes.nodeId, nes);
} }
const imageDTO = await getResultImageDTO(data);
if (imageDTO) {
addImageToGallery(imageDTO);
} }
};
// This complete event has an associated image output const handleOriginCanvas = async (data: S['InvocationCompleteEvent']) => {
if ( const session = getState().canvasV2.session;
(data.result.type === 'image_output' || data.result.type === 'canvas_v2_mask_and_crop_output') &&
!nodeTypeDenylist.includes(data.invocation.type)
) {
const { image_name } = data.result.image;
const { session } = getState().canvasV2;
const imageDTO = await getImageDTO(image_name); const imageDTO = await getResultImageDTO(data);
if (!imageDTO) { if (!imageDTO) {
log.error({ data } as SerializableObject, 'Failed to fetch image DTO after generation');
return; return;
} }
if (data.origin === 'canvas') { if (session.mode === 'compose') {
if (data.invocation_source_id !== 'canvas_output') { if (session.isStaging && isCanvasOutput(data)) {
// Not a canvas output image - ignore
return;
}
if (session.mode === 'compose' && session.isStaging) {
if (data.result.type === 'canvas_v2_mask_and_crop_output') { if (data.result.type === 'canvas_v2_mask_and_crop_output') {
const { offset_x, offset_y } = data.result; const { offset_x, offset_y } = data.result;
if (session.isStaging) { if (session.isStaging) {
@ -132,12 +134,36 @@ export const buildOnInvocationComplete = (
dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } })); dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } }));
} }
} }
addImageToGallery(imageDTO); }
} else { } else {
addImageToGallery(imageDTO); // session.mode === 'generate'
setLastCanvasProgressEvent(null); setLastCanvasProgressEvent(null);
} }
addImageToGallery(imageDTO);
};
const handleOriginOther = async (data: S['InvocationCompleteEvent']) => {
const imageDTO = await getResultImageDTO(data);
if (imageDTO) {
addImageToGallery(imageDTO);
} }
};
return async (data: S['InvocationCompleteEvent']) => {
log.debug(
{ data } as SerializableObject,
`Invocation complete (${data.invocation.type}, ${data.invocation_source_id})`
);
// Update the node execution states - the image output is handled below
if (data.origin === 'workflows') {
await handleOriginWorkflows(data);
} else if (data.origin === 'canvas') {
await handleOriginCanvas(data);
} else {
await handleOriginOther(data);
} }
setLastProgressEvent(null); setLastProgressEvent(null);

View File

@ -17,8 +17,9 @@ import { atom, computed } from 'nanostores';
import { api, LIST_TAG } from 'services/api'; import { api, LIST_TAG } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue'; import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import type { S } from 'services/api/types';
import { buildOnInvocationComplete } from 'services/events/onInvocationComplete'; import { buildOnInvocationComplete } from 'services/events/onInvocationComplete';
import type { ClientToServerEvents, InvocationDenoiseProgressEvent, ServerToClientEvents } from 'services/events/types'; import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
import type { Socket } from 'socket.io-client'; import type { Socket } from 'socket.io-client';
export const socketConnected = createAction('socket/connected'); export const socketConnected = createAction('socket/connected');
@ -34,8 +35,8 @@ type SetEventListenersArg = {
const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select(); const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select();
const nodeTypeDenylist = ['load_image', 'image']; const nodeTypeDenylist = ['load_image', 'image'];
export const $lastProgressEvent = atom<InvocationDenoiseProgressEvent | null>(null); export const $lastProgressEvent = atom<S['InvocationDenoiseProgressEvent'] | null>(null);
export const $lastCanvasProgressEvent = atom<InvocationDenoiseProgressEvent | null>(null); export const $lastCanvasProgressEvent = atom<S['InvocationDenoiseProgressEvent'] | null>(null);
export const $hasProgress = computed($lastProgressEvent, (val) => Boolean(val)); export const $hasProgress = computed($lastProgressEvent, (val) => Boolean(val));
export const $progressImage = computed($lastProgressEvent, (val) => val?.progress_image ?? null); export const $progressImage = computed($lastProgressEvent, (val) => val?.progress_image ?? null);
const cancellations = new Set<string>(); const cancellations = new Set<string>();

View File

@ -1,69 +1,35 @@
import type { S } from 'services/api/types'; import type { S } from 'services/api/types';
type ModelLoadStartedEvent = S['ModelLoadStartedEvent']; type ClientEmitSubscribeQueue = { queue_id: string };
type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent'];
type InvocationStartedEvent = S['InvocationStartedEvent'];
type InvocationDenoiseProgressEvent = S['InvocationDenoiseProgressEvent'];
type InvocationCompleteEvent = S['InvocationCompleteEvent'];
type InvocationErrorEvent = S['InvocationErrorEvent'];
type ModelInstallDownloadStartedEvent = S['ModelInstallDownloadStartedEvent'];
type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];
type ModelInstallDownloadsCompleteEvent = S['ModelInstallDownloadsCompleteEvent'];
type ModelInstallCompleteEvent = S['ModelInstallCompleteEvent'];
type ModelInstallErrorEvent = S['ModelInstallErrorEvent'];
type ModelInstallStartedEvent = S['ModelInstallStartedEvent'];
type ModelInstallCancelledEvent = S['ModelInstallCancelledEvent'];
type DownloadStartedEvent = S['DownloadStartedEvent'];
type DownloadProgressEvent = S['DownloadProgressEvent'];
type DownloadCompleteEvent = S['DownloadCompleteEvent'];
type DownloadCancelledEvent = S['DownloadCancelledEvent'];
type DownloadErrorEvent = S['DownloadErrorEvent'];
type QueueItemStatusChangedEvent = S['QueueItemStatusChangedEvent'];
type QueueClearedEvent = S['QueueClearedEvent'];
type BatchEnqueuedEvent = S['BatchEnqueuedEvent'];
type BulkDownloadStartedEvent = S['BulkDownloadStartedEvent'];
type BulkDownloadCompleteEvent = S['BulkDownloadCompleteEvent'];
type BulkDownloadFailedEvent = S['BulkDownloadErrorEvent'];
type ClientEmitSubscribeQueue = {
queue_id: string;
};
type ClientEmitUnsubscribeQueue = ClientEmitSubscribeQueue; type ClientEmitUnsubscribeQueue = ClientEmitSubscribeQueue;
type ClientEmitSubscribeBulkDownload = { type ClientEmitSubscribeBulkDownload = { bulk_download_id: string };
bulk_download_id: string;
};
type ClientEmitUnsubscribeBulkDownload = ClientEmitSubscribeBulkDownload; type ClientEmitUnsubscribeBulkDownload = ClientEmitSubscribeBulkDownload;
export type ServerToClientEvents = { export type ServerToClientEvents = {
invocation_denoise_progress: (payload: InvocationDenoiseProgressEvent) => void; invocation_denoise_progress: (payload: S['InvocationDenoiseProgressEvent']) => void;
invocation_complete: (payload: InvocationCompleteEvent) => void; invocation_complete: (payload: S['InvocationCompleteEvent']) => void;
invocation_error: (payload: InvocationErrorEvent) => void; invocation_error: (payload: S['InvocationErrorEvent']) => void;
invocation_started: (payload: InvocationStartedEvent) => void; invocation_started: (payload: S['InvocationStartedEvent']) => void;
download_started: (payload: DownloadStartedEvent) => void; download_started: (payload: S['DownloadStartedEvent']) => void;
download_progress: (payload: DownloadProgressEvent) => void; download_progress: (payload: S['DownloadProgressEvent']) => void;
download_complete: (payload: DownloadCompleteEvent) => void; download_complete: (payload: S['DownloadCompleteEvent']) => void;
download_cancelled: (payload: DownloadCancelledEvent) => void; download_cancelled: (payload: S['DownloadCancelledEvent']) => void;
download_error: (payload: DownloadErrorEvent) => void; download_error: (payload: S['DownloadErrorEvent']) => void;
model_load_started: (payload: ModelLoadStartedEvent) => void; model_load_started: (payload: S['ModelLoadStartedEvent']) => void;
model_install_started: (payload: ModelInstallStartedEvent) => void; model_install_started: (payload: S['ModelInstallStartedEvent']) => void;
model_install_download_started: (payload: ModelInstallDownloadStartedEvent) => void; model_install_download_started: (payload: S['ModelInstallDownloadStartedEvent']) => void;
model_install_download_progress: (payload: ModelInstallDownloadProgressEvent) => void; model_install_download_progress: (payload: S['ModelInstallDownloadProgressEvent']) => void;
model_install_downloads_complete: (payload: ModelInstallDownloadsCompleteEvent) => void; model_install_downloads_complete: (payload: S['ModelInstallDownloadsCompleteEvent']) => void;
model_install_complete: (payload: ModelInstallCompleteEvent) => void; model_install_complete: (payload: S['ModelInstallCompleteEvent']) => void;
model_install_error: (payload: ModelInstallErrorEvent) => void; model_install_error: (payload: S['ModelInstallErrorEvent']) => void;
model_install_cancelled: (payload: ModelInstallCancelledEvent) => void; model_install_cancelled: (payload: S['ModelInstallCancelledEvent']) => void;
model_load_complete: (payload: ModelLoadCompleteEvent) => void; model_load_complete: (payload: S['ModelLoadCompleteEvent']) => void;
queue_item_status_changed: (payload: QueueItemStatusChangedEvent) => void; queue_item_status_changed: (payload: S['QueueItemStatusChangedEvent']) => void;
queue_cleared: (payload: QueueClearedEvent) => void; queue_cleared: (payload: S['QueueClearedEvent']) => void;
batch_enqueued: (payload: BatchEnqueuedEvent) => void; batch_enqueued: (payload: S['BatchEnqueuedEvent']) => void;
bulk_download_started: (payload: BulkDownloadStartedEvent) => void; bulk_download_started: (payload: S['BulkDownloadStartedEvent']) => void;
bulk_download_complete: (payload: BulkDownloadCompleteEvent) => void; bulk_download_complete: (payload: S['BulkDownloadCompleteEvent']) => void;
bulk_download_error: (payload: BulkDownloadFailedEvent) => void; bulk_download_error: (payload: S['BulkDownloadErrorEvent']) => void;
}; };
export type ClientToServerEvents = { export type ClientToServerEvents = {