tidy(ui): remove unused files

This commit is contained in:
psychedelicious 2024-05-13 21:01:57 +10:00
parent 5425526d50
commit 4897ce2a13
10 changed files with 274 additions and 2011 deletions

View File

@ -1,8 +1,8 @@
import { enqueueRequested } from 'app/store/actions'; import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice'; import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { buildGenerationTabGraph2 } from 'features/nodes/util/graph/buildGenerationTabGraph2'; import { buildGenerationTabGraph } from 'features/nodes/util/graph/buildGenerationTabGraph';
import { buildGenerationTabSDXLGraph2 } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph2'; import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
@ -18,10 +18,10 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
let graph; let graph;
if (model && model.base === 'sdxl') { if (model?.base === 'sdxl') {
graph = await buildGenerationTabSDXLGraph2(state); graph = await buildGenerationTabSDXLGraph(state);
} else { } else {
graph = await buildGenerationTabGraph2(state); graph = await buildGenerationTabGraph(state);
} }
const batchConfig = prepareLinearUIBatch(state, graph, prepend); const batchConfig = prepareLinearUIBatch(state, graph, prepend);

View File

@ -1,705 +0,0 @@
import { getStore } from 'app/store/nanostores/store';
import type { RootState } from 'app/store/store';
import {
isControlAdapterLayer,
isInitialImageLayer,
isIPAdapterLayer,
isRegionalGuidanceLayer,
rgLayerMaskImageUploaded,
} from 'features/controlLayers/store/controlLayersSlice';
import type { InitialImageLayer, Layer, RegionalGuidanceLayer } from 'features/controlLayers/store/types';
import type {
ControlNetConfigV2,
ImageWithDims,
IPAdapterConfigV2,
ProcessorConfig,
T2IAdapterConfigV2,
} from 'features/controlLayers/util/controlAdapters';
import { getRegionalPromptLayerBlobs } from 'features/controlLayers/util/getLayerBlobs';
import type { ImageField } from 'features/nodes/types/common';
import {
CONTROL_NET_COLLECT,
IMAGE_TO_LATENTS,
IP_ADAPTER_COLLECT,
NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT,
NOISE,
POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
PROMPT_REGION_MASK_TO_TENSOR_PREFIX,
PROMPT_REGION_NEGATIVE_COND_PREFIX,
PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX,
PROMPT_REGION_POSITIVE_COND_PREFIX,
RESIZE,
T2I_ADAPTER_COLLECT,
} from 'features/nodes/util/graph/constants';
import { upsertMetadata } from 'features/nodes/util/graph/metadata';
import { size } from 'lodash-es';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type {
BaseModelType,
CollectInvocation,
ControlNetInvocation,
Edge,
ImageDTO,
Invocation,
IPAdapterInvocation,
NonNullableGraph,
T2IAdapterInvocation,
} from 'services/api/types';
import { assert } from 'tsafe';
export const addControlLayersToGraph = async (
state: RootState,
graph: NonNullableGraph,
denoiseNodeId: string
): Promise<Layer[]> => {
const mainModel = state.generation.model;
assert(mainModel, 'Missing main model when building graph');
const isSDXL = mainModel.base === 'sdxl';
// Filter out layers with incompatible base model, missing control image
const validLayers = state.controlLayers.present.layers.filter((l) => isValidLayer(l, mainModel.base));
const validControlAdapters = validLayers.filter(isControlAdapterLayer).map((l) => l.controlAdapter);
for (const ca of validControlAdapters) {
addGlobalControlAdapterToGraph(ca, graph, denoiseNodeId);
}
const validIPAdapters = validLayers.filter(isIPAdapterLayer).map((l) => l.ipAdapter);
for (const ipAdapter of validIPAdapters) {
addGlobalIPAdapterToGraph(ipAdapter, graph, denoiseNodeId);
}
const initialImageLayers = validLayers.filter(isInitialImageLayer);
assert(initialImageLayers.length <= 1, 'Only one initial image layer allowed');
if (initialImageLayers[0]) {
addInitialImageLayerToGraph(state, graph, denoiseNodeId, initialImageLayers[0]);
}
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
// the existing conditioning nodes.
// With regional prompts we have multiple conditioning nodes which much be routed into collectors. Set those up
const posCondCollectNode: CollectInvocation = {
id: POSITIVE_CONDITIONING_COLLECT,
type: 'collect',
};
graph.nodes[POSITIVE_CONDITIONING_COLLECT] = posCondCollectNode;
const negCondCollectNode: CollectInvocation = {
id: NEGATIVE_CONDITIONING_COLLECT,
type: 'collect',
};
graph.nodes[NEGATIVE_CONDITIONING_COLLECT] = negCondCollectNode;
// Re-route the denoise node's OG conditioning inputs to the collect nodes
const newEdges: Edge[] = [];
for (const edge of graph.edges) {
if (edge.destination.node_id === denoiseNodeId && edge.destination.field === 'positive_conditioning') {
newEdges.push({
source: edge.source,
destination: {
node_id: POSITIVE_CONDITIONING_COLLECT,
field: 'item',
},
});
} else if (edge.destination.node_id === denoiseNodeId && edge.destination.field === 'negative_conditioning') {
newEdges.push({
source: edge.source,
destination: {
node_id: NEGATIVE_CONDITIONING_COLLECT,
field: 'item',
},
});
} else {
newEdges.push(edge);
}
}
graph.edges = newEdges;
// Connect collectors to the denoise nodes - must happen _after_ rerouting else you get cycles
graph.edges.push({
source: {
node_id: POSITIVE_CONDITIONING_COLLECT,
field: 'collection',
},
destination: {
node_id: denoiseNodeId,
field: 'positive_conditioning',
},
});
graph.edges.push({
source: {
node_id: NEGATIVE_CONDITIONING_COLLECT,
field: 'collection',
},
destination: {
node_id: denoiseNodeId,
field: 'negative_conditioning',
},
});
const validRGLayers = validLayers.filter(isRegionalGuidanceLayer);
const layerIds = validRGLayers.map((l) => l.id);
const blobs = await getRegionalPromptLayerBlobs(layerIds);
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
for (const layer of validRGLayers) {
const blob = blobs[layer.id];
assert(blob, `Blob for layer ${layer.id} not found`);
// Upload the mask image, or get the cached image if it exists
const { image_name } = await getMaskImage(layer, blob);
// The main mask-to-tensor node
const maskToTensorNode: Invocation<'alpha_mask_to_tensor'> = {
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${layer.id}`,
type: 'alpha_mask_to_tensor',
image: {
image_name,
},
};
graph.nodes[maskToTensorNode.id] = maskToTensorNode;
if (layer.positivePrompt) {
// The main positive conditioning node
const regionalPositiveCondNode: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'> = isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
style: layer.positivePrompt, // TODO: Should we put the positive prompt in both fields?
}
: {
type: 'compel',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
};
graph.nodes[regionalPositiveCondNode.id] = regionalPositiveCondNode;
// Connect the mask to the conditioning
graph.edges.push({
source: { node_id: maskToTensorNode.id, field: 'mask' },
destination: { node_id: regionalPositiveCondNode.id, field: 'mask' },
});
// Connect the conditioning to the collector
graph.edges.push({
source: { node_id: regionalPositiveCondNode.id, field: 'conditioning' },
destination: { node_id: posCondCollectNode.id, field: 'item' },
});
// Copy the connections to the "global" positive conditioning node to the regional cond
for (const edge of graph.edges) {
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalPositiveCondNode.id, field: edge.destination.field },
});
}
}
}
if (layer.negativePrompt) {
// The main negative conditioning node
const regionalNegativeCondNode: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'> = isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.negativePrompt,
style: layer.negativePrompt,
}
: {
type: 'compel',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.negativePrompt,
};
graph.nodes[regionalNegativeCondNode.id] = regionalNegativeCondNode;
// Connect the mask to the conditioning
graph.edges.push({
source: { node_id: maskToTensorNode.id, field: 'mask' },
destination: { node_id: regionalNegativeCondNode.id, field: 'mask' },
});
// Connect the conditioning to the collector
graph.edges.push({
source: { node_id: regionalNegativeCondNode.id, field: 'conditioning' },
destination: { node_id: negCondCollectNode.id, field: 'item' },
});
// Copy the connections to the "global" negative conditioning node to the regional cond
for (const edge of graph.edges) {
if (edge.destination.node_id === NEGATIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalNegativeCondNode.id, field: edge.destination.field },
});
}
}
}
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
if (layer.autoNegative === 'invert' && layer.positivePrompt) {
// We re-use the mask image, but invert it when converting to tensor
const invertTensorMaskNode: Invocation<'invert_tensor_mask'> = {
id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${layer.id}`,
type: 'invert_tensor_mask',
};
graph.nodes[invertTensorMaskNode.id] = invertTensorMaskNode;
// Connect the OG mask image to the inverted mask-to-tensor node
graph.edges.push({
source: {
node_id: maskToTensorNode.id,
field: 'mask',
},
destination: {
node_id: invertTensorMaskNode.id,
field: 'mask',
},
});
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the
// positive prompt
const regionalPositiveCondInvertedNode: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'> = isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
style: layer.positivePrompt,
}
: {
type: 'compel',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
};
graph.nodes[regionalPositiveCondInvertedNode.id] = regionalPositiveCondInvertedNode;
// Connect the inverted mask to the conditioning
graph.edges.push({
source: { node_id: invertTensorMaskNode.id, field: 'mask' },
destination: { node_id: regionalPositiveCondInvertedNode.id, field: 'mask' },
});
// Connect the conditioning to the negative collector
graph.edges.push({
source: { node_id: regionalPositiveCondInvertedNode.id, field: 'conditioning' },
destination: { node_id: negCondCollectNode.id, field: 'item' },
});
// Copy the connections to the "global" positive conditioning node to our regional node
for (const edge of graph.edges) {
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalPositiveCondInvertedNode.id, field: edge.destination.field },
});
}
}
}
const validRegionalIPAdapters: IPAdapterConfigV2[] = layer.ipAdapters.filter((ipa) =>
isValidIPAdapter(ipa, mainModel.base)
);
for (const ipAdapter of validRegionalIPAdapters) {
addIPAdapterCollectorSafe(graph, denoiseNodeId);
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
assert(model, 'IP Adapter model is required');
assert(image, 'IP Adapter image is required');
const ipAdapterNode: IPAdapterInvocation = {
id: `ip_adapter_${id}`,
type: 'ip_adapter',
is_intermediate: true,
weight,
method,
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
image: {
image_name: image.name,
},
};
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
// Connect the mask to the conditioning
graph.edges.push({
source: { node_id: maskToTensorNode.id, field: 'mask' },
destination: { node_id: ipAdapterNode.id, field: 'mask' },
});
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
destination: {
node_id: IP_ADAPTER_COLLECT,
field: 'item',
},
});
}
}
upsertMetadata(graph, { control_layers: { layers: validLayers, version: state.controlLayers.present._version } });
return validLayers;
};
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {
if (layer.uploadedMaskImage) {
const imageDTO = await getImageDTO(layer.uploadedMaskImage.name);
if (imageDTO) {
return imageDTO;
}
}
const { dispatch } = getStore();
// No cached mask, or the cached image no longer exists - we need to upload the mask image
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();
const imageDTO = await req.unwrap();
dispatch(rgLayerMaskImageUploaded({ layerId: layer.id, imageDTO }));
return imageDTO;
};
const buildControlImage = (
image: ImageWithDims | null,
processedImage: ImageWithDims | null,
processorConfig: ProcessorConfig | null
): ImageField => {
if (processedImage && processorConfig) {
// We've processed the image in the app - use it for the control image.
return {
image_name: processedImage.name,
};
} else if (image) {
// No processor selected, and we have an image - the user provided a processed image, use it for the control image.
return {
image_name: image.name,
};
}
assert(false, 'Attempted to add unprocessed control image');
};
const addGlobalControlAdapterToGraph = (
controlAdapter: ControlNetConfigV2 | T2IAdapterConfigV2,
graph: NonNullableGraph,
denoiseNodeId: string
) => {
if (controlAdapter.type === 'controlnet') {
addGlobalControlNetToGraph(controlAdapter, graph, denoiseNodeId);
}
if (controlAdapter.type === 't2i_adapter') {
addGlobalT2IAdapterToGraph(controlAdapter, graph, denoiseNodeId);
}
};
const addControlNetCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
if (graph.nodes[CONTROL_NET_COLLECT]) {
// You see, we've already got one!
return;
}
// Add the ControlNet collector
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,
type: 'collect',
is_intermediate: true,
};
graph.nodes[CONTROL_NET_COLLECT] = controlNetIterateNode;
graph.edges.push({
source: { node_id: CONTROL_NET_COLLECT, field: 'collection' },
destination: {
node_id: denoiseNodeId,
field: 'control',
},
});
};
const addGlobalControlNetToGraph = (controlNet: ControlNetConfigV2, graph: NonNullableGraph, denoiseNodeId: string) => {
const { id, beginEndStepPct, controlMode, image, model, processedImage, processorConfig, weight } = controlNet;
assert(model, 'ControlNet model is required');
const controlImage = buildControlImage(image, processedImage, processorConfig);
addControlNetCollectorSafe(graph, denoiseNodeId);
const controlNetNode: ControlNetInvocation = {
id: `control_net_${id}`,
type: 'controlnet',
is_intermediate: true,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
control_mode: controlMode,
resize_mode: 'just_resize',
control_model: model,
control_weight: weight,
image: controlImage,
};
graph.nodes[controlNetNode.id] = controlNetNode;
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
destination: {
node_id: CONTROL_NET_COLLECT,
field: 'item',
},
});
};
const addT2IAdapterCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
if (graph.nodes[T2I_ADAPTER_COLLECT]) {
// You see, we've already got one!
return;
}
// Even though denoise_latents' t2i adapter input is collection or scalar, keep it simple and always use a collect
const t2iAdapterCollectNode: CollectInvocation = {
id: T2I_ADAPTER_COLLECT,
type: 'collect',
is_intermediate: true,
};
graph.nodes[T2I_ADAPTER_COLLECT] = t2iAdapterCollectNode;
graph.edges.push({
source: { node_id: T2I_ADAPTER_COLLECT, field: 'collection' },
destination: {
node_id: denoiseNodeId,
field: 't2i_adapter',
},
});
};
const addGlobalT2IAdapterToGraph = (t2iAdapter: T2IAdapterConfigV2, graph: NonNullableGraph, denoiseNodeId: string) => {
const { id, beginEndStepPct, image, model, processedImage, processorConfig, weight } = t2iAdapter;
assert(model, 'T2I Adapter model is required');
const controlImage = buildControlImage(image, processedImage, processorConfig);
addT2IAdapterCollectorSafe(graph, denoiseNodeId);
const t2iAdapterNode: T2IAdapterInvocation = {
id: `t2i_adapter_${id}`,
type: 't2i_adapter',
is_intermediate: true,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
resize_mode: 'just_resize',
t2i_adapter_model: model,
weight: weight,
image: controlImage,
};
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;
graph.edges.push({
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
destination: {
node_id: T2I_ADAPTER_COLLECT,
field: 'item',
},
});
};
const addIPAdapterCollectorSafe = (graph: NonNullableGraph, denoiseNodeId: string) => {
if (graph.nodes[IP_ADAPTER_COLLECT]) {
// You see, we've already got one!
return;
}
const ipAdapterCollectNode: CollectInvocation = {
id: IP_ADAPTER_COLLECT,
type: 'collect',
is_intermediate: true,
};
graph.nodes[IP_ADAPTER_COLLECT] = ipAdapterCollectNode;
graph.edges.push({
source: { node_id: IP_ADAPTER_COLLECT, field: 'collection' },
destination: {
node_id: denoiseNodeId,
field: 'ip_adapter',
},
});
};
const addGlobalIPAdapterToGraph = (ipAdapter: IPAdapterConfigV2, graph: NonNullableGraph, denoiseNodeId: string) => {
addIPAdapterCollectorSafe(graph, denoiseNodeId);
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
assert(image, 'IP Adapter image is required');
assert(model, 'IP Adapter model is required');
const ipAdapterNode: IPAdapterInvocation = {
id: `ip_adapter_${id}`,
type: 'ip_adapter',
is_intermediate: true,
weight,
method,
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
image: {
image_name: image.name,
},
};
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
destination: {
node_id: IP_ADAPTER_COLLECT,
field: 'item',
},
});
};
const addInitialImageLayerToGraph = (
state: RootState,
graph: NonNullableGraph,
denoiseNodeId: string,
layer: InitialImageLayer
) => {
const { vaePrecision, model } = state.generation;
const { refinerModel, refinerStart } = state.sdxl;
const { width, height } = state.controlLayers.present.size;
assert(layer.isEnabled, 'Initial image layer is not enabled');
assert(layer.image, 'Initial image layer has no image');
const isSDXL = model?.base === 'sdxl';
const useRefinerStartEnd = isSDXL && Boolean(refinerModel);
const denoiseNode = graph.nodes[denoiseNodeId];
assert(denoiseNode?.type === 'denoise_latents', `Missing denoise node or incorrect type: ${denoiseNode?.type}`);
const { denoisingStrength } = layer;
denoiseNode.denoising_start = useRefinerStartEnd
? Math.min(refinerStart, 1 - denoisingStrength)
: 1 - denoisingStrength;
denoiseNode.denoising_end = useRefinerStartEnd ? refinerStart : 1;
const i2lNode: Invocation<'i2l'> = {
type: 'i2l',
id: IMAGE_TO_LATENTS,
is_intermediate: true,
use_cache: true,
fp32: vaePrecision === 'fp32',
};
graph.nodes[i2lNode.id] = i2lNode;
graph.edges.push({
source: {
node_id: IMAGE_TO_LATENTS,
field: 'latents',
},
destination: {
node_id: denoiseNode.id,
field: 'latents',
},
});
if (layer.image.width !== width || layer.image.height !== height) {
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
// Create a resize node, explicitly setting its image
const resizeNode: Invocation<'img_resize'> = {
id: RESIZE,
type: 'img_resize',
image: {
image_name: layer.image.name,
},
is_intermediate: true,
width,
height,
};
graph.nodes[RESIZE] = resizeNode;
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
graph.edges.push({
source: { node_id: RESIZE, field: 'image' },
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'image',
},
});
// The `RESIZE` node also passes its width and height to `NOISE`
graph.edges.push({
source: { node_id: RESIZE, field: 'width' },
destination: {
node_id: NOISE,
field: 'width',
},
});
graph.edges.push({
source: { node_id: RESIZE, field: 'height' },
destination: {
node_id: NOISE,
field: 'height',
},
});
} else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
i2lNode.image = {
image_name: layer.image.name,
};
// Pass the image's dimensions to the `NOISE` node
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
destination: {
node_id: NOISE,
field: 'width',
},
});
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
destination: {
node_id: NOISE,
field: 'height',
},
});
}
upsertMetadata(graph, { generation_mode: isSDXL ? 'sdxl_img2img' : 'img2img' });
};
const isValidControlAdapter = (ca: ControlNetConfigV2 | T2IAdapterConfigV2, base: BaseModelType): boolean => {
// Must be have a model that matches the current base and must have a control image
const hasModel = Boolean(ca.model);
const modelMatchesBase = ca.model?.base === base;
const hasControlImage = Boolean(ca.image || (ca.processedImage && ca.processorConfig));
return hasModel && modelMatchesBase && hasControlImage;
};
const isValidIPAdapter = (ipa: IPAdapterConfigV2, base: BaseModelType): boolean => {
// Must be have a model that matches the current base and must have a control image
const hasModel = Boolean(ipa.model);
const modelMatchesBase = ipa.model?.base === base;
const hasImage = Boolean(ipa.image);
return hasModel && modelMatchesBase && hasImage;
};
const isValidLayer = (layer: Layer, base: BaseModelType) => {
if (!layer.isEnabled) {
return false;
}
if (isControlAdapterLayer(layer)) {
return isValidControlAdapter(layer.controlAdapter, base);
}
if (isIPAdapterLayer(layer)) {
return isValidIPAdapter(layer.ipAdapter, base);
}
if (isInitialImageLayer(layer)) {
if (!layer.image) {
return false;
}
return true;
}
if (isRegionalGuidanceLayer(layer)) {
if (layer.maskObjects.length === 0) {
// Layer has no mask, meaning any guidance would be applied to an empty region.
return false;
}
const hasTextPrompt = Boolean(layer.positivePrompt) || Boolean(layer.negativePrompt);
const hasIPAdapter = layer.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)).length > 0;
return hasTextPrompt || hasIPAdapter;
}
return false;
};

