feat(ui): handle initial image layers in control layers helper

This commit is contained in:
psychedelicious 2024-05-08 14:22:19 +10:00 committed by Kent Keirsey
parent f147f99bef
commit 3f489c92c8
3 changed files with 157 additions and 107 deletions

View File

@ -2,11 +2,12 @@ import { getStore } from 'app/store/nanostores/store';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { import {
isControlAdapterLayer, isControlAdapterLayer,
isInitialImageLayer,
isIPAdapterLayer, isIPAdapterLayer,
isRegionalGuidanceLayer, isRegionalGuidanceLayer,
rgLayerMaskImageUploaded, rgLayerMaskImageUploaded,
} from 'features/controlLayers/store/controlLayersSlice'; } from 'features/controlLayers/store/controlLayersSlice';
import type { Layer, RegionalGuidanceLayer } from 'features/controlLayers/store/types'; import type { InitialImageLayer, Layer, RegionalGuidanceLayer } from 'features/controlLayers/store/types';
import { import {
type ControlNetConfigV2, type ControlNetConfigV2,
type ImageWithDims, type ImageWithDims,
@ -20,9 +21,11 @@ import { getRegionalPromptLayerBlobs } from 'features/controlLayers/util/getLaye
import type { ImageField } from 'features/nodes/types/common'; import type { ImageField } from 'features/nodes/types/common';
import { import {
CONTROL_NET_COLLECT, CONTROL_NET_COLLECT,
IMAGE_TO_LATENTS,
IP_ADAPTER_COLLECT, IP_ADAPTER_COLLECT,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT, NEGATIVE_CONDITIONING_COLLECT,
NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT, POSITIVE_CONDITIONING_COLLECT,
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX, PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
@ -30,6 +33,7 @@ import {
PROMPT_REGION_NEGATIVE_COND_PREFIX, PROMPT_REGION_NEGATIVE_COND_PREFIX,
PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX, PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX,
PROMPT_REGION_POSITIVE_COND_PREFIX, PROMPT_REGION_POSITIVE_COND_PREFIX,
RESIZE,
T2I_ADAPTER_COLLECT, T2I_ADAPTER_COLLECT,
} from 'features/nodes/util/graph/constants'; } from 'features/nodes/util/graph/constants';
import { upsertMetadata } from 'features/nodes/util/graph/metadata'; import { upsertMetadata } from 'features/nodes/util/graph/metadata';
@ -38,9 +42,10 @@ import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type { import type {
CollectInvocation, CollectInvocation,
ControlNetInvocation, ControlNetInvocation,
CoreMetadataInvocation,
Edge, Edge,
ImageDTO, ImageDTO,
ImageResizeInvocation,
ImageToLatentsInvocation,
IPAdapterInvocation, IPAdapterInvocation,
NonNullableGraph, NonNullableGraph,
S, S,
@ -67,33 +72,6 @@ const buildControlImage = (
assert(false, 'Attempted to add unprocessed control image'); assert(false, 'Attempted to add unprocessed control image');
}; };
const buildControlNetMetadata = (controlNet: ControlNetConfigV2): S['ControlNetMetadataField'] => {
const { beginEndStepPct, controlMode, image, model, processedImage, processorConfig, weight } = controlNet;
assert(model, 'ControlNet model is required');
assert(image, 'ControlNet image is required');
const processed_image =
processedImage && processorConfig
? {
image_name: processedImage.name,
}
: null;
return {
control_model: model,
control_weight: weight,
control_mode: controlMode,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
resize_mode: 'just_resize',
image: {
image_name: image.name,
},
processed_image,
};
};
const addControlNetCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => { const addControlNetCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
if (graph.nodes[CONTROL_NET_COLLECT]) { if (graph.nodes[CONTROL_NET_COLLECT]) {
// You see, we've already got one! // You see, we've already got one!
@ -123,7 +101,6 @@ const addGlobalControlNetsToGraph = async (
if (controlNets.length === 0) { if (controlNets.length === 0) {
return; return;
} }
const controlNetMetadata: CoreMetadataInvocation['controlnets'] = [];
addControlNetCollectorSafe(graph, denoiseNodeId); addControlNetCollectorSafe(graph, denoiseNodeId);
for (const controlNet of controlNets) { for (const controlNet of controlNets) {
@ -147,8 +124,6 @@ const addGlobalControlNetsToGraph = async (
graph.nodes[controlNetNode.id] = controlNetNode; graph.nodes[controlNetNode.id] = controlNetNode;
controlNetMetadata.push(buildControlNetMetadata(controlNet));
graph.edges.push({ graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' }, source: { node_id: controlNetNode.id, field: 'control' },
destination: { destination: {
@ -157,33 +132,6 @@ const addGlobalControlNetsToGraph = async (
}, },
}); });
} }
upsertMetadata(graph, { controlnets: controlNetMetadata });
};
const buildT2IAdapterMetadata = (t2iAdapter: T2IAdapterConfigV2): S['T2IAdapterMetadataField'] => {
const { beginEndStepPct, image, model, processedImage, processorConfig, weight } = t2iAdapter;
assert(model, 'T2I Adapter model is required');
assert(image, 'T2I Adapter image is required');
const processed_image =
processedImage && processorConfig
? {
image_name: processedImage.name,
}
: null;
return {
t2i_adapter_model: model,
weight,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
resize_mode: 'just_resize',
image: {
image_name: image.name,
},
processed_image,
};
}; };
const addT2IAdapterCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => { const addT2IAdapterCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
@ -215,7 +163,6 @@ const addGlobalT2IAdaptersToGraph = async (
if (t2iAdapters.length === 0) { if (t2iAdapters.length === 0) {
return; return;
} }
const t2iAdapterMetadata: CoreMetadataInvocation['t2iAdapters'] = [];
addT2IAdapterCollectorSafe(graph, denoiseNodeId); addT2IAdapterCollectorSafe(graph, denoiseNodeId);
for (const t2iAdapter of t2iAdapters) { for (const t2iAdapter of t2iAdapters) {
@ -238,8 +185,6 @@ const addGlobalT2IAdaptersToGraph = async (
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode; graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;
t2iAdapterMetadata.push(buildT2IAdapterMetadata(t2iAdapter));
graph.edges.push({ graph.edges.push({
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' }, source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
destination: { destination: {
@ -248,27 +193,6 @@ const addGlobalT2IAdaptersToGraph = async (
}, },
}); });
} }
upsertMetadata(graph, { t2iAdapters: t2iAdapterMetadata });
};
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfigV2): S['IPAdapterMetadataField'] => {
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
assert(model, 'IP Adapter model is required');
assert(image, 'IP Adapter image is required');
return {
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
weight,
method,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
image: {
image_name: image.name,
},
};
}; };
const addIPAdapterCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => { const addIPAdapterCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
@ -300,7 +224,6 @@ const addGlobalIPAdaptersToGraph = async (
if (ipAdapters.length === 0) { if (ipAdapters.length === 0) {
return; return;
} }
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
addIPAdapterCollectorSafe(graph, denoiseNodeId); addIPAdapterCollectorSafe(graph, denoiseNodeId);
for (const ipAdapter of ipAdapters) { for (const ipAdapter of ipAdapters) {
@ -325,8 +248,6 @@ const addGlobalIPAdaptersToGraph = async (
graph.nodes[ipAdapterNode.id] = ipAdapterNode; graph.nodes[ipAdapterNode.id] = ipAdapterNode;
ipAdapterMetdata.push(buildIPAdapterMetadata(ipAdapter));
graph.edges.push({ graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' }, source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
destination: { destination: {
@ -335,16 +256,131 @@ const addGlobalIPAdaptersToGraph = async (
}, },
}); });
} }
upsertMetadata(graph, { ipAdapters: ipAdapterMetdata });
}; };
export const addControlLayersToGraph = async (state: RootState, graph: NonNullableGraph, denoiseNodeId: string) => { const addInitialImageLayerToGraph = (
state: RootState,
graph: NonNullableGraph,
denoiseNodeId: string,
layer: InitialImageLayer
) => {
const { vaePrecision, model } = state.generation;
const { refinerModel, refinerStart } = state.sdxl;
const { width, height } = state.controlLayers.present.size;
assert(layer.isEnabled, 'Initial image layer is not enabled');
assert(layer.image, 'Initial image layer has no image');
const isSDXL = model?.base === 'sdxl';
const useRefinerStartEnd = isSDXL && Boolean(refinerModel);
const denoiseNode = graph.nodes[denoiseNodeId];
assert(denoiseNode?.type === 'denoise_latents', `Missing denoise node or incorrect type: ${denoiseNode?.type}`);
const { denoisingStrength } = layer;
denoiseNode.denoising_start = useRefinerStartEnd
? Math.min(refinerStart, 1 - denoisingStrength)
: 1 - denoisingStrength;
denoiseNode.denoising_end = useRefinerStartEnd ? refinerStart : 1;
// We conditionally hook the image in depending on if a resize is needed
const i2lNode: ImageToLatentsInvocation = {
type: 'i2l',
id: IMAGE_TO_LATENTS,
is_intermediate: true,
use_cache: true,
fp32: vaePrecision === 'fp32',
};
graph.nodes[i2lNode.id] = i2lNode;
graph.edges.push({
source: {
node_id: IMAGE_TO_LATENTS,
field: 'latents',
},
destination: {
node_id: denoiseNode.id,
field: 'latents',
},
});
if (layer.image.width !== width || layer.image.height !== height) {
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
// Create a resize node, explicitly setting its image
const resizeNode: ImageResizeInvocation = {
id: RESIZE,
type: 'img_resize',
image: {
image_name: layer.image.name,
},
is_intermediate: true,
width,
height,
};
graph.nodes[RESIZE] = resizeNode;
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
graph.edges.push({
source: { node_id: RESIZE, field: 'image' },
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'image',
},
});
// The `RESIZE` node also passes its width and height to `NOISE`
graph.edges.push({
source: { node_id: RESIZE, field: 'width' },
destination: {
node_id: NOISE,
field: 'width',
},
});
graph.edges.push({
source: { node_id: RESIZE, field: 'height' },
destination: {
node_id: NOISE,
field: 'height',
},
});
} else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
i2lNode.image = {
image_name: layer.image.name,
};
// Pass the image's dimensions to the `NOISE` node
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
destination: {
node_id: NOISE,
field: 'width',
},
});
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
destination: {
node_id: NOISE,
field: 'height',
},
});
}
upsertMetadata(graph, { generation_mode: isSDXL ? 'sdxl_img2img' : 'img2img' });
};
export const addControlLayersToGraph = async (
state: RootState,
graph: NonNullableGraph,
denoiseNodeId: string
): Promise<Layer[]> => {
const mainModel = state.generation.model; const mainModel = state.generation.model;
assert(mainModel, 'Missing main model when building graph'); assert(mainModel, 'Missing main model when building graph');
const isSDXL = mainModel.base === 'sdxl'; const isSDXL = mainModel.base === 'sdxl';
const layersMetadata: Layer[] = []; const validLayers: Layer[] = [];
// Add global control adapters // Add global control adapters
const validControlAdapterLayers = state.controlLayers.present.layers const validControlAdapterLayers = state.controlLayers.present.layers
@ -366,6 +402,8 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
const validT2IAdapters = validControlAdapterLayers.map((l) => l.controlAdapter).filter(isT2IAdapterConfigV2); const validT2IAdapters = validControlAdapterLayers.map((l) => l.controlAdapter).filter(isT2IAdapterConfigV2);
addGlobalT2IAdaptersToGraph(validT2IAdapters, graph, denoiseNodeId); addGlobalT2IAdaptersToGraph(validT2IAdapters, graph, denoiseNodeId);
validLayers.push(...validControlAdapterLayers);
const validIPAdapterLayers = state.controlLayers.present.layers const validIPAdapterLayers = state.controlLayers.present.layers
// Must be an IP Adapter layer // Must be an IP Adapter layer
.filter(isIPAdapterLayer) .filter(isIPAdapterLayer)
@ -381,6 +419,21 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
}); });
const validIPAdapters = validIPAdapterLayers.map((l) => l.ipAdapter); const validIPAdapters = validIPAdapterLayers.map((l) => l.ipAdapter);
addGlobalIPAdaptersToGraph(validIPAdapters, graph, denoiseNodeId); addGlobalIPAdaptersToGraph(validIPAdapters, graph, denoiseNodeId);
validLayers.push(...validIPAdapterLayers);
const initialImageLayer = state.controlLayers.present.layers.filter(isInitialImageLayer).find((l) => {
if (!l.isEnabled) {
return false;
}
if (!l.image) {
return false;
}
return true;
});
if (initialImageLayer) {
addInitialImageLayerToGraph(state, graph, denoiseNodeId, initialImageLayer);
validLayers.push(initialImageLayer);
}
const validRGLayers = state.controlLayers.present.layers const validRGLayers = state.controlLayers.present.layers
// Only RG layers are get masks // Only RG layers are get masks
@ -393,8 +446,7 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
const hasIPAdapter = l.ipAdapters.filter((ipa) => ipa.image).length > 0; const hasIPAdapter = l.ipAdapters.filter((ipa) => ipa.image).length > 0;
return hasTextPrompt || hasIPAdapter; return hasTextPrompt || hasIPAdapter;
}); });
validLayers.push(...validRGLayers);
layersMetadata.push(...validRGLayers, ...validControlAdapterLayers, ...validIPAdapterLayers);
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing // TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
// the existing conditioning nodes. // the existing conditioning nodes.
@ -660,7 +712,8 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
} }
} }
upsertMetadata(graph, { layers: layersMetadata }); upsertMetadata(graph, { layers: validLayers });
return validLayers;
}; };
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => { const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {

View File

@ -1,8 +1,8 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { isInitialImageLayer, isRegionalGuidanceLayer } from 'features/controlLayers/store/controlLayersSlice';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph'; import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
import { addInitialImageToLinearGraph } from 'features/nodes/util/graph/addInitialImageToLinearGraph';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types'; import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
@ -232,24 +232,24 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<NonNull
LATENTS_TO_IMAGE LATENTS_TO_IMAGE
); );
const didAddInitialImage = addInitialImageToLinearGraph(state, graph, DENOISE_LATENTS);
// Add Seamless To Graph // Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) { if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId); addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS; modelLoaderNodeId = SEAMLESS;
} }
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support // add LoRA support
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId); await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
await addControlLayersToGraph(state, graph, DENOISE_LATENTS); const addedLayers = await addControlLayersToGraph(state, graph, DENOISE_LATENTS);
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
const shouldUseHRF = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l));
// High resolution fix. // High resolution fix.
if (state.hrf.hrfEnabled && !didAddInitialImage) { if (state.hrf.hrfEnabled && shouldUseHRF) {
console.log('HRFING');
addHrfToGraph(state, graph); addHrfToGraph(state, graph);
} }

View File

@ -2,7 +2,6 @@ import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph'; import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
import { addInitialImageToLinearGraph } from 'features/nodes/util/graph/addInitialImageToLinearGraph';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types'; import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
@ -242,8 +241,6 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
LATENTS_TO_IMAGE LATENTS_TO_IMAGE
); );
addInitialImageToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add Seamless To Graph // Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) { if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId); addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
@ -258,14 +255,14 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
} }
} }
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support // add LoRA support
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
await addControlLayersToGraph(state, graph, SDXL_DENOISE_LATENTS); await addControlLayersToGraph(state, graph, SDXL_DENOISE_LATENTS);
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
// NSFW & watermark - must be last thing added to graph // NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) { if (state.system.shouldUseNSFWChecker) {
// must add before watermarker! // must add before watermarker!