WIP, sd1.5 works

This commit is contained in:
psychedelicious 2024-05-06 15:31:08 +10:00
parent dbe22be598
commit f8042ffb41
9 changed files with 840 additions and 53 deletions

View File

@ -1,7 +1,7 @@
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { buildGenerationTabGraph } from 'features/nodes/util/graph/buildGenerationTabGraph';
import { buildGenerationTabGraph2 } from 'features/nodes/util/graph/buildGenerationTabGraph2';
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { queueApi } from 'services/api/endpoints/queue';
@ -21,7 +21,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
if (model && model.base === 'sdxl') {
graph = await buildGenerationTabSDXLGraph(state);
} else {
graph = await buildGenerationTabGraph(state);
graph = await buildGenerationTabGraph2(state);
}
const batchConfig = prepareLinearUIBatch(state, graph, prepend);

View File

@ -289,4 +289,80 @@ describe('Graph', () => {
});
});
});
describe('deleteEdgesFrom', () => {
it('should delete edges from the provided node', () => {
const g = new Graph();
const n1 = g.addNode({
id: 'n1',
type: 'img_resize',
});
const n2 = g.addNode({
id: 'n2',
type: 'add',
});
const _e1 = g.addEdge(n1, 'height', n2, 'a');
const _e2 = g.addEdge(n1, 'width', n2, 'b');
g.deleteEdgesFrom(n1);
expect(g.getEdgesFrom(n1)).toEqual([]);
});
it('should delete edges from the provided node, with the provided field', () => {
const g = new Graph();
const n1 = g.addNode({
id: 'n1',
type: 'img_resize',
});
const n2 = g.addNode({
id: 'n2',
type: 'add',
});
const n3 = g.addNode({
id: 'n3',
type: 'add',
});
const _e1 = g.addEdge(n1, 'height', n2, 'a');
const e2 = g.addEdge(n1, 'width', n2, 'b');
const e3 = g.addEdge(n1, 'width', n3, 'b');
g.deleteEdgesFrom(n1, 'height');
expect(g.getEdgesFrom(n1)).toEqual([e2, e3]);
});
});
describe('deleteEdgesTo', () => {
it('should delete edges to the provided node', () => {
const g = new Graph();
const n1 = g.addNode({
id: 'n1',
type: 'img_resize',
});
const n2 = g.addNode({
id: 'n2',
type: 'add',
});
const _e1 = g.addEdge(n1, 'height', n2, 'a');
const _e2 = g.addEdge(n1, 'width', n2, 'b');
g.deleteEdgesTo(n2);
expect(g.getEdgesTo(n2)).toEqual([]);
});
it('should delete edges to the provided node, with the provided field', () => {
const g = new Graph();
const n1 = g.addNode({
id: 'n1',
type: 'img_resize',
});
const n2 = g.addNode({
id: 'n2',
type: 'img_resize',
});
const n3 = g.addNode({
id: 'n3',
type: 'add',
});
const _e1 = g.addEdge(n1, 'height', n3, 'a');
const e2 = g.addEdge(n1, 'width', n3, 'b');
const _e3 = g.addEdge(n2, 'width', n3, 'a');
g.deleteEdgesTo(n3, 'a');
expect(g.getEdgesTo(n3)).toEqual([e2]);
});
});
});

View File