View File

@ -1,79 +0,0 @@
import type { RootState } from 'app/store/store';
import { isInitialImageLayer } from 'features/controlLayers/store/controlLayersSlice';
import type { ImageField } from 'features/nodes/types/common';
import type { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import type { Invocation } from 'services/api/types';
import { IMAGE_TO_LATENTS, RESIZE } from './constants';
/**
* Adds the initial image to the graph and connects it to the denoise and noise nodes.
* @param state The current Redux state
* @param g The graph to add the initial image to
* @param denoise The denoise node in the graph
* @param noise The noise node in the graph
* @returns Whether the initial image was added to the graph
*/
export const addGenerationTabInitialImage = (
state: RootState,
g: Graph,
denoise: Invocation<'denoise_latents'>,
noise: Invocation<'noise'>
): Invocation<'i2l'> | null => {
// Remove Existing UNet Connections
const { img2imgStrength, vaePrecision, model } = state.generation;
const { refinerModel, refinerStart } = state.sdxl;
const { width, height } = state.controlLayers.present.size;
const initialImageLayer = state.controlLayers.present.layers.find(isInitialImageLayer);
const initialImage = initialImageLayer?.isEnabled ? initialImageLayer?.image : null;
if (!initialImage) {
return null;
}
const isSDXL = model?.base === 'sdxl';
const useRefinerStartEnd = isSDXL && Boolean(refinerModel);
const image: ImageField = {
image_name: initialImage.imageName,
};
denoise.denoising_start = useRefinerStartEnd ? Math.min(refinerStart, 1 - img2imgStrength) : 1 - img2imgStrength;
denoise.denoising_end = useRefinerStartEnd ? refinerStart : 1;
const i2l = g.addNode({
type: 'i2l',
id: IMAGE_TO_LATENTS,
fp32: vaePrecision === 'fp32',
});
g.addEdge(i2l, 'latents', denoise, 'latents');
if (initialImage.width !== width || initialImage.height !== height) {
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
const resize = g.addNode({
id: RESIZE,
type: 'img_resize',
image,
width,
height,
});
// The `RESIZE` node then passes its image, to `IMAGE_TO_LATENTS`
g.addEdge(resize, 'image', i2l, 'image');
// The `RESIZE` node also passes its width and height to `NOISE`
g.addEdge(resize, 'width', noise, 'width');
g.addEdge(resize, 'height', noise, 'height');
} else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
i2l.image = image;
g.addEdge(i2l, 'width', noise, 'width');
g.addEdge(i2l, 'height', noise, 'height');
}
MetadataUtil.add(g, {
generation_mode: isSDXL ? 'sdxl_img2img' : 'img2img',
strength: img2imgStrength,
init_image: initialImage.imageName,
});
return i2l;
};

