mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): handle initial image layers in control layers helper
This commit is contained in:
parent
f147f99bef
commit
3f489c92c8
@ -2,11 +2,12 @@ import { getStore } from 'app/store/nanostores/store';
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import {
|
import {
|
||||||
isControlAdapterLayer,
|
isControlAdapterLayer,
|
||||||
|
isInitialImageLayer,
|
||||||
isIPAdapterLayer,
|
isIPAdapterLayer,
|
||||||
isRegionalGuidanceLayer,
|
isRegionalGuidanceLayer,
|
||||||
rgLayerMaskImageUploaded,
|
rgLayerMaskImageUploaded,
|
||||||
} from 'features/controlLayers/store/controlLayersSlice';
|
} from 'features/controlLayers/store/controlLayersSlice';
|
||||||
import type { Layer, RegionalGuidanceLayer } from 'features/controlLayers/store/types';
|
import type { InitialImageLayer, Layer, RegionalGuidanceLayer } from 'features/controlLayers/store/types';
|
||||||
import {
|
import {
|
||||||
type ControlNetConfigV2,
|
type ControlNetConfigV2,
|
||||||
type ImageWithDims,
|
type ImageWithDims,
|
||||||
@ -20,9 +21,11 @@ import { getRegionalPromptLayerBlobs } from 'features/controlLayers/util/getLaye
|
|||||||
import type { ImageField } from 'features/nodes/types/common';
|
import type { ImageField } from 'features/nodes/types/common';
|
||||||
import {
|
import {
|
||||||
CONTROL_NET_COLLECT,
|
CONTROL_NET_COLLECT,
|
||||||
|
IMAGE_TO_LATENTS,
|
||||||
IP_ADAPTER_COLLECT,
|
IP_ADAPTER_COLLECT,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NEGATIVE_CONDITIONING_COLLECT,
|
NEGATIVE_CONDITIONING_COLLECT,
|
||||||
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
POSITIVE_CONDITIONING_COLLECT,
|
POSITIVE_CONDITIONING_COLLECT,
|
||||||
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
|
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
|
||||||
@ -30,6 +33,7 @@ import {
|
|||||||
PROMPT_REGION_NEGATIVE_COND_PREFIX,
|
PROMPT_REGION_NEGATIVE_COND_PREFIX,
|
||||||
PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX,
|
PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX,
|
||||||
PROMPT_REGION_POSITIVE_COND_PREFIX,
|
PROMPT_REGION_POSITIVE_COND_PREFIX,
|
||||||
|
RESIZE,
|
||||||
T2I_ADAPTER_COLLECT,
|
T2I_ADAPTER_COLLECT,
|
||||||
} from 'features/nodes/util/graph/constants';
|
} from 'features/nodes/util/graph/constants';
|
||||||
import { upsertMetadata } from 'features/nodes/util/graph/metadata';
|
import { upsertMetadata } from 'features/nodes/util/graph/metadata';
|
||||||
@ -38,9 +42,10 @@ import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
|||||||
import type {
|
import type {
|
||||||
CollectInvocation,
|
CollectInvocation,
|
||||||
ControlNetInvocation,
|
ControlNetInvocation,
|
||||||
CoreMetadataInvocation,
|
|
||||||
Edge,
|
Edge,
|
||||||
ImageDTO,
|
ImageDTO,
|
||||||
|
ImageResizeInvocation,
|
||||||
|
ImageToLatentsInvocation,
|
||||||
IPAdapterInvocation,
|
IPAdapterInvocation,
|
||||||
NonNullableGraph,
|
NonNullableGraph,
|
||||||
S,
|
S,
|
||||||
@ -67,33 +72,6 @@ const buildControlImage = (
|
|||||||
assert(false, 'Attempted to add unprocessed control image');
|
assert(false, 'Attempted to add unprocessed control image');
|
||||||
};
|
};
|
||||||
|
|
||||||
const buildControlNetMetadata = (controlNet: ControlNetConfigV2): S['ControlNetMetadataField'] => {
|
|
||||||
const { beginEndStepPct, controlMode, image, model, processedImage, processorConfig, weight } = controlNet;
|
|
||||||
|
|
||||||
assert(model, 'ControlNet model is required');
|
|
||||||
assert(image, 'ControlNet image is required');
|
|
||||||
|
|
||||||
const processed_image =
|
|
||||||
processedImage && processorConfig
|
|
||||||
? {
|
|
||||||
image_name: processedImage.name,
|
|
||||||
}
|
|
||||||
: null;
|
|
||||||
|
|
||||||
return {
|
|
||||||
control_model: model,
|
|
||||||
control_weight: weight,
|
|
||||||
control_mode: controlMode,
|
|
||||||
begin_step_percent: beginEndStepPct[0],
|
|
||||||
end_step_percent: beginEndStepPct[1],
|
|
||||||
resize_mode: 'just_resize',
|
|
||||||
image: {
|
|
||||||
image_name: image.name,
|
|
||||||
},
|
|
||||||
processed_image,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
const addControlNetCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
|
const addControlNetCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||||
if (graph.nodes[CONTROL_NET_COLLECT]) {
|
if (graph.nodes[CONTROL_NET_COLLECT]) {
|
||||||
// You see, we've already got one!
|
// You see, we've already got one!
|
||||||
@ -123,7 +101,6 @@ const addGlobalControlNetsToGraph = async (
|
|||||||
if (controlNets.length === 0) {
|
if (controlNets.length === 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const controlNetMetadata: CoreMetadataInvocation['controlnets'] = [];
|
|
||||||
addControlNetCollectorSafe(graph, denoiseNodeId);
|
addControlNetCollectorSafe(graph, denoiseNodeId);
|
||||||
|
|
||||||
for (const controlNet of controlNets) {
|
for (const controlNet of controlNets) {
|
||||||
@ -147,8 +124,6 @@ const addGlobalControlNetsToGraph = async (
|
|||||||
|
|
||||||
graph.nodes[controlNetNode.id] = controlNetNode;
|
graph.nodes[controlNetNode.id] = controlNetNode;
|
||||||
|
|
||||||
controlNetMetadata.push(buildControlNetMetadata(controlNet));
|
|
||||||
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: controlNetNode.id, field: 'control' },
|
source: { node_id: controlNetNode.id, field: 'control' },
|
||||||
destination: {
|
destination: {
|
||||||
@ -157,33 +132,6 @@ const addGlobalControlNetsToGraph = async (
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
upsertMetadata(graph, { controlnets: controlNetMetadata });
|
|
||||||
};
|
|
||||||
|
|
||||||
const buildT2IAdapterMetadata = (t2iAdapter: T2IAdapterConfigV2): S['T2IAdapterMetadataField'] => {
|
|
||||||
const { beginEndStepPct, image, model, processedImage, processorConfig, weight } = t2iAdapter;
|
|
||||||
|
|
||||||
assert(model, 'T2I Adapter model is required');
|
|
||||||
assert(image, 'T2I Adapter image is required');
|
|
||||||
|
|
||||||
const processed_image =
|
|
||||||
processedImage && processorConfig
|
|
||||||
? {
|
|
||||||
image_name: processedImage.name,
|
|
||||||
}
|
|
||||||
: null;
|
|
||||||
|
|
||||||
return {
|
|
||||||
t2i_adapter_model: model,
|
|
||||||
weight,
|
|
||||||
begin_step_percent: beginEndStepPct[0],
|
|
||||||
end_step_percent: beginEndStepPct[1],
|
|
||||||
resize_mode: 'just_resize',
|
|
||||||
image: {
|
|
||||||
image_name: image.name,
|
|
||||||
},
|
|
||||||
processed_image,
|
|
||||||
};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const addT2IAdapterCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
|
const addT2IAdapterCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||||
@ -215,7 +163,6 @@ const addGlobalT2IAdaptersToGraph = async (
|
|||||||
if (t2iAdapters.length === 0) {
|
if (t2iAdapters.length === 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const t2iAdapterMetadata: CoreMetadataInvocation['t2iAdapters'] = [];
|
|
||||||
addT2IAdapterCollectorSafe(graph, denoiseNodeId);
|
addT2IAdapterCollectorSafe(graph, denoiseNodeId);
|
||||||
|
|
||||||
for (const t2iAdapter of t2iAdapters) {
|
for (const t2iAdapter of t2iAdapters) {
|
||||||
@ -238,8 +185,6 @@ const addGlobalT2IAdaptersToGraph = async (
|
|||||||
|
|
||||||
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;
|
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;
|
||||||
|
|
||||||
t2iAdapterMetadata.push(buildT2IAdapterMetadata(t2iAdapter));
|
|
||||||
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
|
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
|
||||||
destination: {
|
destination: {
|
||||||
@ -248,27 +193,6 @@ const addGlobalT2IAdaptersToGraph = async (
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
upsertMetadata(graph, { t2iAdapters: t2iAdapterMetadata });
|
|
||||||
};
|
|
||||||
|
|
||||||
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfigV2): S['IPAdapterMetadataField'] => {
|
|
||||||
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
|
|
||||||
|
|
||||||
assert(model, 'IP Adapter model is required');
|
|
||||||
assert(image, 'IP Adapter image is required');
|
|
||||||
|
|
||||||
return {
|
|
||||||
ip_adapter_model: model,
|
|
||||||
clip_vision_model: clipVisionModel,
|
|
||||||
weight,
|
|
||||||
method,
|
|
||||||
begin_step_percent: beginEndStepPct[0],
|
|
||||||
end_step_percent: beginEndStepPct[1],
|
|
||||||
image: {
|
|
||||||
image_name: image.name,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const addIPAdapterCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
|
const addIPAdapterCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||||
@ -300,7 +224,6 @@ const addGlobalIPAdaptersToGraph = async (
|
|||||||
if (ipAdapters.length === 0) {
|
if (ipAdapters.length === 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
|
|
||||||
addIPAdapterCollectorSafe(graph, denoiseNodeId);
|
addIPAdapterCollectorSafe(graph, denoiseNodeId);
|
||||||
|
|
||||||
for (const ipAdapter of ipAdapters) {
|
for (const ipAdapter of ipAdapters) {
|
||||||
@ -325,8 +248,6 @@ const addGlobalIPAdaptersToGraph = async (
|
|||||||
|
|
||||||
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
|
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
|
||||||
|
|
||||||
ipAdapterMetdata.push(buildIPAdapterMetadata(ipAdapter));
|
|
||||||
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
||||||
destination: {
|
destination: {
|
||||||
@ -335,16 +256,131 @@ const addGlobalIPAdaptersToGraph = async (
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
upsertMetadata(graph, { ipAdapters: ipAdapterMetdata });
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export const addControlLayersToGraph = async (state: RootState, graph: NonNullableGraph, denoiseNodeId: string) => {
|
const addInitialImageLayerToGraph = (
|
||||||
|
state: RootState,
|
||||||
|
graph: NonNullableGraph,
|
||||||
|
denoiseNodeId: string,
|
||||||
|
layer: InitialImageLayer
|
||||||
|
) => {
|
||||||
|
const { vaePrecision, model } = state.generation;
|
||||||
|
const { refinerModel, refinerStart } = state.sdxl;
|
||||||
|
const { width, height } = state.controlLayers.present.size;
|
||||||
|
assert(layer.isEnabled, 'Initial image layer is not enabled');
|
||||||
|
assert(layer.image, 'Initial image layer has no image');
|
||||||
|
|
||||||
|
const isSDXL = model?.base === 'sdxl';
|
||||||
|
const useRefinerStartEnd = isSDXL && Boolean(refinerModel);
|
||||||
|
|
||||||
|
const denoiseNode = graph.nodes[denoiseNodeId];
|
||||||
|
assert(denoiseNode?.type === 'denoise_latents', `Missing denoise node or incorrect type: ${denoiseNode?.type}`);
|
||||||
|
|
||||||
|
const { denoisingStrength } = layer;
|
||||||
|
denoiseNode.denoising_start = useRefinerStartEnd
|
||||||
|
? Math.min(refinerStart, 1 - denoisingStrength)
|
||||||
|
: 1 - denoisingStrength;
|
||||||
|
denoiseNode.denoising_end = useRefinerStartEnd ? refinerStart : 1;
|
||||||
|
|
||||||
|
// We conditionally hook the image in depending on if a resize is needed
|
||||||
|
const i2lNode: ImageToLatentsInvocation = {
|
||||||
|
type: 'i2l',
|
||||||
|
id: IMAGE_TO_LATENTS,
|
||||||
|
is_intermediate: true,
|
||||||
|
use_cache: true,
|
||||||
|
fp32: vaePrecision === 'fp32',
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.nodes[i2lNode.id] = i2lNode;
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: IMAGE_TO_LATENTS,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: denoiseNode.id,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (layer.image.width !== width || layer.image.height !== height) {
|
||||||
|
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
||||||
|
|
||||||
|
// Create a resize node, explicitly setting its image
|
||||||
|
const resizeNode: ImageResizeInvocation = {
|
||||||
|
id: RESIZE,
|
||||||
|
type: 'img_resize',
|
||||||
|
image: {
|
||||||
|
image_name: layer.image.name,
|
||||||
|
},
|
||||||
|
is_intermediate: true,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.nodes[RESIZE] = resizeNode;
|
||||||
|
|
||||||
|
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RESIZE, field: 'image' },
|
||||||
|
destination: {
|
||||||
|
node_id: IMAGE_TO_LATENTS,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// The `RESIZE` node also passes its width and height to `NOISE`
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RESIZE, field: 'width' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'width',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RESIZE, field: 'height' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'height',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||||
|
i2lNode.image = {
|
||||||
|
image_name: layer.image.name,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Pass the image's dimensions to the `NOISE` node
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'width',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'height',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
upsertMetadata(graph, { generation_mode: isSDXL ? 'sdxl_img2img' : 'img2img' });
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addControlLayersToGraph = async (
|
||||||
|
state: RootState,
|
||||||
|
graph: NonNullableGraph,
|
||||||
|
denoiseNodeId: string
|
||||||
|
): Promise<Layer[]> => {
|
||||||
const mainModel = state.generation.model;
|
const mainModel = state.generation.model;
|
||||||
assert(mainModel, 'Missing main model when building graph');
|
assert(mainModel, 'Missing main model when building graph');
|
||||||
const isSDXL = mainModel.base === 'sdxl';
|
const isSDXL = mainModel.base === 'sdxl';
|
||||||
|
|
||||||
const layersMetadata: Layer[] = [];
|
const validLayers: Layer[] = [];
|
||||||
|
|
||||||
// Add global control adapters
|
// Add global control adapters
|
||||||
const validControlAdapterLayers = state.controlLayers.present.layers
|
const validControlAdapterLayers = state.controlLayers.present.layers
|
||||||
@ -366,6 +402,8 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
|
|||||||
const validT2IAdapters = validControlAdapterLayers.map((l) => l.controlAdapter).filter(isT2IAdapterConfigV2);
|
const validT2IAdapters = validControlAdapterLayers.map((l) => l.controlAdapter).filter(isT2IAdapterConfigV2);
|
||||||
addGlobalT2IAdaptersToGraph(validT2IAdapters, graph, denoiseNodeId);
|
addGlobalT2IAdaptersToGraph(validT2IAdapters, graph, denoiseNodeId);
|
||||||
|
|
||||||
|
validLayers.push(...validControlAdapterLayers);
|
||||||
|
|
||||||
const validIPAdapterLayers = state.controlLayers.present.layers
|
const validIPAdapterLayers = state.controlLayers.present.layers
|
||||||
// Must be an IP Adapter layer
|
// Must be an IP Adapter layer
|
||||||
.filter(isIPAdapterLayer)
|
.filter(isIPAdapterLayer)
|
||||||
@ -381,6 +419,21 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
|
|||||||
});
|
});
|
||||||
const validIPAdapters = validIPAdapterLayers.map((l) => l.ipAdapter);
|
const validIPAdapters = validIPAdapterLayers.map((l) => l.ipAdapter);
|
||||||
addGlobalIPAdaptersToGraph(validIPAdapters, graph, denoiseNodeId);
|
addGlobalIPAdaptersToGraph(validIPAdapters, graph, denoiseNodeId);
|
||||||
|
validLayers.push(...validIPAdapterLayers);
|
||||||
|
|
||||||
|
const initialImageLayer = state.controlLayers.present.layers.filter(isInitialImageLayer).find((l) => {
|
||||||
|
if (!l.isEnabled) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!l.image) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
if (initialImageLayer) {
|
||||||
|
addInitialImageLayerToGraph(state, graph, denoiseNodeId, initialImageLayer);
|
||||||
|
validLayers.push(initialImageLayer);
|
||||||
|
}
|
||||||
|
|
||||||
const validRGLayers = state.controlLayers.present.layers
|
const validRGLayers = state.controlLayers.present.layers
|
||||||
// Only RG layers are get masks
|
// Only RG layers are get masks
|
||||||
@ -393,8 +446,7 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
|
|||||||
const hasIPAdapter = l.ipAdapters.filter((ipa) => ipa.image).length > 0;
|
const hasIPAdapter = l.ipAdapters.filter((ipa) => ipa.image).length > 0;
|
||||||
return hasTextPrompt || hasIPAdapter;
|
return hasTextPrompt || hasIPAdapter;
|
||||||
});
|
});
|
||||||
|
validLayers.push(...validRGLayers);
|
||||||
layersMetadata.push(...validRGLayers, ...validControlAdapterLayers, ...validIPAdapterLayers);
|
|
||||||
|
|
||||||
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
|
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
|
||||||
// the existing conditioning nodes.
|
// the existing conditioning nodes.
|
||||||
@ -660,7 +712,8 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
upsertMetadata(graph, { layers: layersMetadata });
|
upsertMetadata(graph, { layers: validLayers });
|
||||||
|
return validLayers;
|
||||||
};
|
};
|
||||||
|
|
||||||
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {
|
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
|
import { isInitialImageLayer, isRegionalGuidanceLayer } from 'features/controlLayers/store/controlLayersSlice';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
|
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
|
||||||
import { addInitialImageToLinearGraph } from 'features/nodes/util/graph/addInitialImageToLinearGraph';
|
|
||||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||||
|
|
||||||
@ -232,24 +232,24 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<NonNull
|
|||||||
LATENTS_TO_IMAGE
|
LATENTS_TO_IMAGE
|
||||||
);
|
);
|
||||||
|
|
||||||
const didAddInitialImage = addInitialImageToLinearGraph(state, graph, DENOISE_LATENTS);
|
|
||||||
|
|
||||||
// Add Seamless To Graph
|
// Add Seamless To Graph
|
||||||
if (seamlessXAxis || seamlessYAxis) {
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
|
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
|
||||||
modelLoaderNodeId = SEAMLESS;
|
modelLoaderNodeId = SEAMLESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
// optionally add custom VAE
|
|
||||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
|
||||||
|
|
||||||
// add LoRA support
|
// add LoRA support
|
||||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||||
|
|
||||||
await addControlLayersToGraph(state, graph, DENOISE_LATENTS);
|
const addedLayers = await addControlLayersToGraph(state, graph, DENOISE_LATENTS);
|
||||||
|
|
||||||
|
// optionally add custom VAE
|
||||||
|
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||||
|
|
||||||
|
const shouldUseHRF = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l));
|
||||||
// High resolution fix.
|
// High resolution fix.
|
||||||
if (state.hrf.hrfEnabled && !didAddInitialImage) {
|
if (state.hrf.hrfEnabled && shouldUseHRF) {
|
||||||
|
console.log('HRFING');
|
||||||
addHrfToGraph(state, graph);
|
addHrfToGraph(state, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@ import { logger } from 'app/logging/logger';
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
|
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
|
||||||
import { addInitialImageToLinearGraph } from 'features/nodes/util/graph/addInitialImageToLinearGraph';
|
|
||||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||||
|
|
||||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||||
@ -242,8 +241,6 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
|||||||
LATENTS_TO_IMAGE
|
LATENTS_TO_IMAGE
|
||||||
);
|
);
|
||||||
|
|
||||||
addInitialImageToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
|
||||||
|
|
||||||
// Add Seamless To Graph
|
// Add Seamless To Graph
|
||||||
if (seamlessXAxis || seamlessYAxis) {
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
|
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
|
||||||
@ -258,14 +255,14 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// optionally add custom VAE
|
|
||||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
|
||||||
|
|
||||||
// add LoRA support
|
// add LoRA support
|
||||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||||
|
|
||||||
await addControlLayersToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addControlLayersToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
|
// optionally add custom VAE
|
||||||
|
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||||
|
|
||||||
// NSFW & watermark - must be last thing added to graph
|
// NSFW & watermark - must be last thing added to graph
|
||||||
if (state.system.shouldUseNSFWChecker) {
|
if (state.system.shouldUseNSFWChecker) {
|
||||||
// must add before watermarker!
|
// must add before watermarker!
|
||||||
|
Loading…
x
Reference in New Issue
Block a user