@ -1,4 +1,4 @@
import { isEqual } from 'lodash-es';
import { forEach, groupBy, isEqual, values } from 'lodash-es';
import type {
AnyInvocation,
AnyInvocationInputField,
@ -22,7 +22,7 @@ type Edge = {
};
};
type GraphType = { id: string; nodes: Record<string, AnyInvocation>; edges: Edge[] };
export type GraphType = { id: string; nodes: Record<string, AnyInvocation>; edges: Edge[] };
export class Graph {
_graph: GraphType;
@ -130,6 +130,31 @@ export class Graph {
return edge;
}
/**
* Add an edge to the graph. If an edge with the same source and destination already exists, an `AssertionError` is raised.
* If providing node ids, provide the from and to node types as generics to get type hints for from and to field names.
* @param fromNode The source node or id of the source node.
* @param fromField The field of the source node.
* @param toNode The source node or id of the destination node.
* @param toField The field of the destination node.
* @returns The added edge.
* @raises `AssertionError` if an edge with the same source and destination already exists.
*/
addEdgeFromObj(edge: Edge): Edge {
const edgeAlreadyExists = this._graph.edges.some((e) => isEqual(e, edge));
assert(
!edgeAlreadyExists,
Graph.getEdgeAlreadyExistsMsg(
edge.source.node_id,
edge.source.field,
edge.destination.node_id,
edge.destination.field
)
);
this._graph.edges.push(edge);
return edge;
}
/**
* Get an edge from the graph. If the edge does not exist, an `AssertionError` is raised.
* Provide the from and to node types as generics to get type hints for from and to field names.
@ -255,6 +280,24 @@ export class Graph {
for (const edge of this._graph.edges) {
this.getNode(edge.source.node_id);
this.getNode(edge.destination.node_id);
assert(
!this._graph.edges.filter((e) => e !== edge).find((e) => isEqual(e, edge)),
`Duplicate edge: ${Graph.edgeToString(edge)}`
);
}
for (const node of values(this._graph.nodes)) {
const edgesTo = this.getEdgesTo(node);
// Validate that no node has multiple incoming edges with the same field
forEach(groupBy(edgesTo, 'destination.field'), (group, field) => {
if (node.type === 'collect' && field === 'item') {
// Collectors' item field accepts multiple incoming edges
return;
}
assert(
group.length === 1,
`Node ${node.id} has multiple incoming edges with field ${field}: ${group.map(Graph.edgeToString).join(', ')}`
);
});
}
}
@ -299,6 +342,10 @@ export class Graph {
return `Edge from ${fromNodeId}.${fromField} to ${toNodeId}.${toField} already exists`;
}
static edgeToString(edge: Edge): string {
return `${edge.source.node_id}.${edge.source.field} -> ${edge.destination.node_id}.${edge.destination.field}`;
}
static uuid = uuidv4;
//#endregion
}

View File

@ -0,0 +1,539 @@
import { getStore } from 'app/store/nanostores/store';
import type { RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import {
isControlAdapterLayer,
isIPAdapterLayer,
isRegionalGuidanceLayer,
rgLayerMaskImageUploaded,
} from 'features/controlLayers/store/controlLayersSlice';
import type { RegionalGuidanceLayer } from 'features/controlLayers/store/types';
import {
type ControlNetConfigV2,
type ImageWithDims,
type IPAdapterConfigV2,
isControlNetConfigV2,
isT2IAdapterConfigV2,
type ProcessorConfig,
type T2IAdapterConfigV2,
} from 'features/controlLayers/util/controlAdapters';
import { getRegionalPromptLayerBlobs } from 'features/controlLayers/util/getLayerBlobs';
import type { ImageField } from 'features/nodes/types/common';
import {
CONTROL_NET_COLLECT,
IP_ADAPTER_COLLECT,
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
PROMPT_REGION_MASK_TO_TENSOR_PREFIX,
PROMPT_REGION_NEGATIVE_COND_PREFIX,
PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX,
PROMPT_REGION_POSITIVE_COND_PREFIX,
T2I_ADAPTER_COLLECT,
} from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import { size } from 'lodash-es';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO, Invocation, S } from 'services/api/types';
import { assert } from 'tsafe';
const buildControlImage = (
image: ImageWithDims | null,
processedImage: ImageWithDims | null,
processorConfig: ProcessorConfig | null
): ImageField => {
if (processedImage && processorConfig) {
// We've processed the image in the app - use it for the control image.
return {
image_name: processedImage.imageName,
};
} else if (image) {
// No processor selected, and we have an image - the user provided a processed image, use it for the control image.
return {
image_name: image.imageName,
};
}
assert(false, 'Attempted to add unprocessed control image');
};
const buildControlNetMetadata = (controlNet: ControlNetConfigV2): S['ControlNetMetadataField'] => {
const { beginEndStepPct, controlMode, image, model, processedImage, processorConfig, weight } = controlNet;
assert(model, 'ControlNet model is required');
assert(image, 'ControlNet image is required');
const processed_image =
processedImage && processorConfig
? {
image_name: processedImage.imageName,
}
: null;
return {
control_model: model,
control_weight: weight,
control_mode: controlMode,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
resize_mode: 'just_resize',
image: {
image_name: image.imageName,
},
processed_image,
};
};
const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => {
try {
// Attempt to retrieve the collector
const controlNetCollect = g.getNode(CONTROL_NET_COLLECT);
assert(controlNetCollect.type === 'collect');
return controlNetCollect;
} catch {
// Add the ControlNet collector
const controlNetCollect = g.addNode({
id: CONTROL_NET_COLLECT,
type: 'collect',
});
g.addEdge(controlNetCollect, 'collection', denoise, 'control');
return controlNetCollect;
}
};
const addGlobalControlNetsToGraph = (
controlNetConfigs: ControlNetConfigV2[],
g: Graph,
denoise: Invocation<'denoise_latents'>
): void => {
if (controlNetConfigs.length === 0) {
return;
}
const controlNetMetadata: S['ControlNetMetadataField'][] = [];
const controlNetCollect = addControlNetCollectorSafe(g, denoise);
for (const controlNetConfig of controlNetConfigs) {
if (!controlNetConfig.model) {
return;
}
const { id, beginEndStepPct, controlMode, image, model, processedImage, processorConfig, weight } =
controlNetConfig;
const controlNet = g.addNode({
id: `control_net_${id}`,
type: 'controlnet',
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
control_mode: controlMode,
resize_mode: 'just_resize',
control_model: model,
control_weight: weight,
image: buildControlImage(image, processedImage, processorConfig),
});
g.addEdge(controlNet, 'control', controlNetCollect, 'item');
controlNetMetadata.push(buildControlNetMetadata(controlNetConfig));
}
MetadataUtil.add(g, { controlnets: controlNetMetadata });
};
const buildT2IAdapterMetadata = (t2iAdapter: T2IAdapterConfigV2): S['T2IAdapterMetadataField'] => {
const { beginEndStepPct, image, model, processedImage, processorConfig, weight } = t2iAdapter;
assert(model, 'T2I Adapter model is required');
assert(image, 'T2I Adapter image is required');
const processed_image =
processedImage && processorConfig
? {
image_name: processedImage.imageName,
}
: null;
return {
t2i_adapter_model: model,
weight,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
resize_mode: 'just_resize',
image: {
image_name: image.imageName,
},
processed_image,
};
};
const addT2IAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => {
try {
// You see, we've already got one!
const t2iAdapterCollect = g.getNode(T2I_ADAPTER_COLLECT);
assert(t2iAdapterCollect.type === 'collect');
return t2iAdapterCollect;
} catch {
const t2iAdapterCollect = g.addNode({
id: T2I_ADAPTER_COLLECT,
type: 'collect',
});
g.addEdge(t2iAdapterCollect, 'collection', denoise, 't2i_adapter');
return t2iAdapterCollect;
}
};
const addGlobalT2IAdaptersToGraph = (
t2iAdapterConfigs: T2IAdapterConfigV2[],
g: Graph,
denoise: Invocation<'denoise_latents'>
): void => {
if (t2iAdapterConfigs.length === 0) {
return;
}
const t2iAdapterMetadata: S['T2IAdapterMetadataField'][] = [];
const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise);
for (const t2iAdapterConfig of t2iAdapterConfigs) {
if (!t2iAdapterConfig.model) {
return;
}
const { id, beginEndStepPct, image, model, processedImage, processorConfig, weight } = t2iAdapterConfig;
const t2iAdapter = g.addNode({
id: `t2i_adapter_${id}`,
type: 't2i_adapter',
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
resize_mode: 'just_resize',
t2i_adapter_model: model,
weight: weight,
image: buildControlImage(image, processedImage, processorConfig),
});
g.addEdge(t2iAdapter, 't2i_adapter', t2iAdapterCollect, 'item');
t2iAdapterMetadata.push(buildT2IAdapterMetadata(t2iAdapterConfig));
}
MetadataUtil.add(g, { t2iAdapters: t2iAdapterMetadata });
};
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfigV2): S['IPAdapterMetadataField'] => {
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
assert(model, 'IP Adapter model is required');
assert(image, 'IP Adapter image is required');
return {
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
weight,
method,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
image: {
image_name: image.imageName,
},
};
};
const addIPAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => {
try {
// You see, we've already got one!
const ipAdapterCollect = g.getNode(IP_ADAPTER_COLLECT);
assert(ipAdapterCollect.type === 'collect');
return ipAdapterCollect;
} catch {
const ipAdapterCollect = g.addNode({
id: IP_ADAPTER_COLLECT,
type: 'collect',
});
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
return ipAdapterCollect;
}
};
const addGlobalIPAdaptersToGraph = (
ipAdapterConfigs: IPAdapterConfigV2[],
g: Graph,
denoise: Invocation<'denoise_latents'>
): void => {
if (ipAdapterConfigs.length === 0) {
return;
}
const ipAdapterMetdata: S['IPAdapterMetadataField'][] = [];
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
for (const ipAdapterConfig of ipAdapterConfigs) {
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapterConfig;
assert(image, 'IP Adapter image is required');
assert(model, 'IP Adapter model is required');
const ipAdapter = g.addNode({
id: `ip_adapter_${id}`,
type: 'ip_adapter',
weight,
method,
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
image: {
image_name: image.imageName,
},
});
g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item');
ipAdapterMetdata.push(buildIPAdapterMetadata(ipAdapterConfig));
}
MetadataUtil.add(g, { ipAdapters: ipAdapterMetdata });
};
export const addGenerationTabControlLayers = async (
state: RootState,
g: Graph,
denoise: Invocation<'denoise_latents'>,
posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
posCondCollect: Invocation<'collect'>,
negCondCollect: Invocation<'collect'>
) => {
const mainModel = state.generation.model;
assert(mainModel, 'Missing main model when building graph');
const isSDXL = mainModel.base === 'sdxl';
// Add global control adapters
const globalControlNetConfigs = state.controlLayers.present.layers
// Must be a CA layer
.filter(isControlAdapterLayer)
// Must be enabled
.filter((l) => l.isEnabled)
// We want the CAs themselves
.map((l) => l.controlAdapter)
// Must be a ControlNet
.filter(isControlNetConfigV2)
.filter((ca) => {
const hasModel = Boolean(ca.model);
const modelMatchesBase = ca.model?.base === mainModel.base;
const hasControlImage = ca.image || (ca.processedImage && ca.processorConfig);
return hasModel && modelMatchesBase && hasControlImage;
});
addGlobalControlNetsToGraph(globalControlNetConfigs, g, denoise);
const globalT2IAdapterConfigs = state.controlLayers.present.layers
// Must be a CA layer
.filter(isControlAdapterLayer)
// Must be enabled
.filter((l) => l.isEnabled)
// We want the CAs themselves
.map((l) => l.controlAdapter)
// Must have a ControlNet CA
.filter(isT2IAdapterConfigV2)
.filter((ca) => {
const hasModel = Boolean(ca.model);
const modelMatchesBase = ca.model?.base === mainModel.base;
const hasControlImage = ca.image || (ca.processedImage && ca.processorConfig);
return hasModel && modelMatchesBase && hasControlImage;
});
addGlobalT2IAdaptersToGraph(globalT2IAdapterConfigs, g, denoise);
const globalIPAdapterConfigs = state.controlLayers.present.layers
// Must be an IP Adapter layer
.filter(isIPAdapterLayer)
// Must be enabled
.filter((l) => l.isEnabled)
// We want the IP Adapters themselves
.map((l) => l.ipAdapter)
.filter((ca) => {
const hasModel = Boolean(ca.model);
const modelMatchesBase = ca.model?.base === mainModel.base;
const hasControlImage = Boolean(ca.image);
return hasModel && modelMatchesBase && hasControlImage;
});
addGlobalIPAdaptersToGraph(globalIPAdapterConfigs, g, denoise);
const rgLayers = state.controlLayers.present.layers
// Only RG layers are get masks
.filter(isRegionalGuidanceLayer)
// Only visible layers are rendered on the canvas
.filter((l) => l.isEnabled)
// Only layers with prompts get added to the graph
.filter((l) => {
const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt);
const hasIPAdapter = l.ipAdapters.length !== 0;
return hasTextPrompt || hasIPAdapter;
});
const layerIds = rgLayers.map((l) => l.id);
const blobs = await getRegionalPromptLayerBlobs(layerIds);
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
for (const layer of rgLayers) {
const blob = blobs[layer.id];
assert(blob, `Blob for layer ${layer.id} not found`);
// Upload the mask image, or get the cached image if it exists
const { image_name } = await getMaskImage(layer, blob);
// The main mask-to-tensor node
const maskToTensor = g.addNode({
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${layer.id}`,
type: 'alpha_mask_to_tensor',
image: {
image_name,
},
});
if (layer.positivePrompt) {
// The main positive conditioning node
const regionalPosCond = g.addNode(
isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
style: layer.positivePrompt, // TODO: Should we put the positive prompt in both fields?
}
: {
type: 'compel',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
}
);
// Connect the mask to the conditioning
g.addEdge(maskToTensor, 'mask', regionalPosCond, 'mask');
// Connect the conditioning to the collector
g.addEdge(regionalPosCond, 'conditioning', posCondCollect, 'item');
// Copy the connections to the "global" positive conditioning node to the regional cond
for (const edge of g.getEdgesTo(posCond)) {
console.log(edge);
if (edge.destination.field !== 'prompt') {
// Clone the edge, but change the destination node to the regional conditioning node
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCond.id;
g.addEdgeFromObj(clone);
}
}
}
if (layer.negativePrompt) {
// The main negative conditioning node
const regionalNegCond = g.addNode(
isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.negativePrompt,
style: layer.negativePrompt,
}
: {
type: 'compel',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.negativePrompt,
}
);
// Connect the mask to the conditioning
g.addEdge(maskToTensor, 'mask', regionalNegCond, 'mask');
// Connect the conditioning to the collector
g.addEdge(regionalNegCond, 'conditioning', negCondCollect, 'item');
// Copy the connections to the "global" negative conditioning node to the regional cond
for (const edge of g.getEdgesTo(negCond)) {
if (edge.destination.field !== 'prompt') {
// Clone the edge, but change the destination node to the regional conditioning node
const clone = deepClone(edge);
clone.destination.node_id = regionalNegCond.id;
g.addEdgeFromObj(clone);
}
}
}
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
if (layer.autoNegative === 'invert' && layer.positivePrompt) {
// We re-use the mask image, but invert it when converting to tensor
const invertTensorMask = g.addNode({
id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${layer.id}`,
type: 'invert_tensor_mask',
});
// Connect the OG mask image to the inverted mask-to-tensor node
g.addEdge(maskToTensor, 'mask', invertTensorMask, 'mask');
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the positive prompt
const regionalPosCondInverted = g.addNode(
isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
style: layer.positivePrompt,
}
: {
type: 'compel',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
}
);
// Connect the inverted mask to the conditioning
g.addEdge(invertTensorMask, 'mask', regionalPosCondInverted, 'mask');
// Connect the conditioning to the negative collector
g.addEdge(regionalPosCondInverted, 'conditioning', negCondCollect, 'item');
// Copy the connections to the "global" positive conditioning node to our regional node
for (const edge of g.getEdgesTo(posCond)) {
if (edge.destination.field !== 'prompt') {
// Clone the edge, but change the destination node to the regional conditioning node
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCondInverted.id;
g.addEdgeFromObj(clone);
}
}
}
// TODO(psyche): For some reason, I have to explicitly annotate regionalIPAdapters here. Not sure why.
const regionalIPAdapters: IPAdapterConfigV2[] = layer.ipAdapters.filter((ipAdapter) => {
const hasModel = Boolean(ipAdapter.model);
const modelMatchesBase = ipAdapter.model?.base === mainModel.base;
const hasControlImage = Boolean(ipAdapter.image);
return hasModel && modelMatchesBase && hasControlImage;
});
for (const ipAdapterConfig of regionalIPAdapters) {
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapterConfig;
assert(model, 'IP Adapter model is required');
assert(image, 'IP Adapter image is required');
const ipAdapter = g.addNode({
id: `ip_adapter_${id}`,
type: 'ip_adapter',
weight,
method,
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
image: {
image_name: image.imageName,
},
});
// Connect the mask to the conditioning
g.addEdge(maskToTensor, 'mask', ipAdapter, 'mask');
g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item');
}
}
};
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {
if (layer.uploadedMaskImage) {
const imageDTO = await getImageDTO(layer.uploadedMaskImage.imageName);
if (imageDTO) {
return imageDTO;
}
}
const { dispatch } = getStore();
// No cached mask, or the cached image no longer exists - we need to upload the mask image
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();
const imageDTO = await req.unwrap();
dispatch(rgLayerMaskImageUploaded({ layerId: layer.id, imageDTO }));
return imageDTO;
};

View File

@ -15,12 +15,12 @@ import { IMAGE_TO_LATENTS, RESIZE } from './constants';
* @param noise The noise node in the graph
* @returns Whether the initial image was added to the graph
*/
export const addInitialImageToGenerationTabGraph = (
export const addGenerationTabInitialImage = (
state: RootState,
g: Graph,
denoise: Invocation<'denoise_latents'>,
noise: Invocation<'noise'>
): boolean => {
): Invocation<'i2l'> | null => {
// Remove Existing UNet Connections
const { img2imgStrength, vaePrecision, model } = state.generation;
const { refinerModel, refinerStart } = state.sdxl;
@ -29,7 +29,7 @@ export const addInitialImageToGenerationTabGraph = (
const initialImage = initialImageLayer?.isEnabled ? initialImageLayer?.image : null;
if (!initialImage) {
return false;
return null;
}
const isSDXL = model?.base === 'sdxl';
@ -75,5 +75,5 @@ export const addInitialImageToGenerationTabGraph = (
init_image: initialImage.imageName,
});
return true;
return i2l;
};

View File

@ -0,0 +1,94 @@
import type { RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import { filter, size } from 'lodash-es';
import type { Invocation, S } from 'services/api/types';
import { assert } from 'tsafe';
import { LORA_LOADER } from './constants';
export const addGenerationTabLoRAs = (
state: RootState,
g: Graph,
denoise: Invocation<'denoise_latents'>,
unetSource: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'>,
clipSkip: Invocation<'clip_skip'>,
posCond: Invocation<'compel'>,
negCond: Invocation<'compel'>
): void => {
/**
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
* or to the inference/conditioning nodes.
*
* So we need to inject a LoRA chain into the graph.
*/
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
const loraCount = size(enabledLoRAs);
if (loraCount === 0) {
return;
}
// Remove modelLoaderNodeId unet connection to feed it to LoRAs
console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e)));
g.deleteEdgesFrom(unetSource, 'unet');
console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e)));
if (clipSkip) {
// Remove CLIP_SKIP connections to conditionings to feed it through LoRAs
g.deleteEdgesFrom(clipSkip, 'clip');
}
console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e)));
// we need to remember the last lora so we can chain from it
let lastLoRALoader: Invocation<'lora_loader'> | null = null;
let currentLoraIndex = 0;
const loraMetadata: S['LoRAMetadataField'][] = [];
for (const lora of enabledLoRAs) {
const { weight } = lora;
const { key } = lora.model;
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
const parsedModel = zModelIdentifierField.parse(lora.model);
const currentLoRALoader = g.addNode({
type: 'lora_loader',
id: currentLoraNodeId,
lora: parsedModel,
weight,
});
loraMetadata.push({
model: parsedModel,
weight,
});
// add to graph
if (currentLoraIndex === 0) {
// first lora = start the lora chain, attach directly to model loader
g.addEdge(unetSource, 'unet', currentLoRALoader, 'unet');
g.addEdge(clipSkip, 'clip', currentLoRALoader, 'clip');
} else {
assert(lastLoRALoader !== null);
// we are in the middle of the lora chain, instead connect to the previous lora
g.addEdge(lastLoRALoader, 'unet', currentLoRALoader, 'unet');
g.addEdge(lastLoRALoader, 'clip', currentLoRALoader, 'clip');
}
if (currentLoraIndex === loraCount - 1) {
// final lora, end the lora chain - we need to connect up to inference and conditioning nodes
g.addEdge(currentLoRALoader, 'unet', denoise, 'unet');
g.addEdge(currentLoRALoader, 'clip', posCond, 'clip');
g.addEdge(currentLoRALoader, 'clip', negCond, 'clip');
}
// increment the lora for the next one in the chain
lastLoRALoader = currentLoRALoader;
currentLoraIndex += 1;
}
MetadataUtil.add(g, { loras: loraMetadata });
};