View File

@ -1,37 +0,0 @@
import type { RootState } from 'app/store/store';
import type { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import type { Invocation } from 'services/api/types';
import { VAE_LOADER } from './constants';
export const addGenerationTabVAE = (
state: RootState,
g: Graph,
modelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'>,
l2i: Invocation<'l2i'>,
i2l: Invocation<'i2l'> | null,
seamless: Invocation<'seamless'> | null
): void => {
const { vae } = state.generation;
// The seamless helper also adds the VAE loader... so we need to check if it's already there
const shouldAddVAELoader = !g.hasNode(VAE_LOADER) && vae;
const vaeLoader = shouldAddVAELoader
? g.addNode({
type: 'vae_loader',
id: VAE_LOADER,
vae_model: vae,
})
: null;
const vaeSource = seamless ? seamless : vaeLoader ? vaeLoader : modelLoader;
g.addEdge(vaeSource, 'vae', l2i, 'vae');
if (i2l) {
g.addEdge(vaeSource, 'vae', i2l, 'vae');
}
if (vae) {
MetadataUtil.add(g, { vae });
}
};

View File

@ -1,356 +0,0 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import type {
DenoiseLatentsInvocation,
Edge,
ESRGANInvocation,
LatentsToImageInvocation,
NoiseInvocation,
NonNullableGraph,
} from 'services/api/types';
import {
DENOISE_LATENTS,
DENOISE_LATENTS_HRF,
ESRGAN_HRF,
IMAGE_TO_LATENTS_HRF,
LATENTS_TO_IMAGE,
LATENTS_TO_IMAGE_HRF_HR,
LATENTS_TO_IMAGE_HRF_LR,
MAIN_MODEL_LOADER,
NOISE,
NOISE_HRF,
RESIZE_HRF,
SEAMLESS,
VAE_LOADER,
} from './constants';
import { setMetadataReceivingNode, upsertMetadata } from './metadata';
// Copy certain connections from previous DENOISE_LATENTS to new DENOISE_LATENTS_HRF.
function copyConnectionsToDenoiseLatentsHrf(graph: NonNullableGraph): void {
const destinationFields = [
'control',
'ip_adapter',
'metadata',
'unet',
'positive_conditioning',
'negative_conditioning',
];
const newEdges: Edge[] = [];
// Loop through the existing edges connected to DENOISE_LATENTS
graph.edges.forEach((edge: Edge) => {
if (edge.destination.node_id === DENOISE_LATENTS && destinationFields.includes(edge.destination.field)) {
// Add a similar connection to DENOISE_LATENTS_HRF
newEdges.push({
source: {
node_id: edge.source.node_id,
field: edge.source.field,
},
destination: {
node_id: DENOISE_LATENTS_HRF,
field: edge.destination.field,
},
});
}
});
graph.edges = graph.edges.concat(newEdges);
}
/**
* Calculates the new resolution for high-resolution features (HRF) based on base model type.
* Adjusts the width and height to maintain the aspect ratio and constrains them by the model's dimension limits,
* rounding down to the nearest multiple of 8.
*
* @param {number} optimalDimension The optimal dimension for the base model.
* @param {number} width The current width to be adjusted for HRF.
* @param {number} height The current height to be adjusted for HRF.
* @return {{newWidth: number, newHeight: number}} The new width and height, adjusted and rounded as needed.
*/
function calculateHrfRes(
optimalDimension: number,
width: number,
height: number
): { newWidth: number; newHeight: number } {
const aspect = width / height;
const minDimension = Math.floor(optimalDimension * 0.5);
const modelArea = optimalDimension * optimalDimension; // Assuming square images for model_area
let initWidth;
let initHeight;
if (aspect > 1.0) {
initHeight = Math.max(minDimension, Math.sqrt(modelArea / aspect));
initWidth = initHeight * aspect;
} else {
initWidth = Math.max(minDimension, Math.sqrt(modelArea * aspect));
initHeight = initWidth / aspect;
}
// Cap initial height and width to final height and width.
initWidth = Math.min(width, initWidth);
initHeight = Math.min(height, initHeight);
const newWidth = roundToMultiple(Math.floor(initWidth), 8);
const newHeight = roundToMultiple(Math.floor(initHeight), 8);
return { newWidth, newHeight };
}
// Adds the high-res fix feature to the given graph.
export const addHrfToGraph = (state: RootState, graph: NonNullableGraph): void => {
// Double check hrf is enabled.
if (!state.hrf.hrfEnabled || state.config.disabledSDFeatures.includes('hrf')) {
return;
}
const log = logger('generation');
const { vae, seamlessXAxis, seamlessYAxis } = state.generation;
const { hrfStrength, hrfEnabled, hrfMethod } = state.hrf;
const { width, height } = state.controlLayers.present.size;
const isAutoVae = !vae;
const isSeamlessEnabled = seamlessXAxis || seamlessYAxis;
const optimalDimension = selectOptimalDimension(state);
const { newWidth: hrfWidth, newHeight: hrfHeight } = calculateHrfRes(optimalDimension, width, height);
// Pre-existing (original) graph nodes.
const originalDenoiseLatentsNode = graph.nodes[DENOISE_LATENTS] as DenoiseLatentsInvocation | undefined;
const originalNoiseNode = graph.nodes[NOISE] as NoiseInvocation | undefined;
const originalLatentsToImageNode = graph.nodes[LATENTS_TO_IMAGE] as LatentsToImageInvocation | undefined;
if (!originalDenoiseLatentsNode) {
log.error('originalDenoiseLatentsNode is undefined');
return;
}
if (!originalNoiseNode) {
log.error('originalNoiseNode is undefined');
return;
}
if (!originalLatentsToImageNode) {
log.error('originalLatentsToImageNode is undefined');
return;
}
// Change height and width of original noise node to initial resolution.
if (originalNoiseNode) {
originalNoiseNode.width = hrfWidth;
originalNoiseNode.height = hrfHeight;
}
// Define new nodes and their connections, roughly in order of operations.
graph.nodes[LATENTS_TO_IMAGE_HRF_LR] = {
type: 'l2i',
id: LATENTS_TO_IMAGE_HRF_LR,
fp32: originalLatentsToImageNode?.fp32,
is_intermediate: true,
};
graph.edges.push(
{
source: {
node_id: DENOISE_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE_HRF_LR,
field: 'latents',
},
},
{
source: {
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE_HRF_LR,
field: 'vae',
},
}
);
graph.nodes[RESIZE_HRF] = {
id: RESIZE_HRF,
type: 'img_resize',
is_intermediate: true,
width: width,
height: height,
};
if (hrfMethod === 'ESRGAN') {
let model_name: ESRGANInvocation['model_name'] = 'RealESRGAN_x2plus.pth';
if ((width * height) / (hrfWidth * hrfHeight) > 2) {
model_name = 'RealESRGAN_x4plus.pth';
}
graph.nodes[ESRGAN_HRF] = {
id: ESRGAN_HRF,
type: 'esrgan',
model_name,
is_intermediate: true,
};
graph.edges.push(
{
source: {
node_id: LATENTS_TO_IMAGE_HRF_LR,
field: 'image',
},
destination: {
node_id: ESRGAN_HRF,
field: 'image',
},
},
{
source: {
node_id: ESRGAN_HRF,
field: 'image',
},
destination: {
node_id: RESIZE_HRF,
field: 'image',
},
}
);
} else {
graph.edges.push({
source: {
node_id: LATENTS_TO_IMAGE_HRF_LR,
field: 'image',
},
destination: {
node_id: RESIZE_HRF,
field: 'image',
},
});
}
graph.nodes[NOISE_HRF] = {
type: 'noise',
id: NOISE_HRF,
seed: originalNoiseNode?.seed,
use_cpu: originalNoiseNode?.use_cpu,
is_intermediate: true,
};
graph.edges.push(
{
source: {
node_id: RESIZE_HRF,
field: 'height',
},
destination: {
node_id: NOISE_HRF,
field: 'height',
},
},
{
source: {
node_id: RESIZE_HRF,
field: 'width',
},
destination: {
node_id: NOISE_HRF,
field: 'width',
},
}
);
graph.nodes[IMAGE_TO_LATENTS_HRF] = {
type: 'i2l',
id: IMAGE_TO_LATENTS_HRF,
is_intermediate: true,
};
graph.edges.push(
{
source: {
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS_HRF,
field: 'vae',
},
},
{
source: {
node_id: RESIZE_HRF,
field: 'image',
},
destination: {
node_id: IMAGE_TO_LATENTS_HRF,
field: 'image',
},
}
);
graph.nodes[DENOISE_LATENTS_HRF] = {
type: 'denoise_latents',
id: DENOISE_LATENTS_HRF,
is_intermediate: true,
cfg_scale: originalDenoiseLatentsNode?.cfg_scale,
scheduler: originalDenoiseLatentsNode?.scheduler,
steps: originalDenoiseLatentsNode?.steps,
denoising_start: 1 - hrfStrength,
denoising_end: 1,
};
graph.edges.push(
{
source: {
node_id: IMAGE_TO_LATENTS_HRF,
field: 'latents',
},
destination: {
node_id: DENOISE_LATENTS_HRF,
field: 'latents',
},
},
{
source: {
node_id: NOISE_HRF,
field: 'noise',
},
destination: {
node_id: DENOISE_LATENTS_HRF,
field: 'noise',
},
}
);
copyConnectionsToDenoiseLatentsHrf(graph);
// The original l2i node is unnecessary now, remove it
graph.edges = graph.edges.filter((edge) => edge.destination.node_id !== LATENTS_TO_IMAGE);
delete graph.nodes[LATENTS_TO_IMAGE];
graph.nodes[LATENTS_TO_IMAGE_HRF_HR] = {
type: 'l2i',
id: LATENTS_TO_IMAGE_HRF_HR,
fp32: originalLatentsToImageNode?.fp32,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
};
graph.edges.push(
{
source: {
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE_HRF_HR,
field: 'vae',
},
},
{
source: {
node_id: DENOISE_LATENTS_HRF,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE_HRF_HR,
field: 'latents',
},
}
);
upsertMetadata(graph, {
hrf_strength: hrfStrength,
hrf_enabled: hrfEnabled,
hrf_method: hrfMethod,
});
setMetadataReceivingNode(graph, LATENTS_TO_IMAGE_HRF_HR);
};

View File

@ -1,17 +1,20 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { isInitialImageLayer, isRegionalGuidanceLayer } from 'features/controlLayers/store/controlLayersSlice'; import { isInitialImageLayer, isRegionalGuidanceLayer } from 'features/controlLayers/store/controlLayersSlice';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph'; import { addGenerationTabControlLayers } from 'features/nodes/util/graph/addGenerationTabControlLayers';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import { addGenerationTabHRF } from 'features/nodes/util/graph/addGenerationTabHRF';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types'; import { addGenerationTabLoRAs } from 'features/nodes/util/graph/addGenerationTabLoRAs';
import { addGenerationTabNSFWChecker } from 'features/nodes/util/graph/addGenerationTabNSFWChecker';
import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless';
import { addGenerationTabWatermarker } from 'features/nodes/util/graph/addGenerationTabWatermarker';
import type { GraphType } from 'features/nodes/util/graph/Graph';
import { Graph } from 'features/nodes/util/graph/Graph';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { addHrfToGraph } from './addHrfToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { import {
CLIP_SKIP, CLIP_SKIP,
CONTROL_LAYERS_GRAPH, CONTROL_LAYERS_GRAPH,
@ -19,202 +22,113 @@ import {
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER, MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
SEAMLESS, POSITIVE_CONDITIONING_COLLECT,
VAE_LOADER,
} from './constants'; } from './constants';
import { addCoreMetadataNode, getModelMetadataField } from './metadata'; import { getModelMetadataField } from './metadata';
export const buildGenerationTabGraph = async (state: RootState): Promise<NonNullableGraph> => { export const buildGenerationTabGraph = async (state: RootState): Promise<GraphType> => {
const log = logger('nodes');
const { const {
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier, cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler, scheduler,
steps, steps,
clipSkip, clipSkip: skipped_layers,
shouldUseCpuNoise, shouldUseCpuNoise,
vaePrecision, vaePrecision,
seamlessXAxis,
seamlessYAxis,
seed, seed,
vae,
} = state.generation; } = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present; const { positivePrompt, negativePrompt } = state.controlLayers.present;
const { width, height } = state.controlLayers.present.size; const { width, height } = state.controlLayers.present.size;
const use_cpu = shouldUseCpuNoise; assert(model, 'No model found in state');
if (!model) { const g = new Graph(CONTROL_LAYERS_GRAPH);
log.error('No model found in state'); const modelLoader = g.addNode({
throw new Error('No model found in state');
}
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
let modelLoaderNodeId = MAIN_MODEL_LOADER;
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
* ids.
*
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
* the `fit` param. These are added to the graph at the end.
*/
// copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = {
id: CONTROL_LAYERS_GRAPH,
nodes: {
[modelLoaderNodeId]: {
type: 'main_model_loader', type: 'main_model_loader',
id: modelLoaderNodeId, id: MAIN_MODEL_LOADER,
is_intermediate,
model, model,
}, });
[CLIP_SKIP]: { const clipSkip = g.addNode({
type: 'clip_skip', type: 'clip_skip',
id: CLIP_SKIP, id: CLIP_SKIP,
skipped_layers: clipSkip, skipped_layers,
is_intermediate, });
}, const posCond = g.addNode({
[POSITIVE_CONDITIONING]: {
type: 'compel', type: 'compel',
id: POSITIVE_CONDITIONING, id: POSITIVE_CONDITIONING,
prompt: positivePrompt, prompt: positivePrompt,
is_intermediate, });
}, const posCondCollect = g.addNode({
[NEGATIVE_CONDITIONING]: { type: 'collect',
id: POSITIVE_CONDITIONING_COLLECT,
});
const negCond = g.addNode({
type: 'compel', type: 'compel',
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
is_intermediate, });
}, const negCondCollect = g.addNode({
[NOISE]: { type: 'collect',
id: NEGATIVE_CONDITIONING_COLLECT,
});
const noise = g.addNode({
type: 'noise', type: 'noise',
id: NOISE, id: NOISE,
seed, seed,
width, width,
height, height,
use_cpu, use_cpu: shouldUseCpuNoise,
is_intermediate, });
}, const denoise = g.addNode({
[DENOISE_LATENTS]: {
type: 'denoise_latents', type: 'denoise_latents',
id: DENOISE_LATENTS, id: DENOISE_LATENTS,
is_intermediate,
cfg_scale, cfg_scale,
cfg_rescale_multiplier, cfg_rescale_multiplier,
scheduler, scheduler,
steps, steps,
denoising_start: 0, denoising_start: 0,
denoising_end: 1, denoising_end: 1,
}, });
[LATENTS_TO_IMAGE]: { const l2i = g.addNode({
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
fp32, fp32: vaePrecision === 'fp32',
is_intermediate: getIsIntermediate(state),
board: getBoardField(state), board: getBoardField(state),
// This is the terminal node and must always save to gallery.
is_intermediate: false,
use_cache: false, use_cache: false,
}, });
}, const vaeLoader =
edges: [ vae?.base === model.base
// Connect Model Loader to UNet and CLIP Skip ? g.addNode({
{ type: 'vae_loader',
source: { id: VAE_LOADER,
node_id: modelLoaderNodeId, vae_model: vae,
field: 'unet', })
}, : null;
destination: {
node_id: DENOISE_LATENTS, let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> = l2i;
field: 'unet',
}, g.addEdge(modelLoader, 'unet', denoise, 'unet');
}, g.addEdge(modelLoader, 'clip', clipSkip, 'clip');
{ g.addEdge(clipSkip, 'clip', posCond, 'clip');
source: { g.addEdge(clipSkip, 'clip', negCond, 'clip');
node_id: modelLoaderNodeId, g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
field: 'clip', g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
}, g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
destination: { g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
node_id: CLIP_SKIP, g.addEdge(noise, 'noise', denoise, 'noise');
field: 'clip', g.addEdge(denoise, 'latents', l2i, 'latents');
},
},
// Connect CLIP Skip to Conditioning
{
source: {
node_id: CLIP_SKIP,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: CLIP_SKIP,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
},
// Connect everything to Denoise Latents
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'noise',
},
},
// Decode Denoised Latents To Image
{
source: {
node_id: DENOISE_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
},
],
};
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode( MetadataUtil.add(g, {
graph,
{
generation_mode: 'txt2img', generation_mode: 'txt2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier, cfg_rescale_multiplier,
@ -225,43 +139,49 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<NonNull
model: getModelMetadataField(modelConfig), model: getModelMetadataField(modelConfig),
seed, seed,
steps, steps,
rand_device: use_cpu ? 'cpu' : 'cuda', rand_device: shouldUseCpuNoise ? 'cpu' : 'cuda',
scheduler, scheduler,
clip_skip: clipSkip, clip_skip: skipped_layers,
}, vae: vae ?? undefined,
LATENTS_TO_IMAGE });
g.validate();
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader, vaeLoader);
g.validate();
addGenerationTabLoRAs(state, g, denoise, modelLoader, seamless, clipSkip, posCond, negCond);
g.validate();
// We might get the VAE from the main model, custom VAE, or seamless node.
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
g.addEdge(vaeSource, 'vae', l2i, 'vae');
const addedLayers = await addGenerationTabControlLayers(
state,
g,
denoise,
posCond,
negCond,
posCondCollect,
negCondCollect,
noise,
vaeSource
); );
g.validate();
// Add Seamless To Graph const isHRFAllowed = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l));
if (seamlessXAxis || seamlessYAxis) { if (isHRFAllowed && state.hrf.hrfEnabled) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId); imageOutput = addGenerationTabHRF(state, g, denoise, noise, l2i, vaeSource);
modelLoaderNodeId = SEAMLESS;
} }
// add LoRA support
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
const addedLayers = await addControlLayersToGraph(state, graph, DENOISE_LATENTS);
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
const shouldUseHRF = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l));
// High resolution fix.
if (state.hrf.hrfEnabled && shouldUseHRF) {
addHrfToGraph(state, graph);
}
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) { if (state.system.shouldUseNSFWChecker) {
// must add before watermarker! imageOutput = addGenerationTabNSFWChecker(g, imageOutput);
addNSFWCheckerToGraph(state, graph);
} }
if (state.system.shouldUseWatermarker) { if (state.system.shouldUseWatermarker) {
// must add after nsfw checker! imageOutput = addGenerationTabWatermarker(g, imageOutput);
addWatermarkerToGraph(state, graph);
} }
return graph; MetadataUtil.setMetadataReceivingNode(g, imageOutput);
return g.getGraph();
}; };

