mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP, sd1.5 works
This commit is contained in:
parent
dbe22be598
commit
f8042ffb41
@ -1,7 +1,7 @@
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { buildGenerationTabGraph } from 'features/nodes/util/graph/buildGenerationTabGraph';
|
||||
import { buildGenerationTabGraph2 } from 'features/nodes/util/graph/buildGenerationTabGraph2';
|
||||
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
@ -21,7 +21,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
if (model && model.base === 'sdxl') {
|
||||
graph = await buildGenerationTabSDXLGraph(state);
|
||||
} else {
|
||||
graph = await buildGenerationTabGraph(state);
|
||||
graph = await buildGenerationTabGraph2(state);
|
||||
}
|
||||
|
||||
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
||||
|
@ -289,4 +289,80 @@ describe('Graph', () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteEdgesFrom', () => {
|
||||
it('should delete edges from the provided node', () => {
|
||||
const g = new Graph();
|
||||
const n1 = g.addNode({
|
||||
id: 'n1',
|
||||
type: 'img_resize',
|
||||
});
|
||||
const n2 = g.addNode({
|
||||
id: 'n2',
|
||||
type: 'add',
|
||||
});
|
||||
const _e1 = g.addEdge(n1, 'height', n2, 'a');
|
||||
const _e2 = g.addEdge(n1, 'width', n2, 'b');
|
||||
g.deleteEdgesFrom(n1);
|
||||
expect(g.getEdgesFrom(n1)).toEqual([]);
|
||||
});
|
||||
it('should delete edges from the provided node, with the provided field', () => {
|
||||
const g = new Graph();
|
||||
const n1 = g.addNode({
|
||||
id: 'n1',
|
||||
type: 'img_resize',
|
||||
});
|
||||
const n2 = g.addNode({
|
||||
id: 'n2',
|
||||
type: 'add',
|
||||
});
|
||||
const n3 = g.addNode({
|
||||
id: 'n3',
|
||||
type: 'add',
|
||||
});
|
||||
const _e1 = g.addEdge(n1, 'height', n2, 'a');
|
||||
const e2 = g.addEdge(n1, 'width', n2, 'b');
|
||||
const e3 = g.addEdge(n1, 'width', n3, 'b');
|
||||
g.deleteEdgesFrom(n1, 'height');
|
||||
expect(g.getEdgesFrom(n1)).toEqual([e2, e3]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteEdgesTo', () => {
|
||||
it('should delete edges to the provided node', () => {
|
||||
const g = new Graph();
|
||||
const n1 = g.addNode({
|
||||
id: 'n1',
|
||||
type: 'img_resize',
|
||||
});
|
||||
const n2 = g.addNode({
|
||||
id: 'n2',
|
||||
type: 'add',
|
||||
});
|
||||
const _e1 = g.addEdge(n1, 'height', n2, 'a');
|
||||
const _e2 = g.addEdge(n1, 'width', n2, 'b');
|
||||
g.deleteEdgesTo(n2);
|
||||
expect(g.getEdgesTo(n2)).toEqual([]);
|
||||
});
|
||||
it('should delete edges to the provided node, with the provided field', () => {
|
||||
const g = new Graph();
|
||||
const n1 = g.addNode({
|
||||
id: 'n1',
|
||||
type: 'img_resize',
|
||||
});
|
||||
const n2 = g.addNode({
|
||||
id: 'n2',
|
||||
type: 'img_resize',
|
||||
});
|
||||
const n3 = g.addNode({
|
||||
id: 'n3',
|
||||
type: 'add',
|
||||
});
|
||||
const _e1 = g.addEdge(n1, 'height', n3, 'a');
|
||||
const e2 = g.addEdge(n1, 'width', n3, 'b');
|
||||
const _e3 = g.addEdge(n2, 'width', n3, 'a');
|
||||
g.deleteEdgesTo(n3, 'a');
|
||||
expect(g.getEdgesTo(n3)).toEqual([e2]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { forEach, groupBy, isEqual, values } from 'lodash-es';
|
||||
import type {
|
||||
AnyInvocation,
|
||||
AnyInvocationInputField,
|
||||
@ -22,7 +22,7 @@ type Edge = {
|
||||
};
|
||||
};
|
||||
|
||||
type GraphType = { id: string; nodes: Record<string, AnyInvocation>; edges: Edge[] };
|
||||
export type GraphType = { id: string; nodes: Record<string, AnyInvocation>; edges: Edge[] };
|
||||
|
||||
export class Graph {
|
||||
_graph: GraphType;
|
||||
@ -130,6 +130,31 @@ export class Graph {
|
||||
return edge;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add an edge to the graph. If an edge with the same source and destination already exists, an `AssertionError` is raised.
|
||||
* If providing node ids, provide the from and to node types as generics to get type hints for from and to field names.
|
||||
* @param fromNode The source node or id of the source node.
|
||||
* @param fromField The field of the source node.
|
||||
* @param toNode The source node or id of the destination node.
|
||||
* @param toField The field of the destination node.
|
||||
* @returns The added edge.
|
||||
* @raises `AssertionError` if an edge with the same source and destination already exists.
|
||||
*/
|
||||
addEdgeFromObj(edge: Edge): Edge {
|
||||
const edgeAlreadyExists = this._graph.edges.some((e) => isEqual(e, edge));
|
||||
assert(
|
||||
!edgeAlreadyExists,
|
||||
Graph.getEdgeAlreadyExistsMsg(
|
||||
edge.source.node_id,
|
||||
edge.source.field,
|
||||
edge.destination.node_id,
|
||||
edge.destination.field
|
||||
)
|
||||
);
|
||||
this._graph.edges.push(edge);
|
||||
return edge;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get an edge from the graph. If the edge does not exist, an `AssertionError` is raised.
|
||||
* Provide the from and to node types as generics to get type hints for from and to field names.
|
||||
@ -255,6 +280,24 @@ export class Graph {
|
||||
for (const edge of this._graph.edges) {
|
||||
this.getNode(edge.source.node_id);
|
||||
this.getNode(edge.destination.node_id);
|
||||
assert(
|
||||
!this._graph.edges.filter((e) => e !== edge).find((e) => isEqual(e, edge)),
|
||||
`Duplicate edge: ${Graph.edgeToString(edge)}`
|
||||
);
|
||||
}
|
||||
for (const node of values(this._graph.nodes)) {
|
||||
const edgesTo = this.getEdgesTo(node);
|
||||
// Validate that no node has multiple incoming edges with the same field
|
||||
forEach(groupBy(edgesTo, 'destination.field'), (group, field) => {
|
||||
if (node.type === 'collect' && field === 'item') {
|
||||
// Collectors' item field accepts multiple incoming edges
|
||||
return;
|
||||
}
|
||||
assert(
|
||||
group.length === 1,
|
||||
`Node ${node.id} has multiple incoming edges with field ${field}: ${group.map(Graph.edgeToString).join(', ')}`
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -299,6 +342,10 @@ export class Graph {
|
||||
return `Edge from ${fromNodeId}.${fromField} to ${toNodeId}.${toField} already exists`;
|
||||
}
|
||||
|
||||
static edgeToString(edge: Edge): string {
|
||||
return `${edge.source.node_id}.${edge.source.field} -> ${edge.destination.node_id}.${edge.destination.field}`;
|
||||
}
|
||||
|
||||
static uuid = uuidv4;
|
||||
//#endregion
|
||||
}
|
||||
|
@ -0,0 +1,539 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import {
|
||||
isControlAdapterLayer,
|
||||
isIPAdapterLayer,
|
||||
isRegionalGuidanceLayer,
|
||||
rgLayerMaskImageUploaded,
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type { RegionalGuidanceLayer } from 'features/controlLayers/store/types';
|
||||
import {
|
||||
type ControlNetConfigV2,
|
||||
type ImageWithDims,
|
||||
type IPAdapterConfigV2,
|
||||
isControlNetConfigV2,
|
||||
isT2IAdapterConfigV2,
|
||||
type ProcessorConfig,
|
||||
type T2IAdapterConfigV2,
|
||||
} from 'features/controlLayers/util/controlAdapters';
|
||||
import { getRegionalPromptLayerBlobs } from 'features/controlLayers/util/getLayerBlobs';
|
||||
import type { ImageField } from 'features/nodes/types/common';
|
||||
import {
|
||||
CONTROL_NET_COLLECT,
|
||||
IP_ADAPTER_COLLECT,
|
||||
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
|
||||
PROMPT_REGION_MASK_TO_TENSOR_PREFIX,
|
||||
PROMPT_REGION_NEGATIVE_COND_PREFIX,
|
||||
PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX,
|
||||
PROMPT_REGION_POSITIVE_COND_PREFIX,
|
||||
T2I_ADAPTER_COLLECT,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||
import { size } from 'lodash-es';
|
||||
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO, Invocation, S } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const buildControlImage = (
|
||||
image: ImageWithDims | null,
|
||||
processedImage: ImageWithDims | null,
|
||||
processorConfig: ProcessorConfig | null
|
||||
): ImageField => {
|
||||
if (processedImage && processorConfig) {
|
||||
// We've processed the image in the app - use it for the control image.
|
||||
return {
|
||||
image_name: processedImage.imageName,
|
||||
};
|
||||
} else if (image) {
|
||||
// No processor selected, and we have an image - the user provided a processed image, use it for the control image.
|
||||
return {
|
||||
image_name: image.imageName,
|
||||
};
|
||||
}
|
||||
assert(false, 'Attempted to add unprocessed control image');
|
||||
};
|
||||
|
||||
const buildControlNetMetadata = (controlNet: ControlNetConfigV2): S['ControlNetMetadataField'] => {
|
||||
const { beginEndStepPct, controlMode, image, model, processedImage, processorConfig, weight } = controlNet;
|
||||
|
||||
assert(model, 'ControlNet model is required');
|
||||
assert(image, 'ControlNet image is required');
|
||||
|
||||
const processed_image =
|
||||
processedImage && processorConfig
|
||||
? {
|
||||
image_name: processedImage.imageName,
|
||||
}
|
||||
: null;
|
||||
|
||||
return {
|
||||
control_model: model,
|
||||
control_weight: weight,
|
||||
control_mode: controlMode,
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
resize_mode: 'just_resize',
|
||||
image: {
|
||||
image_name: image.imageName,
|
||||
},
|
||||
processed_image,
|
||||
};
|
||||
};
|
||||
|
||||
const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => {
|
||||
try {
|
||||
// Attempt to retrieve the collector
|
||||
const controlNetCollect = g.getNode(CONTROL_NET_COLLECT);
|
||||
assert(controlNetCollect.type === 'collect');
|
||||
return controlNetCollect;
|
||||
} catch {
|
||||
// Add the ControlNet collector
|
||||
const controlNetCollect = g.addNode({
|
||||
id: CONTROL_NET_COLLECT,
|
||||
type: 'collect',
|
||||
});
|
||||
g.addEdge(controlNetCollect, 'collection', denoise, 'control');
|
||||
return controlNetCollect;
|
||||
}
|
||||
};
|
||||
|
||||
const addGlobalControlNetsToGraph = (
|
||||
controlNetConfigs: ControlNetConfigV2[],
|
||||
g: Graph,
|
||||
denoise: Invocation<'denoise_latents'>
|
||||
): void => {
|
||||
if (controlNetConfigs.length === 0) {
|
||||
return;
|
||||
}
|
||||
const controlNetMetadata: S['ControlNetMetadataField'][] = [];
|
||||
const controlNetCollect = addControlNetCollectorSafe(g, denoise);
|
||||
|
||||
for (const controlNetConfig of controlNetConfigs) {
|
||||
if (!controlNetConfig.model) {
|
||||
return;
|
||||
}
|
||||
const { id, beginEndStepPct, controlMode, image, model, processedImage, processorConfig, weight } =
|
||||
controlNetConfig;
|
||||
|
||||
const controlNet = g.addNode({
|
||||
id: `control_net_${id}`,
|
||||
type: 'controlnet',
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
control_mode: controlMode,
|
||||
resize_mode: 'just_resize',
|
||||
control_model: model,
|
||||
control_weight: weight,
|
||||
image: buildControlImage(image, processedImage, processorConfig),
|
||||
});
|
||||
|
||||
g.addEdge(controlNet, 'control', controlNetCollect, 'item');
|
||||
|
||||
controlNetMetadata.push(buildControlNetMetadata(controlNetConfig));
|
||||
}
|
||||
MetadataUtil.add(g, { controlnets: controlNetMetadata });
|
||||
};
|
||||
|
||||
const buildT2IAdapterMetadata = (t2iAdapter: T2IAdapterConfigV2): S['T2IAdapterMetadataField'] => {
|
||||
const { beginEndStepPct, image, model, processedImage, processorConfig, weight } = t2iAdapter;
|
||||
|
||||
assert(model, 'T2I Adapter model is required');
|
||||
assert(image, 'T2I Adapter image is required');
|
||||
|
||||
const processed_image =
|
||||
processedImage && processorConfig
|
||||
? {
|
||||
image_name: processedImage.imageName,
|
||||
}
|
||||
: null;
|
||||
|
||||
return {
|
||||
t2i_adapter_model: model,
|
||||
weight,
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
resize_mode: 'just_resize',
|
||||
image: {
|
||||
image_name: image.imageName,
|
||||
},
|
||||
processed_image,
|
||||
};
|
||||
};
|
||||
|
||||
const addT2IAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => {
|
||||
try {
|
||||
// You see, we've already got one!
|
||||
const t2iAdapterCollect = g.getNode(T2I_ADAPTER_COLLECT);
|
||||
assert(t2iAdapterCollect.type === 'collect');
|
||||
return t2iAdapterCollect;
|
||||
} catch {
|
||||
const t2iAdapterCollect = g.addNode({
|
||||
id: T2I_ADAPTER_COLLECT,
|
||||
type: 'collect',
|
||||
});
|
||||
|
||||
g.addEdge(t2iAdapterCollect, 'collection', denoise, 't2i_adapter');
|
||||
|
||||
return t2iAdapterCollect;
|
||||
}
|
||||
};
|
||||
|
||||
const addGlobalT2IAdaptersToGraph = (
|
||||
t2iAdapterConfigs: T2IAdapterConfigV2[],
|
||||
g: Graph,
|
||||
denoise: Invocation<'denoise_latents'>
|
||||
): void => {
|
||||
if (t2iAdapterConfigs.length === 0) {
|
||||
return;
|
||||
}
|
||||
const t2iAdapterMetadata: S['T2IAdapterMetadataField'][] = [];
|
||||
const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise);
|
||||
|
||||
for (const t2iAdapterConfig of t2iAdapterConfigs) {
|
||||
if (!t2iAdapterConfig.model) {
|
||||
return;
|
||||
}
|
||||
const { id, beginEndStepPct, image, model, processedImage, processorConfig, weight } = t2iAdapterConfig;
|
||||
|
||||
const t2iAdapter = g.addNode({
|
||||
id: `t2i_adapter_${id}`,
|
||||
type: 't2i_adapter',
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
resize_mode: 'just_resize',
|
||||
t2i_adapter_model: model,
|
||||
weight: weight,
|
||||
image: buildControlImage(image, processedImage, processorConfig),
|
||||
});
|
||||
|
||||
g.addEdge(t2iAdapter, 't2i_adapter', t2iAdapterCollect, 'item');
|
||||
|
||||
t2iAdapterMetadata.push(buildT2IAdapterMetadata(t2iAdapterConfig));
|
||||
}
|
||||
|
||||
MetadataUtil.add(g, { t2iAdapters: t2iAdapterMetadata });
|
||||
};
|
||||
|
||||
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfigV2): S['IPAdapterMetadataField'] => {
|
||||
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
|
||||
|
||||
assert(model, 'IP Adapter model is required');
|
||||
assert(image, 'IP Adapter image is required');
|
||||
|
||||
return {
|
||||
ip_adapter_model: model,
|
||||
clip_vision_model: clipVisionModel,
|
||||
weight,
|
||||
method,
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
image: {
|
||||
image_name: image.imageName,
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
const addIPAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => {
|
||||
try {
|
||||
// You see, we've already got one!
|
||||
const ipAdapterCollect = g.getNode(IP_ADAPTER_COLLECT);
|
||||
assert(ipAdapterCollect.type === 'collect');
|
||||
return ipAdapterCollect;
|
||||
} catch {
|
||||
const ipAdapterCollect = g.addNode({
|
||||
id: IP_ADAPTER_COLLECT,
|
||||
type: 'collect',
|
||||
});
|
||||
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
|
||||
return ipAdapterCollect;
|
||||
}
|
||||
};
|
||||
|
||||
const addGlobalIPAdaptersToGraph = (
|
||||
ipAdapterConfigs: IPAdapterConfigV2[],
|
||||
g: Graph,
|
||||
denoise: Invocation<'denoise_latents'>
|
||||
): void => {
|
||||
if (ipAdapterConfigs.length === 0) {
|
||||
return;
|
||||
}
|
||||
const ipAdapterMetdata: S['IPAdapterMetadataField'][] = [];
|
||||
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
|
||||
|
||||
for (const ipAdapterConfig of ipAdapterConfigs) {
|
||||
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapterConfig;
|
||||
assert(image, 'IP Adapter image is required');
|
||||
assert(model, 'IP Adapter model is required');
|
||||
|
||||
const ipAdapter = g.addNode({
|
||||
id: `ip_adapter_${id}`,
|
||||
type: 'ip_adapter',
|
||||
weight,
|
||||
method,
|
||||
ip_adapter_model: model,
|
||||
clip_vision_model: clipVisionModel,
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
image: {
|
||||
image_name: image.imageName,
|
||||
},
|
||||
});
|
||||
g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item');
|
||||
ipAdapterMetdata.push(buildIPAdapterMetadata(ipAdapterConfig));
|
||||
}
|
||||
|
||||
MetadataUtil.add(g, { ipAdapters: ipAdapterMetdata });
|
||||
};
|
||||
|
||||
export const addGenerationTabControlLayers = async (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
|
||||
negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
|
||||
posCondCollect: Invocation<'collect'>,
|
||||
negCondCollect: Invocation<'collect'>
|
||||
) => {
|
||||
const mainModel = state.generation.model;
|
||||
assert(mainModel, 'Missing main model when building graph');
|
||||
const isSDXL = mainModel.base === 'sdxl';
|
||||
|
||||
// Add global control adapters
|
||||
const globalControlNetConfigs = state.controlLayers.present.layers
|
||||
// Must be a CA layer
|
||||
.filter(isControlAdapterLayer)
|
||||
// Must be enabled
|
||||
.filter((l) => l.isEnabled)
|
||||
// We want the CAs themselves
|
||||
.map((l) => l.controlAdapter)
|
||||
// Must be a ControlNet
|
||||
.filter(isControlNetConfigV2)
|
||||
.filter((ca) => {
|
||||
const hasModel = Boolean(ca.model);
|
||||
const modelMatchesBase = ca.model?.base === mainModel.base;
|
||||
const hasControlImage = ca.image || (ca.processedImage && ca.processorConfig);
|
||||
return hasModel && modelMatchesBase && hasControlImage;
|
||||
});
|
||||
addGlobalControlNetsToGraph(globalControlNetConfigs, g, denoise);
|
||||
|
||||
const globalT2IAdapterConfigs = state.controlLayers.present.layers
|
||||
// Must be a CA layer
|
||||
.filter(isControlAdapterLayer)
|
||||
// Must be enabled
|
||||
.filter((l) => l.isEnabled)
|
||||
// We want the CAs themselves
|
||||
.map((l) => l.controlAdapter)
|
||||
// Must have a ControlNet CA
|
||||
.filter(isT2IAdapterConfigV2)
|
||||
.filter((ca) => {
|
||||
const hasModel = Boolean(ca.model);
|
||||
const modelMatchesBase = ca.model?.base === mainModel.base;
|
||||
const hasControlImage = ca.image || (ca.processedImage && ca.processorConfig);
|
||||
return hasModel && modelMatchesBase && hasControlImage;
|
||||
});
|
||||
addGlobalT2IAdaptersToGraph(globalT2IAdapterConfigs, g, denoise);
|
||||
|
||||
const globalIPAdapterConfigs = state.controlLayers.present.layers
|
||||
// Must be an IP Adapter layer
|
||||
.filter(isIPAdapterLayer)
|
||||
// Must be enabled
|
||||
.filter((l) => l.isEnabled)
|
||||
// We want the IP Adapters themselves
|
||||
.map((l) => l.ipAdapter)
|
||||
.filter((ca) => {
|
||||
const hasModel = Boolean(ca.model);
|
||||
const modelMatchesBase = ca.model?.base === mainModel.base;
|
||||
const hasControlImage = Boolean(ca.image);
|
||||
return hasModel && modelMatchesBase && hasControlImage;
|
||||
});
|
||||
addGlobalIPAdaptersToGraph(globalIPAdapterConfigs, g, denoise);
|
||||
|
||||
const rgLayers = state.controlLayers.present.layers
|
||||
// Only RG layers are get masks
|
||||
.filter(isRegionalGuidanceLayer)
|
||||
// Only visible layers are rendered on the canvas
|
||||
.filter((l) => l.isEnabled)
|
||||
// Only layers with prompts get added to the graph
|
||||
.filter((l) => {
|
||||
const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt);
|
||||
const hasIPAdapter = l.ipAdapters.length !== 0;
|
||||
return hasTextPrompt || hasIPAdapter;
|
||||
});
|
||||
|
||||
const layerIds = rgLayers.map((l) => l.id);
|
||||
const blobs = await getRegionalPromptLayerBlobs(layerIds);
|
||||
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
|
||||
|
||||
for (const layer of rgLayers) {
|
||||
const blob = blobs[layer.id];
|
||||
assert(blob, `Blob for layer ${layer.id} not found`);
|
||||
// Upload the mask image, or get the cached image if it exists
|
||||
const { image_name } = await getMaskImage(layer, blob);
|
||||
|
||||
// The main mask-to-tensor node
|
||||
const maskToTensor = g.addNode({
|
||||
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${layer.id}`,
|
||||
type: 'alpha_mask_to_tensor',
|
||||
image: {
|
||||
image_name,
|
||||
},
|
||||
});
|
||||
|
||||
if (layer.positivePrompt) {
|
||||
// The main positive conditioning node
|
||||
const regionalPosCond = g.addNode(
|
||||
isSDXL
|
||||
? {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
|
||||
prompt: layer.positivePrompt,
|
||||
style: layer.positivePrompt, // TODO: Should we put the positive prompt in both fields?
|
||||
}
|
||||
: {
|
||||
type: 'compel',
|
||||
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
|
||||
prompt: layer.positivePrompt,
|
||||
}
|
||||
);
|
||||
// Connect the mask to the conditioning
|
||||
g.addEdge(maskToTensor, 'mask', regionalPosCond, 'mask');
|
||||
// Connect the conditioning to the collector
|
||||
g.addEdge(regionalPosCond, 'conditioning', posCondCollect, 'item');
|
||||
// Copy the connections to the "global" positive conditioning node to the regional cond
|
||||
for (const edge of g.getEdgesTo(posCond)) {
|
||||
console.log(edge);
|
||||
if (edge.destination.field !== 'prompt') {
|
||||
// Clone the edge, but change the destination node to the regional conditioning node
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (layer.negativePrompt) {
|
||||
// The main negative conditioning node
|
||||
const regionalNegCond = g.addNode(
|
||||
isSDXL
|
||||
? {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
|
||||
prompt: layer.negativePrompt,
|
||||
style: layer.negativePrompt,
|
||||
}
|
||||
: {
|
||||
type: 'compel',
|
||||
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
|
||||
prompt: layer.negativePrompt,
|
||||
}
|
||||
);
|
||||
// Connect the mask to the conditioning
|
||||
g.addEdge(maskToTensor, 'mask', regionalNegCond, 'mask');
|
||||
// Connect the conditioning to the collector
|
||||
g.addEdge(regionalNegCond, 'conditioning', negCondCollect, 'item');
|
||||
// Copy the connections to the "global" negative conditioning node to the regional cond
|
||||
for (const edge of g.getEdgesTo(negCond)) {
|
||||
if (edge.destination.field !== 'prompt') {
|
||||
// Clone the edge, but change the destination node to the regional conditioning node
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalNegCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
|
||||
if (layer.autoNegative === 'invert' && layer.positivePrompt) {
|
||||
// We re-use the mask image, but invert it when converting to tensor
|
||||
const invertTensorMask = g.addNode({
|
||||
id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${layer.id}`,
|
||||
type: 'invert_tensor_mask',
|
||||
});
|
||||
// Connect the OG mask image to the inverted mask-to-tensor node
|
||||
g.addEdge(maskToTensor, 'mask', invertTensorMask, 'mask');
|
||||
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the positive prompt
|
||||
const regionalPosCondInverted = g.addNode(
|
||||
isSDXL
|
||||
? {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
|
||||
prompt: layer.positivePrompt,
|
||||
style: layer.positivePrompt,
|
||||
}
|
||||
: {
|
||||
type: 'compel',
|
||||
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
|
||||
prompt: layer.positivePrompt,
|
||||
}
|
||||
);
|
||||
// Connect the inverted mask to the conditioning
|
||||
g.addEdge(invertTensorMask, 'mask', regionalPosCondInverted, 'mask');
|
||||
// Connect the conditioning to the negative collector
|
||||
g.addEdge(regionalPosCondInverted, 'conditioning', negCondCollect, 'item');
|
||||
// Copy the connections to the "global" positive conditioning node to our regional node
|
||||
for (const edge of g.getEdgesTo(posCond)) {
|
||||
if (edge.destination.field !== 'prompt') {
|
||||
// Clone the edge, but change the destination node to the regional conditioning node
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCondInverted.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(psyche): For some reason, I have to explicitly annotate regionalIPAdapters here. Not sure why.
|
||||
const regionalIPAdapters: IPAdapterConfigV2[] = layer.ipAdapters.filter((ipAdapter) => {
|
||||
const hasModel = Boolean(ipAdapter.model);
|
||||
const modelMatchesBase = ipAdapter.model?.base === mainModel.base;
|
||||
const hasControlImage = Boolean(ipAdapter.image);
|
||||
return hasModel && modelMatchesBase && hasControlImage;
|
||||
});
|
||||
|
||||
for (const ipAdapterConfig of regionalIPAdapters) {
|
||||
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
|
||||
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapterConfig;
|
||||
assert(model, 'IP Adapter model is required');
|
||||
assert(image, 'IP Adapter image is required');
|
||||
|
||||
const ipAdapter = g.addNode({
|
||||
id: `ip_adapter_${id}`,
|
||||
type: 'ip_adapter',
|
||||
weight,
|
||||
method,
|
||||
ip_adapter_model: model,
|
||||
clip_vision_model: clipVisionModel,
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
image: {
|
||||
image_name: image.imageName,
|
||||
},
|
||||
});
|
||||
|
||||
// Connect the mask to the conditioning
|
||||
g.addEdge(maskToTensor, 'mask', ipAdapter, 'mask');
|
||||
g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item');
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {
|
||||
if (layer.uploadedMaskImage) {
|
||||
const imageDTO = await getImageDTO(layer.uploadedMaskImage.imageName);
|
||||
if (imageDTO) {
|
||||
return imageDTO;
|
||||
}
|
||||
}
|
||||
const { dispatch } = getStore();
|
||||
// No cached mask, or the cached image no longer exists - we need to upload the mask image
|
||||
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
|
||||
const req = dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
|
||||
);
|
||||
req.reset();
|
||||
|
||||
const imageDTO = await req.unwrap();
|
||||
dispatch(rgLayerMaskImageUploaded({ layerId: layer.id, imageDTO }));
|
||||
return imageDTO;
|
||||
};
|
@ -15,12 +15,12 @@ import { IMAGE_TO_LATENTS, RESIZE } from './constants';
|
||||
* @param noise The noise node in the graph
|
||||
* @returns Whether the initial image was added to the graph
|
||||
*/
|
||||
export const addInitialImageToGenerationTabGraph = (
|
||||
export const addGenerationTabInitialImage = (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
noise: Invocation<'noise'>
|
||||
): boolean => {
|
||||
): Invocation<'i2l'> | null => {
|
||||
// Remove Existing UNet Connections
|
||||
const { img2imgStrength, vaePrecision, model } = state.generation;
|
||||
const { refinerModel, refinerStart } = state.sdxl;
|
||||
@ -29,7 +29,7 @@ export const addInitialImageToGenerationTabGraph = (
|
||||
const initialImage = initialImageLayer?.isEnabled ? initialImageLayer?.image : null;
|
||||
|
||||
if (!initialImage) {
|
||||
return false;
|
||||
return null;
|
||||
}
|
||||
|
||||
const isSDXL = model?.base === 'sdxl';
|
||||
@ -75,5 +75,5 @@ export const addInitialImageToGenerationTabGraph = (
|
||||
init_image: initialImage.imageName,
|
||||
});
|
||||
|
||||
return true;
|
||||
return i2l;
|
||||
};
|
@ -0,0 +1,94 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { Graph } from 'features/nodes/util/graph/Graph';
|
||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||
import { filter, size } from 'lodash-es';
|
||||
import type { Invocation, S } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { LORA_LOADER } from './constants';
|
||||
|
||||
export const addGenerationTabLoRAs = (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
unetSource: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'>,
|
||||
clipSkip: Invocation<'clip_skip'>,
|
||||
posCond: Invocation<'compel'>,
|
||||
negCond: Invocation<'compel'>
|
||||
): void => {
|
||||
/**
|
||||
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
||||
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
|
||||
* or to the inference/conditioning nodes.
|
||||
*
|
||||
* So we need to inject a LoRA chain into the graph.
|
||||
*/
|
||||
|
||||
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
|
||||
const loraCount = size(enabledLoRAs);
|
||||
|
||||
if (loraCount === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Remove modelLoaderNodeId unet connection to feed it to LoRAs
|
||||
console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e)));
|
||||
g.deleteEdgesFrom(unetSource, 'unet');
|
||||
console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e)));
|
||||
if (clipSkip) {
|
||||
// Remove CLIP_SKIP connections to conditionings to feed it through LoRAs
|
||||
g.deleteEdgesFrom(clipSkip, 'clip');
|
||||
}
|
||||
console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e)));
|
||||
|
||||
// we need to remember the last lora so we can chain from it
|
||||
let lastLoRALoader: Invocation<'lora_loader'> | null = null;
|
||||
let currentLoraIndex = 0;
|
||||
const loraMetadata: S['LoRAMetadataField'][] = [];
|
||||
|
||||
for (const lora of enabledLoRAs) {
|
||||
const { weight } = lora;
|
||||
const { key } = lora.model;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||
const parsedModel = zModelIdentifierField.parse(lora.model);
|
||||
|
||||
const currentLoRALoader = g.addNode({
|
||||
type: 'lora_loader',
|
||||
id: currentLoraNodeId,
|
||||
lora: parsedModel,
|
||||
weight,
|
||||
});
|
||||
|
||||
loraMetadata.push({
|
||||
model: parsedModel,
|
||||
weight,
|
||||
});
|
||||
|
||||
// add to graph
|
||||
if (currentLoraIndex === 0) {
|
||||
// first lora = start the lora chain, attach directly to model loader
|
||||
g.addEdge(unetSource, 'unet', currentLoRALoader, 'unet');
|
||||
g.addEdge(clipSkip, 'clip', currentLoRALoader, 'clip');
|
||||
} else {
|
||||
assert(lastLoRALoader !== null);
|
||||
// we are in the middle of the lora chain, instead connect to the previous lora
|
||||
g.addEdge(lastLoRALoader, 'unet', currentLoRALoader, 'unet');
|
||||
g.addEdge(lastLoRALoader, 'clip', currentLoRALoader, 'clip');
|
||||
}
|
||||
|
||||
if (currentLoraIndex === loraCount - 1) {
|
||||
// final lora, end the lora chain - we need to connect up to inference and conditioning nodes
|
||||
g.addEdge(currentLoRALoader, 'unet', denoise, 'unet');
|
||||
g.addEdge(currentLoRALoader, 'clip', posCond, 'clip');
|
||||
g.addEdge(currentLoRALoader, 'clip', negCond, 'clip');
|
||||
}
|
||||
|
||||
// increment the lora for the next one in the chain
|
||||
lastLoRALoader = currentLoRALoader;
|
||||
currentLoraIndex += 1;
|
||||
}
|
||||
|
||||
MetadataUtil.add(g, { loras: loraMetadata });
|
||||
};
|
@ -15,26 +15,28 @@ import { SEAMLESS, VAE_LOADER } from './constants';
|
||||
* @param modelLoader The model loader node in the graph
|
||||
* @returns The terminal model loader node in the graph
|
||||
*/
|
||||
export const addSeamlessToGenerationTabGraph = (
|
||||
export const addGenerationTabSeamless = (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
modelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'>
|
||||
): Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'> => {
|
||||
const { seamlessXAxis, seamlessYAxis, vae } = state.generation;
|
||||
): Invocation<'seamless'> | null => {
|
||||
const { seamlessXAxis: seamless_x, seamlessYAxis: seamless_y, vae } = state.generation;
|
||||
|
||||
if (!seamlessXAxis && !seamlessYAxis) {
|
||||
return modelLoader;
|
||||
if (!seamless_x && !seamless_y) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const seamless = g.addNode({
|
||||
id: SEAMLESS,
|
||||
type: 'seamless',
|
||||
seamless_x: seamlessXAxis,
|
||||
seamless_y: seamlessYAxis,
|
||||
seamless_x,
|
||||
seamless_y,
|
||||
});
|
||||
|
||||
const vaeLoader = vae
|
||||
// The VAE helper also adds the VAE loader - so we need to check if it's already there
|
||||
const shouldAddVAELoader = !g.hasNode(VAE_LOADER) && vae;
|
||||
const vaeLoader = shouldAddVAELoader
|
||||
? g.addNode({
|
||||
type: 'vae_loader',
|
||||
id: VAE_LOADER,
|
||||
@ -42,29 +44,18 @@ export const addSeamlessToGenerationTabGraph = (
|
||||
})
|
||||
: null;
|
||||
|
||||
let terminalModelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'> =
|
||||
modelLoader;
|
||||
|
||||
if (seamlessXAxis) {
|
||||
MetadataUtil.add(g, {
|
||||
seamless_x: seamlessXAxis,
|
||||
seamless_x: seamless_x || undefined,
|
||||
seamless_y: seamless_y || undefined,
|
||||
});
|
||||
terminalModelLoader = seamless;
|
||||
}
|
||||
if (seamlessYAxis) {
|
||||
MetadataUtil.add(g, {
|
||||
seamless_y: seamlessYAxis,
|
||||
});
|
||||
terminalModelLoader = seamless;
|
||||
}
|
||||
|
||||
// Seamless slots into the graph between the model loader and the denoise node
|
||||
g.deleteEdgesFrom(modelLoader, 'unet');
|
||||
g.deleteEdgesFrom(modelLoader, 'clip');
|
||||
g.deleteEdgesFrom(modelLoader, 'vae');
|
||||
|
||||
g.addEdge(modelLoader, 'unet', seamless, 'unet');
|
||||
g.addEdge(vaeLoader ?? modelLoader, 'vae', seamless, 'unet');
|
||||
g.addEdge(vaeLoader ?? modelLoader, 'vae', seamless, 'vae');
|
||||
g.addEdge(seamless, 'unet', denoise, 'unet');
|
||||
|
||||
return terminalModelLoader;
|
||||
return seamless;
|
||||
};
|
@ -0,0 +1,37 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
|
||||
import { VAE_LOADER } from './constants';
|
||||
|
||||
export const addGenerationTabVAE = (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
modelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'>,
|
||||
l2i: Invocation<'l2i'>,
|
||||
i2l: Invocation<'i2l'> | null,
|
||||
seamless: Invocation<'seamless'> | null
|
||||
): void => {
|
||||
const { vae } = state.generation;
|
||||
|
||||
// The seamless helper also adds the VAE loader... so we need to check if it's already there
|
||||
const shouldAddVAELoader = !g.hasNode(VAE_LOADER) && vae;
|
||||
const vaeLoader = shouldAddVAELoader
|
||||
? g.addNode({
|
||||
type: 'vae_loader',
|
||||
id: VAE_LOADER,
|
||||
vae_model: vae,
|
||||
})
|
||||
: null;
|
||||
|
||||
const vaeSource = seamless ? seamless : vaeLoader ? vaeLoader : modelLoader;
|
||||
g.addEdge(vaeSource, 'vae', l2i, 'vae');
|
||||
if (i2l) {
|
||||
g.addEdge(vaeSource, 'vae', i2l, 'vae');
|
||||
}
|
||||
|
||||
if (vae) {
|
||||
MetadataUtil.add(g, { vae });
|
||||
}
|
||||
};
|
@ -1,18 +1,19 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
|
||||
import { addInitialImageToGenerationTabGraph } from 'features/nodes/util/graph/addInitialImageToGenerationTabGraph';
|
||||
import { addSeamlessToGenerationTabGraph } from 'features/nodes/util/graph/addSeamlessToGenerationTabGraph';
|
||||
import { addGenerationTabControlLayers } from 'features/nodes/util/graph/addGenerationTabControlLayers';
|
||||
import { addGenerationTabInitialImage } from 'features/nodes/util/graph/addGenerationTabInitialImage';
|
||||
import { addGenerationTabLoRAs } from 'features/nodes/util/graph/addGenerationTabLoRAs';
|
||||
import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless';
|
||||
import { addGenerationTabVAE } from 'features/nodes/util/graph/addGenerationTabVAE';
|
||||
import type { GraphType } from 'features/nodes/util/graph/Graph';
|
||||
import { Graph } from 'features/nodes/util/graph/Graph';
|
||||
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
|
||||
import { addHrfToGraph } from './addHrfToGraph';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
CLIP_SKIP,
|
||||
@ -29,7 +30,7 @@ import {
|
||||
import { getModelMetadataField } from './metadata';
|
||||
|
||||
const log = logger('nodes');
|
||||
export const buildGenerationTabGraph = async (state: RootState): Promise<Graph> => {
|
||||
export const buildGenerationTabGraph2 = async (state: RootState): Promise<GraphType> => {
|
||||
const {
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
@ -39,8 +40,6 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<Graph>
|
||||
clipSkip: skipped_layers,
|
||||
shouldUseCpuNoise,
|
||||
vaePrecision,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
seed,
|
||||
} = state.generation;
|
||||
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
||||
@ -114,6 +113,8 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<Graph>
|
||||
g.addEdge(clipSkip, 'clip', negCond, 'clip');
|
||||
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
|
||||
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
|
||||
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
|
||||
g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
|
||||
g.addEdge(noise, 'noise', denoise, 'noise');
|
||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||
|
||||
@ -135,20 +136,22 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<Graph>
|
||||
clip_skip: skipped_layers,
|
||||
});
|
||||
MetadataUtil.setMetadataReceivingNode(g, l2i);
|
||||
g.validate();
|
||||
|
||||
const didAddInitialImage = addInitialImageToGenerationTabGraph(state, g, denoise, noise);
|
||||
const terminalModelLoader = addSeamlessToGenerationTabGraph(state, g, denoise, modelLoader);
|
||||
const i2l = addGenerationTabInitialImage(state, g, denoise, noise);
|
||||
g.validate();
|
||||
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader);
|
||||
g.validate();
|
||||
addGenerationTabVAE(state, g, modelLoader, l2i, i2l, seamless);
|
||||
g.validate();
|
||||
addGenerationTabLoRAs(state, g, denoise, seamless ?? modelLoader, clipSkip, posCond, negCond);
|
||||
g.validate();
|
||||
|
||||
// optionally add custom VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
await addControlLayersToGraph(state, graph, DENOISE_LATENTS);
|
||||
await addGenerationTabControlLayers(state, g, denoise, posCond, negCond, posCondCollect, negCondCollect);
|
||||
g.validate();
|
||||
|
||||
// High resolution fix.
|
||||
if (state.hrf.hrfEnabled && !didAddInitialImage) {
|
||||
if (state.hrf.hrfEnabled && !i2l) {
|
||||
addHrfToGraph(state, graph);
|
||||
}
|
||||
|
||||
@ -163,5 +166,5 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<Graph>
|
||||
addWatermarkerToGraph(state, graph);
|
||||
}
|
||||
|
||||
return graph;
|
||||
return g.getGraph();
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user