mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): handle initial image layers in control layers helper
This commit is contained in:
parent
f147f99bef
commit
3f489c92c8
@ -2,11 +2,12 @@ import { getStore } from 'app/store/nanostores/store';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import {
|
||||
isControlAdapterLayer,
|
||||
isInitialImageLayer,
|
||||
isIPAdapterLayer,
|
||||
isRegionalGuidanceLayer,
|
||||
rgLayerMaskImageUploaded,
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type { Layer, RegionalGuidanceLayer } from 'features/controlLayers/store/types';
|
||||
import type { InitialImageLayer, Layer, RegionalGuidanceLayer } from 'features/controlLayers/store/types';
|
||||
import {
|
||||
type ControlNetConfigV2,
|
||||
type ImageWithDims,
|
||||
@ -20,9 +21,11 @@ import { getRegionalPromptLayerBlobs } from 'features/controlLayers/util/getLaye
|
||||
import type { ImageField } from 'features/nodes/types/common';
|
||||
import {
|
||||
CONTROL_NET_COLLECT,
|
||||
IMAGE_TO_LATENTS,
|
||||
IP_ADAPTER_COLLECT,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NEGATIVE_CONDITIONING_COLLECT,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING_COLLECT,
|
||||
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
|
||||
@ -30,6 +33,7 @@ import {
|
||||
PROMPT_REGION_NEGATIVE_COND_PREFIX,
|
||||
PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX,
|
||||
PROMPT_REGION_POSITIVE_COND_PREFIX,
|
||||
RESIZE,
|
||||
T2I_ADAPTER_COLLECT,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import { upsertMetadata } from 'features/nodes/util/graph/metadata';
|
||||
@ -38,9 +42,10 @@ import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
||||
import type {
|
||||
CollectInvocation,
|
||||
ControlNetInvocation,
|
||||
CoreMetadataInvocation,
|
||||
Edge,
|
||||
ImageDTO,
|
||||
ImageResizeInvocation,
|
||||
ImageToLatentsInvocation,
|
||||
IPAdapterInvocation,
|
||||
NonNullableGraph,
|
||||
S,
|
||||
@ -67,33 +72,6 @@ const buildControlImage = (
|
||||
assert(false, 'Attempted to add unprocessed control image');
|
||||
};
|
||||
|
||||
const buildControlNetMetadata = (controlNet: ControlNetConfigV2): S['ControlNetMetadataField'] => {
|
||||
const { beginEndStepPct, controlMode, image, model, processedImage, processorConfig, weight } = controlNet;
|
||||
|
||||
assert(model, 'ControlNet model is required');
|
||||
assert(image, 'ControlNet image is required');
|
||||
|
||||
const processed_image =
|
||||
processedImage && processorConfig
|
||||
? {
|
||||
image_name: processedImage.name,
|
||||
}
|
||||
: null;
|
||||
|
||||
return {
|
||||
control_model: model,
|
||||
control_weight: weight,
|
||||
control_mode: controlMode,
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
resize_mode: 'just_resize',
|
||||
image: {
|
||||
image_name: image.name,
|
||||
},
|
||||
processed_image,
|
||||
};
|
||||
};
|
||||
|
||||
const addControlNetCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||
if (graph.nodes[CONTROL_NET_COLLECT]) {
|
||||
// You see, we've already got one!
|
||||
@ -123,7 +101,6 @@ const addGlobalControlNetsToGraph = async (
|
||||
if (controlNets.length === 0) {
|
||||
return;
|
||||
}
|
||||
const controlNetMetadata: CoreMetadataInvocation['controlnets'] = [];
|
||||
addControlNetCollectorSafe(graph, denoiseNodeId);
|
||||
|
||||
for (const controlNet of controlNets) {
|
||||
@ -147,8 +124,6 @@ const addGlobalControlNetsToGraph = async (
|
||||
|
||||
graph.nodes[controlNetNode.id] = controlNetNode;
|
||||
|
||||
controlNetMetadata.push(buildControlNetMetadata(controlNet));
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: controlNetNode.id, field: 'control' },
|
||||
destination: {
|
||||
@ -157,33 +132,6 @@ const addGlobalControlNetsToGraph = async (
|
||||
},
|
||||
});
|
||||
}
|
||||
upsertMetadata(graph, { controlnets: controlNetMetadata });
|
||||
};
|
||||
|
||||
const buildT2IAdapterMetadata = (t2iAdapter: T2IAdapterConfigV2): S['T2IAdapterMetadataField'] => {
|
||||
const { beginEndStepPct, image, model, processedImage, processorConfig, weight } = t2iAdapter;
|
||||
|
||||
assert(model, 'T2I Adapter model is required');
|
||||
assert(image, 'T2I Adapter image is required');
|
||||
|
||||
const processed_image =
|
||||
processedImage && processorConfig
|
||||
? {
|
||||
image_name: processedImage.name,
|
||||
}
|
||||
: null;
|
||||
|
||||
return {
|
||||
t2i_adapter_model: model,
|
||||
weight,
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
resize_mode: 'just_resize',
|
||||
image: {
|
||||
image_name: image.name,
|
||||
},
|
||||
processed_image,
|
||||
};
|
||||
};
|
||||
|
||||
const addT2IAdapterCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||
@ -215,7 +163,6 @@ const addGlobalT2IAdaptersToGraph = async (
|
||||
if (t2iAdapters.length === 0) {
|
||||
return;
|
||||
}
|
||||
const t2iAdapterMetadata: CoreMetadataInvocation['t2iAdapters'] = [];
|
||||
addT2IAdapterCollectorSafe(graph, denoiseNodeId);
|
||||
|
||||
for (const t2iAdapter of t2iAdapters) {
|
||||
@ -238,8 +185,6 @@ const addGlobalT2IAdaptersToGraph = async (
|
||||
|
||||
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;
|
||||
|
||||
t2iAdapterMetadata.push(buildT2IAdapterMetadata(t2iAdapter));
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
|
||||
destination: {
|
||||
@ -248,27 +193,6 @@ const addGlobalT2IAdaptersToGraph = async (
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
upsertMetadata(graph, { t2iAdapters: t2iAdapterMetadata });
|
||||
};
|
||||
|
||||
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfigV2): S['IPAdapterMetadataField'] => {
|
||||
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
|
||||
|
||||
assert(model, 'IP Adapter model is required');
|
||||
assert(image, 'IP Adapter image is required');
|
||||
|
||||
return {
|
||||
ip_adapter_model: model,
|
||||
clip_vision_model: clipVisionModel,
|
||||
weight,
|
||||
method,
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
image: {
|
||||
image_name: image.name,
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
const addIPAdapterCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||
@ -300,7 +224,6 @@ const addGlobalIPAdaptersToGraph = async (
|
||||
if (ipAdapters.length === 0) {
|
||||
return;
|
||||
}
|
||||
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
|
||||
addIPAdapterCollectorSafe(graph, denoiseNodeId);
|
||||
|
||||
for (const ipAdapter of ipAdapters) {
|
||||
@ -325,8 +248,6 @@ const addGlobalIPAdaptersToGraph = async (
|
||||
|
||||
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
|
||||
|
||||
ipAdapterMetdata.push(buildIPAdapterMetadata(ipAdapter));
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
||||
destination: {
|
||||
@ -335,16 +256,131 @@ const addGlobalIPAdaptersToGraph = async (
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
upsertMetadata(graph, { ipAdapters: ipAdapterMetdata });
|
||||
};
|
||||
|
||||
export const addControlLayersToGraph = async (state: RootState, graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||
const addInitialImageLayerToGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
denoiseNodeId: string,
|
||||
layer: InitialImageLayer
|
||||
) => {
|
||||
const { vaePrecision, model } = state.generation;
|
||||
const { refinerModel, refinerStart } = state.sdxl;
|
||||
const { width, height } = state.controlLayers.present.size;
|
||||
assert(layer.isEnabled, 'Initial image layer is not enabled');
|
||||
assert(layer.image, 'Initial image layer has no image');
|
||||
|
||||
const isSDXL = model?.base === 'sdxl';
|
||||
const useRefinerStartEnd = isSDXL && Boolean(refinerModel);
|
||||
|
||||
const denoiseNode = graph.nodes[denoiseNodeId];
|
||||
assert(denoiseNode?.type === 'denoise_latents', `Missing denoise node or incorrect type: ${denoiseNode?.type}`);
|
||||
|
||||
const { denoisingStrength } = layer;
|
||||
denoiseNode.denoising_start = useRefinerStartEnd
|
||||
? Math.min(refinerStart, 1 - denoisingStrength)
|
||||
: 1 - denoisingStrength;
|
||||
denoiseNode.denoising_end = useRefinerStartEnd ? refinerStart : 1;
|
||||
|
||||
// We conditionally hook the image in depending on if a resize is needed
|
||||
const i2lNode: ImageToLatentsInvocation = {
|
||||
type: 'i2l',
|
||||
id: IMAGE_TO_LATENTS,
|
||||
is_intermediate: true,
|
||||
use_cache: true,
|
||||
fp32: vaePrecision === 'fp32',
|
||||
};
|
||||
|
||||
graph.nodes[i2lNode.id] = i2lNode;
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: denoiseNode.id,
|
||||
field: 'latents',
|
||||
},
|
||||
});
|
||||
|
||||
if (layer.image.width !== width || layer.image.height !== height) {
|
||||
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
||||
|
||||
// Create a resize node, explicitly setting its image
|
||||
const resizeNode: ImageResizeInvocation = {
|
||||
id: RESIZE,
|
||||
type: 'img_resize',
|
||||
image: {
|
||||
image_name: layer.image.name,
|
||||
},
|
||||
is_intermediate: true,
|
||||
width,
|
||||
height,
|
||||
};
|
||||
|
||||
graph.nodes[RESIZE] = resizeNode;
|
||||
|
||||
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'image' },
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
|
||||
// The `RESIZE` node also passes its width and height to `NOISE`
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'width' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'width',
|
||||
},
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'height' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'height',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||
i2lNode.image = {
|
||||
image_name: layer.image.name,
|
||||
};
|
||||
|
||||
// Pass the image's dimensions to the `NOISE` node
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'width',
|
||||
},
|
||||
});
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'height',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
upsertMetadata(graph, { generation_mode: isSDXL ? 'sdxl_img2img' : 'img2img' });
|
||||
};
|
||||
|
||||
export const addControlLayersToGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
denoiseNodeId: string
|
||||
): Promise<Layer[]> => {
|
||||
const mainModel = state.generation.model;
|
||||
assert(mainModel, 'Missing main model when building graph');
|
||||
const isSDXL = mainModel.base === 'sdxl';
|
||||
|
||||
const layersMetadata: Layer[] = [];
|
||||
const validLayers: Layer[] = [];
|
||||
|
||||
// Add global control adapters
|
||||
const validControlAdapterLayers = state.controlLayers.present.layers
|
||||
@ -366,6 +402,8 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
|
||||
const validT2IAdapters = validControlAdapterLayers.map((l) => l.controlAdapter).filter(isT2IAdapterConfigV2);
|
||||
addGlobalT2IAdaptersToGraph(validT2IAdapters, graph, denoiseNodeId);
|
||||
|
||||
validLayers.push(...validControlAdapterLayers);
|
||||
|
||||
const validIPAdapterLayers = state.controlLayers.present.layers
|
||||
// Must be an IP Adapter layer
|
||||
.filter(isIPAdapterLayer)
|
||||
@ -381,6 +419,21 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
|
||||
});
|
||||
const validIPAdapters = validIPAdapterLayers.map((l) => l.ipAdapter);
|
||||
addGlobalIPAdaptersToGraph(validIPAdapters, graph, denoiseNodeId);
|
||||
validLayers.push(...validIPAdapterLayers);
|
||||
|
||||
const initialImageLayer = state.controlLayers.present.layers.filter(isInitialImageLayer).find((l) => {
|
||||
if (!l.isEnabled) {
|
||||
return false;
|
||||
}
|
||||
if (!l.image) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
if (initialImageLayer) {
|
||||
addInitialImageLayerToGraph(state, graph, denoiseNodeId, initialImageLayer);
|
||||
validLayers.push(initialImageLayer);
|
||||
}
|
||||
|
||||
const validRGLayers = state.controlLayers.present.layers
|
||||
// Only RG layers are get masks
|
||||
@ -393,8 +446,7 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
|
||||
const hasIPAdapter = l.ipAdapters.filter((ipa) => ipa.image).length > 0;
|
||||
return hasTextPrompt || hasIPAdapter;
|
||||
});
|
||||
|
||||
layersMetadata.push(...validRGLayers, ...validControlAdapterLayers, ...validIPAdapterLayers);
|
||||
validLayers.push(...validRGLayers);
|
||||
|
||||
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
|
||||
// the existing conditioning nodes.
|
||||
@ -660,7 +712,8 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
|
||||
}
|
||||
}
|
||||
|
||||
upsertMetadata(graph, { layers: layersMetadata });
|
||||
upsertMetadata(graph, { layers: validLayers });
|
||||
return validLayers;
|
||||
};
|
||||
|
||||
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {
|
||||
|
@ -1,8 +1,8 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { isInitialImageLayer, isRegionalGuidanceLayer } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
|
||||
import { addInitialImageToLinearGraph } from 'features/nodes/util/graph/addInitialImageToLinearGraph';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
@ -232,24 +232,24 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<NonNull
|
||||
LATENTS_TO_IMAGE
|
||||
);
|
||||
|
||||
const didAddInitialImage = addInitialImageToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// Add Seamless To Graph
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
|
||||
modelLoaderNodeId = SEAMLESS;
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
await addControlLayersToGraph(state, graph, DENOISE_LATENTS);
|
||||
const addedLayers = await addControlLayersToGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// optionally add custom VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
const shouldUseHRF = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l));
|
||||
// High resolution fix.
|
||||
if (state.hrf.hrfEnabled && !didAddInitialImage) {
|
||||
if (state.hrf.hrfEnabled && shouldUseHRF) {
|
||||
console.log('HRFING');
|
||||
addHrfToGraph(state, graph);
|
||||
}
|
||||
|
||||
|
@ -2,7 +2,6 @@ import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
|
||||
import { addInitialImageToLinearGraph } from 'features/nodes/util/graph/addInitialImageToLinearGraph';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
@ -242,8 +241,6 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
||||
LATENTS_TO_IMAGE
|
||||
);
|
||||
|
||||
addInitialImageToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// Add Seamless To Graph
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
|
||||
@ -258,14 +255,14 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
||||
}
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
await addControlLayersToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// optionally add custom VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
// must add before watermarker!
|
||||
|
Loading…
x
Reference in New Issue
Block a user