View File

@ -1,187 +0,0 @@
import type { RootState } from 'app/store/store';
import { isInitialImageLayer, isRegionalGuidanceLayer } from 'features/controlLayers/store/controlLayersSlice';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addGenerationTabControlLayers } from 'features/nodes/util/graph/addGenerationTabControlLayers';
import { addGenerationTabHRF } from 'features/nodes/util/graph/addGenerationTabHRF';
import { addGenerationTabLoRAs } from 'features/nodes/util/graph/addGenerationTabLoRAs';
import { addGenerationTabNSFWChecker } from 'features/nodes/util/graph/addGenerationTabNSFWChecker';
import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless';
import { addGenerationTabWatermarker } from 'features/nodes/util/graph/addGenerationTabWatermarker';
import type { GraphType } from 'features/nodes/util/graph/Graph';
import { Graph } from 'features/nodes/util/graph/Graph';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import {
CLIP_SKIP,
CONTROL_LAYERS_GRAPH,
DENOISE_LATENTS,
LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT,
NOISE,
POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
VAE_LOADER,
} from './constants';
import { getModelMetadataField } from './metadata';
export const buildGenerationTabGraph2 = async (state: RootState): Promise<GraphType> => {
const {
model,
cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler,
steps,
clipSkip: skipped_layers,
shouldUseCpuNoise,
vaePrecision,
seed,
vae,
} = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
const { width, height } = state.controlLayers.present.size;
assert(model, 'No model found in state');
const g = new Graph(CONTROL_LAYERS_GRAPH);
const modelLoader = g.addNode({
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
model,
});
const clipSkip = g.addNode({
type: 'clip_skip',
id: CLIP_SKIP,
skipped_layers,
});
const posCond = g.addNode({
type: 'compel',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
});
const posCondCollect = g.addNode({
type: 'collect',
id: POSITIVE_CONDITIONING_COLLECT,
});
const negCond = g.addNode({
type: 'compel',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
});
const negCondCollect = g.addNode({
type: 'collect',
id: NEGATIVE_CONDITIONING_COLLECT,
});
const noise = g.addNode({
type: 'noise',
id: NOISE,
seed,
width,
height,
use_cpu: shouldUseCpuNoise,
});
const denoise = g.addNode({
type: 'denoise_latents',
id: DENOISE_LATENTS,
cfg_scale,
cfg_rescale_multiplier,
scheduler,
steps,
denoising_start: 0,
denoising_end: 1,
});
const l2i = g.addNode({
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32: vaePrecision === 'fp32',
board: getBoardField(state),
// This is the terminal node and must always save to gallery.
is_intermediate: false,
use_cache: false,
});
const vaeLoader =
vae?.base === model.base
? g.addNode({
type: 'vae_loader',
id: VAE_LOADER,
vae_model: vae,
})
: null;
let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> = l2i;
g.addEdge(modelLoader, 'unet', denoise, 'unet');
g.addEdge(modelLoader, 'clip', clipSkip, 'clip');
g.addEdge(clipSkip, 'clip', posCond, 'clip');
g.addEdge(clipSkip, 'clip', negCond, 'clip');
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
g.addEdge(noise, 'noise', denoise, 'noise');
g.addEdge(denoise, 'latents', l2i, 'latents');
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
MetadataUtil.add(g, {
generation_mode: 'txt2img',
cfg_scale,
cfg_rescale_multiplier,
height,
width,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: shouldUseCpuNoise ? 'cpu' : 'cuda',
scheduler,
clip_skip: skipped_layers,
vae: vae ?? undefined,
});
g.validate();
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader, vaeLoader);
g.validate();
addGenerationTabLoRAs(state, g, denoise, modelLoader, seamless, clipSkip, posCond, negCond);
g.validate();
// We might get the VAE from the main model, custom VAE, or seamless node.
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
g.addEdge(vaeSource, 'vae', l2i, 'vae');
const addedLayers = await addGenerationTabControlLayers(
state,
g,
denoise,
posCond,
negCond,
posCondCollect,
negCondCollect,
noise,
vaeSource
);
g.validate();
const isHRFAllowed = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l));
if (isHRFAllowed && state.hrf.hrfEnabled) {
imageOutput = addGenerationTabHRF(state, g, denoise, noise, l2i, vaeSource);
}
if (state.system.shouldUseNSFWChecker) {
imageOutput = addGenerationTabNSFWChecker(g, imageOutput);
}
if (state.system.shouldUseWatermarker) {
imageOutput = addGenerationTabWatermarker(g, imageOutput);
}
MetadataUtil.setMetadataReceivingNode(g, imageOutput);
return g.getGraph();
};