View File

@ -15,26 +15,28 @@ import { SEAMLESS, VAE_LOADER } from './constants';
* @param modelLoader The model loader node in the graph
* @returns The terminal model loader node in the graph
*/
export const addSeamlessToGenerationTabGraph = (
export const addGenerationTabSeamless = (
state: RootState,
g: Graph,
denoise: Invocation<'denoise_latents'>,
modelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'>
): Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'> => {
const { seamlessXAxis, seamlessYAxis, vae } = state.generation;
): Invocation<'seamless'> | null => {
const { seamlessXAxis: seamless_x, seamlessYAxis: seamless_y, vae } = state.generation;
if (!seamlessXAxis && !seamlessYAxis) {
return modelLoader;
if (!seamless_x && !seamless_y) {
return null;
}
const seamless = g.addNode({
id: SEAMLESS,
type: 'seamless',
seamless_x: seamlessXAxis,
seamless_y: seamlessYAxis,
seamless_x,
seamless_y,
});
const vaeLoader = vae
// The VAE helper also adds the VAE loader - so we need to check if it's already there
const shouldAddVAELoader = !g.hasNode(VAE_LOADER) && vae;
const vaeLoader = shouldAddVAELoader
? g.addNode({
type: 'vae_loader',
id: VAE_LOADER,
@ -42,29 +44,18 @@ export const addSeamlessToGenerationTabGraph = (
})
: null;
let terminalModelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'> =
modelLoader;
if (seamlessXAxis) {
MetadataUtil.add(g, {
seamless_x: seamlessXAxis,
seamless_x: seamless_x || undefined,
seamless_y: seamless_y || undefined,
});
terminalModelLoader = seamless;
}
if (seamlessYAxis) {
MetadataUtil.add(g, {
seamless_y: seamlessYAxis,
});
terminalModelLoader = seamless;
}
// Seamless slots into the graph between the model loader and the denoise node
g.deleteEdgesFrom(modelLoader, 'unet');
g.deleteEdgesFrom(modelLoader, 'clip');
g.deleteEdgesFrom(modelLoader, 'vae');
g.addEdge(modelLoader, 'unet', seamless, 'unet');
g.addEdge(vaeLoader ?? modelLoader, 'vae', seamless, 'unet');
g.addEdge(vaeLoader ?? modelLoader, 'vae', seamless, 'vae');
g.addEdge(seamless, 'unet', denoise, 'unet');
return terminalModelLoader;
return seamless;
};

View File

@ -0,0 +1,37 @@
import type { RootState } from 'app/store/store';
import type { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import type { Invocation } from 'services/api/types';
import { VAE_LOADER } from './constants';
export const addGenerationTabVAE = (
state: RootState,
g: Graph,
modelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'>,
l2i: Invocation<'l2i'>,
i2l: Invocation<'i2l'> | null,
seamless: Invocation<'seamless'> | null
): void => {
const { vae } = state.generation;
// The seamless helper also adds the VAE loader... so we need to check if it's already there
const shouldAddVAELoader = !g.hasNode(VAE_LOADER) && vae;
const vaeLoader = shouldAddVAELoader
? g.addNode({
type: 'vae_loader',
id: VAE_LOADER,
vae_model: vae,
})
: null;
const vaeSource = seamless ? seamless : vaeLoader ? vaeLoader : modelLoader;
g.addEdge(vaeSource, 'vae', l2i, 'vae');
if (i2l) {
g.addEdge(vaeSource, 'vae', i2l, 'vae');
}
if (vae) {
MetadataUtil.add(g, { vae });
}
};

View File

@ -1,18 +1,19 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
import { addInitialImageToGenerationTabGraph } from 'features/nodes/util/graph/addInitialImageToGenerationTabGraph';
import { addSeamlessToGenerationTabGraph } from 'features/nodes/util/graph/addSeamlessToGenerationTabGraph';
import { addGenerationTabControlLayers } from 'features/nodes/util/graph/addGenerationTabControlLayers';
import { addGenerationTabInitialImage } from 'features/nodes/util/graph/addGenerationTabInitialImage';
import { addGenerationTabLoRAs } from 'features/nodes/util/graph/addGenerationTabLoRAs';
import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless';
import { addGenerationTabVAE } from 'features/nodes/util/graph/addGenerationTabVAE';
import type { GraphType } from 'features/nodes/util/graph/Graph';
import { Graph } from 'features/nodes/util/graph/Graph';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { addHrfToGraph } from './addHrfToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
CLIP_SKIP,
@ -29,7 +30,7 @@ import {
import { getModelMetadataField } from './metadata';
const log = logger('nodes');
export const buildGenerationTabGraph = async (state: RootState): Promise<Graph> => {
export const buildGenerationTabGraph2 = async (state: RootState): Promise<GraphType> => {
const {
model,
cfgScale: cfg_scale,
@ -39,8 +40,6 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<Graph>
clipSkip: skipped_layers,
shouldUseCpuNoise,
vaePrecision,
seamlessXAxis,
seamlessYAxis,
seed,
} = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
@ -114,6 +113,8 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<Graph>
g.addEdge(clipSkip, 'clip', negCond, 'clip');
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
g.addEdge(noise, 'noise', denoise, 'noise');
g.addEdge(denoise, 'latents', l2i, 'latents');
@ -135,20 +136,22 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<Graph>
clip_skip: skipped_layers,
});
MetadataUtil.setMetadataReceivingNode(g, l2i);
g.validate();
const didAddInitialImage = addInitialImageToGenerationTabGraph(state, g, denoise, noise);
const terminalModelLoader = addSeamlessToGenerationTabGraph(state, g, denoise, modelLoader);
const i2l = addGenerationTabInitialImage(state, g, denoise, noise);
g.validate();
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader);
g.validate();
addGenerationTabVAE(state, g, modelLoader, l2i, i2l, seamless);
g.validate();
addGenerationTabLoRAs(state, g, denoise, seamless ?? modelLoader, clipSkip, posCond, negCond);
g.validate();
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
await addControlLayersToGraph(state, graph, DENOISE_LATENTS);
await addGenerationTabControlLayers(state, g, denoise, posCond, negCond, posCondCollect, negCondCollect);
g.validate();
// High resolution fix.
if (state.hrf.hrfEnabled && !didAddInitialImage) {
if (state.hrf.hrfEnabled && !i2l) {
addHrfToGraph(state, graph);
}
@ -163,5 +166,5 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<Graph>
addWatermarkerToGraph(state, graph);
}
return graph;
return g.getGraph();
};