View File

@ -1,31 +1,33 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph'; import { addGenerationTabControlLayers } from 'features/nodes/util/graph/addGenerationTabControlLayers';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types'; import { addGenerationTabNSFWChecker } from 'features/nodes/util/graph/addGenerationTabNSFWChecker';
import { addGenerationTabSDXLLoRAs } from 'features/nodes/util/graph/addGenerationTabSDXLLoRAs';
import { addGenerationTabSDXLRefiner } from 'features/nodes/util/graph/addGenerationTabSDXLRefiner';
import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless';
import { addGenerationTabWatermarker } from 'features/nodes/util/graph/addGenerationTabWatermarker';
import { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import type { Invocation, NonNullableGraph } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { import {
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
SDXL_CONTROL_LAYERS_GRAPH, SDXL_CONTROL_LAYERS_GRAPH,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS, VAE_LOADER,
SEAMLESS,
} from './constants'; } from './constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; import { getBoardField, getSDXLStylePrompts } from './graphBuilderUtils';
import { addCoreMetadataNode, getModelMetadataField } from './metadata'; import { getModelMetadataField } from './metadata';
export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<NonNullableGraph> => { export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<NonNullableGraph> => {
const log = logger('nodes');
const { const {
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
@ -35,73 +37,45 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
steps, steps,
shouldUseCpuNoise, shouldUseCpuNoise,
vaePrecision, vaePrecision,
seamlessXAxis, vae,
seamlessYAxis,
} = state.generation; } = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present; const { positivePrompt, negativePrompt } = state.controlLayers.present;
const { width, height } = state.controlLayers.present.size; const { width, height } = state.controlLayers.present.size;
const { refinerModel, refinerStart } = state.sdxl; const { refinerModel, refinerStart } = state.sdxl;
const use_cpu = shouldUseCpuNoise; assert(model, 'No model found in state');
if (!model) {
log.error('No model found in state');
throw new Error('No model found in state');
}
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
// Construct Style Prompt
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state); const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
// Model Loader ID const g = new Graph(SDXL_CONTROL_LAYERS_GRAPH);
let modelLoaderNodeId = SDXL_MODEL_LOADER; const modelLoader = g.addNode({
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
* ids.
*
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
* the `fit` param. These are added to the graph at the end.
*/
// copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = {
id: SDXL_CONTROL_LAYERS_GRAPH,
nodes: {
[modelLoaderNodeId]: {
type: 'sdxl_model_loader', type: 'sdxl_model_loader',
id: modelLoaderNodeId, id: SDXL_MODEL_LOADER,
model, model,
is_intermediate, });
}, const posCond = g.addNode({
[POSITIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt', type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING, id: POSITIVE_CONDITIONING,
prompt: positivePrompt, prompt: positivePrompt,
style: positiveStylePrompt, style: positiveStylePrompt,
is_intermediate, });
}, const posCondCollect = g.addNode({
[NEGATIVE_CONDITIONING]: { type: 'collect',
id: POSITIVE_CONDITIONING_COLLECT,
});
const negCond = g.addNode({
type: 'sdxl_compel_prompt', type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
style: negativeStylePrompt, style: negativeStylePrompt,
is_intermediate, });
}, const negCondCollect = g.addNode({
[NOISE]: { type: 'collect',
type: 'noise', id: NEGATIVE_CONDITIONING_COLLECT,
id: NOISE, });
seed, const noise = g.addNode({ type: 'noise', id: NOISE, seed, width, height, use_cpu: shouldUseCpuNoise });
width, const denoise = g.addNode({
height,
use_cpu,
is_intermediate,
},
[SDXL_DENOISE_LATENTS]: {
type: 'denoise_latents', type: 'denoise_latents',
id: SDXL_DENOISE_LATENTS, id: SDXL_DENOISE_LATENTS,
cfg_scale, cfg_scale,
@ -110,119 +84,42 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
steps, steps,
denoising_start: 0, denoising_start: 0,
denoising_end: refinerModel ? refinerStart : 1, denoising_end: refinerModel ? refinerStart : 1,
is_intermediate, });
}, const l2i = g.addNode({
[LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
fp32, fp32: vaePrecision === 'fp32',
is_intermediate: getIsIntermediate(state),
board: getBoardField(state), board: getBoardField(state),
// This is the terminal node and must always save to gallery.
is_intermediate: false,
use_cache: false, use_cache: false,
}, });
}, const vaeLoader =
edges: [ vae?.base === model.base
// Connect Model Loader to UNet, VAE & CLIP ? g.addNode({
{ type: 'vae_loader',
source: { id: VAE_LOADER,
node_id: modelLoaderNodeId, vae_model: vae,
field: 'unet', })
}, : null;
destination: {
node_id: SDXL_DENOISE_LATENTS, let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> = l2i;
field: 'unet',
}, g.addEdge(modelLoader, 'unet', denoise, 'unet');
}, g.addEdge(modelLoader, 'clip', posCond, 'clip');
{ g.addEdge(modelLoader, 'clip', negCond, 'clip');
source: { g.addEdge(modelLoader, 'clip2', posCond, 'clip2');
node_id: modelLoaderNodeId, g.addEdge(modelLoader, 'clip2', negCond, 'clip2');
field: 'clip', g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
}, g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
destination: { g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
node_id: POSITIVE_CONDITIONING, g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
field: 'clip', g.addEdge(noise, 'noise', denoise, 'noise');
}, g.addEdge(denoise, 'latents', l2i, 'latents');
},
{
source: {
node_id: modelLoaderNodeId,
field: 'clip2',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: modelLoaderNodeId,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: modelLoaderNodeId,
field: 'clip2',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip2',
},
},
// Connect everything to Denoise Latents
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_DENOISE_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_DENOISE_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: SDXL_DENOISE_LATENTS,
field: 'noise',
},
},
// Decode Denoised Latents To Image
{
source: {
node_id: SDXL_DENOISE_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
},
],
};
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode( MetadataUtil.add(g, {
graph,
{
generation_mode: 'txt2img', generation_mode: 'txt2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier, cfg_rescale_multiplier,
@ -233,46 +130,49 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
model: getModelMetadataField(modelConfig), model: getModelMetadataField(modelConfig),
seed, seed,
steps, steps,
rand_device: use_cpu ? 'cpu' : 'cuda', rand_device: shouldUseCpuNoise ? 'cpu' : 'cuda',
scheduler, scheduler,
positive_style_prompt: positiveStylePrompt, positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt, negative_style_prompt: negativeStylePrompt,
}, vae: vae ?? undefined,
LATENTS_TO_IMAGE });
); g.validate();
// Add Seamless To Graph const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader, vaeLoader);
if (seamlessXAxis || seamlessYAxis) { g.validate();
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS; addGenerationTabSDXLLoRAs(state, g, denoise, modelLoader, seamless, posCond, negCond);
} g.validate();
// We might get the VAE from the main model, custom VAE, or seamless node.
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
g.addEdge(vaeSource, 'vae', l2i, 'vae');
// Add Refiner if enabled // Add Refiner if enabled
if (refinerModel) { if (refinerModel) {
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); await addGenerationTabSDXLRefiner(state, g, denoise, modelLoader, seamless, posCond, negCond, l2i);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
} }
// add LoRA support await addGenerationTabControlLayers(
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); state,
g,
denoise,
posCond,
negCond,
posCondCollect,
negCondCollect,
noise,
vaeSource
);
await addControlLayersToGraph(state, graph, SDXL_DENOISE_LATENTS);
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) { if (state.system.shouldUseNSFWChecker) {
// must add before watermarker! imageOutput = addGenerationTabNSFWChecker(g, imageOutput);
addNSFWCheckerToGraph(state, graph);
} }
if (state.system.shouldUseWatermarker) { if (state.system.shouldUseWatermarker) {
// must add after nsfw checker! imageOutput = addGenerationTabWatermarker(g, imageOutput);
addWatermarkerToGraph(state, graph);
} }
return graph; MetadataUtil.setMetadataReceivingNode(g, imageOutput);
return g.getGraph();
}; };

View File

@ -1,178 +0,0 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addGenerationTabControlLayers } from 'features/nodes/util/graph/addGenerationTabControlLayers';
import { addGenerationTabNSFWChecker } from 'features/nodes/util/graph/addGenerationTabNSFWChecker';
import { addGenerationTabSDXLLoRAs } from 'features/nodes/util/graph/addGenerationTabSDXLLoRAs';
import { addGenerationTabSDXLRefiner } from 'features/nodes/util/graph/addGenerationTabSDXLRefiner';
import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless';
import { addGenerationTabWatermarker } from 'features/nodes/util/graph/addGenerationTabWatermarker';
import { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import type { Invocation, NonNullableGraph } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import {
LATENTS_TO_IMAGE,
NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT,
NOISE,
POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
SDXL_CONTROL_LAYERS_GRAPH,
SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER,
VAE_LOADER,
} from './constants';
import { getBoardField, getSDXLStylePrompts } from './graphBuilderUtils';
import { getModelMetadataField } from './metadata';
export const buildGenerationTabSDXLGraph2 = async (state: RootState): Promise<NonNullableGraph> => {
const {
model,
cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler,
seed,
steps,
shouldUseCpuNoise,
vaePrecision,
vae,
} = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
const { width, height } = state.controlLayers.present.size;
const { refinerModel, refinerStart } = state.sdxl;
assert(model, 'No model found in state');
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
const g = new Graph(SDXL_CONTROL_LAYERS_GRAPH);
const modelLoader = g.addNode({
type: 'sdxl_model_loader',
id: SDXL_MODEL_LOADER,
model,
});
const posCond = g.addNode({
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: positiveStylePrompt,
});
const posCondCollect = g.addNode({
type: 'collect',
id: POSITIVE_CONDITIONING_COLLECT,
});
const negCond = g.addNode({
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: negativeStylePrompt,
});
const negCondCollect = g.addNode({
type: 'collect',
id: NEGATIVE_CONDITIONING_COLLECT,
});
const noise = g.addNode({ type: 'noise', id: NOISE, seed, width, height, use_cpu: shouldUseCpuNoise });
const denoise = g.addNode({
type: 'denoise_latents',
id: SDXL_DENOISE_LATENTS,
cfg_scale,
cfg_rescale_multiplier,
scheduler,
steps,
denoising_start: 0,
denoising_end: refinerModel ? refinerStart : 1,
});
const l2i = g.addNode({
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32: vaePrecision === 'fp32',
board: getBoardField(state),
// This is the terminal node and must always save to gallery.
is_intermediate: false,
use_cache: false,
});
const vaeLoader =
vae?.base === model.base
? g.addNode({
type: 'vae_loader',
id: VAE_LOADER,
vae_model: vae,
})
: null;
let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> = l2i;
g.addEdge(modelLoader, 'unet', denoise, 'unet');
g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 'clip', negCond, 'clip');
g.addEdge(modelLoader, 'clip2', posCond, 'clip2');
g.addEdge(modelLoader, 'clip2', negCond, 'clip2');
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
g.addEdge(noise, 'noise', denoise, 'noise');
g.addEdge(denoise, 'latents', l2i, 'latents');
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
MetadataUtil.add(g, {
generation_mode: 'txt2img',
cfg_scale,
cfg_rescale_multiplier,
height,
width,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: shouldUseCpuNoise ? 'cpu' : 'cuda',
scheduler,
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
vae: vae ?? undefined,
});
g.validate();
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader, vaeLoader);
g.validate();
addGenerationTabSDXLLoRAs(state, g, denoise, modelLoader, seamless, posCond, negCond);
g.validate();
// We might get the VAE from the main model, custom VAE, or seamless node.
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
g.addEdge(vaeSource, 'vae', l2i, 'vae');
// Add Refiner if enabled
if (refinerModel) {
await addGenerationTabSDXLRefiner(state, g, denoise, modelLoader, seamless, posCond, negCond, l2i);
}
await addGenerationTabControlLayers(
state,
g,
denoise,
posCond,
negCond,
posCondCollect,
negCondCollect,
noise,
vaeSource
);
if (state.system.shouldUseNSFWChecker) {
imageOutput = addGenerationTabNSFWChecker(g, imageOutput);
}
if (state.system.shouldUseWatermarker) {
imageOutput = addGenerationTabWatermarker(g, imageOutput);
}
MetadataUtil.setMetadataReceivingNode(g, imageOutput);
return g.getGraph();
};

View File

@ -58,21 +58,6 @@ export const getHasMetadata = (graph: NonNullableGraph): boolean => {
return Boolean(metadataNode); return Boolean(metadataNode);
}; };
export const setMetadataReceivingNode = (graph: NonNullableGraph, nodeId: string) => {
graph.edges = graph.edges.filter((edge) => edge.source.node_id !== METADATA);
graph.edges.push({
source: {
node_id: METADATA,
field: 'metadata',
},
destination: {
node_id: nodeId,
field: 'metadata',
},
});
};
export const getModelMetadataField = ({ key, hash, name, base, type }: AnyModelConfig): ModelIdentifierField => ({ export const getModelMetadataField = ({ key, hash, name, base, type }: AnyModelConfig): ModelIdentifierField => ({
key, key,
hash, hash,