mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP, sd1.5 works
This commit is contained in:
parent
dbe22be598
commit
f8042ffb41
@ -1,7 +1,7 @@
|
|||||||
import { enqueueRequested } from 'app/store/actions';
|
import { enqueueRequested } from 'app/store/actions';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||||
import { buildGenerationTabGraph } from 'features/nodes/util/graph/buildGenerationTabGraph';
|
import { buildGenerationTabGraph2 } from 'features/nodes/util/graph/buildGenerationTabGraph2';
|
||||||
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph';
|
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph';
|
||||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||||
import { queueApi } from 'services/api/endpoints/queue';
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
@ -21,7 +21,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
|||||||
if (model && model.base === 'sdxl') {
|
if (model && model.base === 'sdxl') {
|
||||||
graph = await buildGenerationTabSDXLGraph(state);
|
graph = await buildGenerationTabSDXLGraph(state);
|
||||||
} else {
|
} else {
|
||||||
graph = await buildGenerationTabGraph(state);
|
graph = await buildGenerationTabGraph2(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
||||||
|
@ -289,4 +289,80 @@ describe('Graph', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('deleteEdgesFrom', () => {
|
||||||
|
it('should delete edges from the provided node', () => {
|
||||||
|
const g = new Graph();
|
||||||
|
const n1 = g.addNode({
|
||||||
|
id: 'n1',
|
||||||
|
type: 'img_resize',
|
||||||
|
});
|
||||||
|
const n2 = g.addNode({
|
||||||
|
id: 'n2',
|
||||||
|
type: 'add',
|
||||||
|
});
|
||||||
|
const _e1 = g.addEdge(n1, 'height', n2, 'a');
|
||||||
|
const _e2 = g.addEdge(n1, 'width', n2, 'b');
|
||||||
|
g.deleteEdgesFrom(n1);
|
||||||
|
expect(g.getEdgesFrom(n1)).toEqual([]);
|
||||||
|
});
|
||||||
|
it('should delete edges from the provided node, with the provided field', () => {
|
||||||
|
const g = new Graph();
|
||||||
|
const n1 = g.addNode({
|
||||||
|
id: 'n1',
|
||||||
|
type: 'img_resize',
|
||||||
|
});
|
||||||
|
const n2 = g.addNode({
|
||||||
|
id: 'n2',
|
||||||
|
type: 'add',
|
||||||
|
});
|
||||||
|
const n3 = g.addNode({
|
||||||
|
id: 'n3',
|
||||||
|
type: 'add',
|
||||||
|
});
|
||||||
|
const _e1 = g.addEdge(n1, 'height', n2, 'a');
|
||||||
|
const e2 = g.addEdge(n1, 'width', n2, 'b');
|
||||||
|
const e3 = g.addEdge(n1, 'width', n3, 'b');
|
||||||
|
g.deleteEdgesFrom(n1, 'height');
|
||||||
|
expect(g.getEdgesFrom(n1)).toEqual([e2, e3]);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('deleteEdgesTo', () => {
|
||||||
|
it('should delete edges to the provided node', () => {
|
||||||
|
const g = new Graph();
|
||||||
|
const n1 = g.addNode({
|
||||||
|
id: 'n1',
|
||||||
|
type: 'img_resize',
|
||||||
|
});
|
||||||
|
const n2 = g.addNode({
|
||||||
|
id: 'n2',
|
||||||
|
type: 'add',
|
||||||
|
});
|
||||||
|
const _e1 = g.addEdge(n1, 'height', n2, 'a');
|
||||||
|
const _e2 = g.addEdge(n1, 'width', n2, 'b');
|
||||||
|
g.deleteEdgesTo(n2);
|
||||||
|
expect(g.getEdgesTo(n2)).toEqual([]);
|
||||||
|
});
|
||||||
|
it('should delete edges to the provided node, with the provided field', () => {
|
||||||
|
const g = new Graph();
|
||||||
|
const n1 = g.addNode({
|
||||||
|
id: 'n1',
|
||||||
|
type: 'img_resize',
|
||||||
|
});
|
||||||
|
const n2 = g.addNode({
|
||||||
|
id: 'n2',
|
||||||
|
type: 'img_resize',
|
||||||
|
});
|
||||||
|
const n3 = g.addNode({
|
||||||
|
id: 'n3',
|
||||||
|
type: 'add',
|
||||||
|
});
|
||||||
|
const _e1 = g.addEdge(n1, 'height', n3, 'a');
|
||||||
|
const e2 = g.addEdge(n1, 'width', n3, 'b');
|
||||||
|
const _e3 = g.addEdge(n2, 'width', n3, 'a');
|
||||||
|
g.deleteEdgesTo(n3, 'a');
|
||||||
|
expect(g.getEdgesTo(n3)).toEqual([e2]);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { isEqual } from 'lodash-es';
|
import { forEach, groupBy, isEqual, values } from 'lodash-es';
|
||||||
import type {
|
import type {
|
||||||
AnyInvocation,
|
AnyInvocation,
|
||||||
AnyInvocationInputField,
|
AnyInvocationInputField,
|
||||||
@ -22,7 +22,7 @@ type Edge = {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
type GraphType = { id: string; nodes: Record<string, AnyInvocation>; edges: Edge[] };
|
export type GraphType = { id: string; nodes: Record<string, AnyInvocation>; edges: Edge[] };
|
||||||
|
|
||||||
export class Graph {
|
export class Graph {
|
||||||
_graph: GraphType;
|
_graph: GraphType;
|
||||||
@ -130,6 +130,31 @@ export class Graph {
|
|||||||
return edge;
|
return edge;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add an edge to the graph. If an edge with the same source and destination already exists, an `AssertionError` is raised.
|
||||||
|
* If providing node ids, provide the from and to node types as generics to get type hints for from and to field names.
|
||||||
|
* @param fromNode The source node or id of the source node.
|
||||||
|
* @param fromField The field of the source node.
|
||||||
|
* @param toNode The source node or id of the destination node.
|
||||||
|
* @param toField The field of the destination node.
|
||||||
|
* @returns The added edge.
|
||||||
|
* @raises `AssertionError` if an edge with the same source and destination already exists.
|
||||||
|
*/
|
||||||
|
addEdgeFromObj(edge: Edge): Edge {
|
||||||
|
const edgeAlreadyExists = this._graph.edges.some((e) => isEqual(e, edge));
|
||||||
|
assert(
|
||||||
|
!edgeAlreadyExists,
|
||||||
|
Graph.getEdgeAlreadyExistsMsg(
|
||||||
|
edge.source.node_id,
|
||||||
|
edge.source.field,
|
||||||
|
edge.destination.node_id,
|
||||||
|
edge.destination.field
|
||||||
|
)
|
||||||
|
);
|
||||||
|
this._graph.edges.push(edge);
|
||||||
|
return edge;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get an edge from the graph. If the edge does not exist, an `AssertionError` is raised.
|
* Get an edge from the graph. If the edge does not exist, an `AssertionError` is raised.
|
||||||
* Provide the from and to node types as generics to get type hints for from and to field names.
|
* Provide the from and to node types as generics to get type hints for from and to field names.
|
||||||
@ -255,6 +280,24 @@ export class Graph {
|
|||||||
for (const edge of this._graph.edges) {
|
for (const edge of this._graph.edges) {
|
||||||
this.getNode(edge.source.node_id);
|
this.getNode(edge.source.node_id);
|
||||||
this.getNode(edge.destination.node_id);
|
this.getNode(edge.destination.node_id);
|
||||||
|
assert(
|
||||||
|
!this._graph.edges.filter((e) => e !== edge).find((e) => isEqual(e, edge)),
|
||||||
|
`Duplicate edge: ${Graph.edgeToString(edge)}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
for (const node of values(this._graph.nodes)) {
|
||||||
|
const edgesTo = this.getEdgesTo(node);
|
||||||
|
// Validate that no node has multiple incoming edges with the same field
|
||||||
|
forEach(groupBy(edgesTo, 'destination.field'), (group, field) => {
|
||||||
|
if (node.type === 'collect' && field === 'item') {
|
||||||
|
// Collectors' item field accepts multiple incoming edges
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
assert(
|
||||||
|
group.length === 1,
|
||||||
|
`Node ${node.id} has multiple incoming edges with field ${field}: ${group.map(Graph.edgeToString).join(', ')}`
|
||||||
|
);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -299,6 +342,10 @@ export class Graph {
|
|||||||
return `Edge from ${fromNodeId}.${fromField} to ${toNodeId}.${toField} already exists`;
|
return `Edge from ${fromNodeId}.${fromField} to ${toNodeId}.${toField} already exists`;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static edgeToString(edge: Edge): string {
|
||||||
|
return `${edge.source.node_id}.${edge.source.field} -> ${edge.destination.node_id}.${edge.destination.field}`;
|
||||||
|
}
|
||||||
|
|
||||||
static uuid = uuidv4;
|
static uuid = uuidv4;
|
||||||
//#endregion
|
//#endregion
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,539 @@
|
|||||||
|
import { getStore } from 'app/store/nanostores/store';
|
||||||
|
import type { RootState } from 'app/store/store';
|
||||||
|
import { deepClone } from 'common/util/deepClone';
|
||||||
|
import {
|
||||||
|
isControlAdapterLayer,
|
||||||
|
isIPAdapterLayer,
|
||||||
|
isRegionalGuidanceLayer,
|
||||||
|
rgLayerMaskImageUploaded,
|
||||||
|
} from 'features/controlLayers/store/controlLayersSlice';
|
||||||
|
import type { RegionalGuidanceLayer } from 'features/controlLayers/store/types';
|
||||||
|
import {
|
||||||
|
type ControlNetConfigV2,
|
||||||
|
type ImageWithDims,
|
||||||
|
type IPAdapterConfigV2,
|
||||||
|
isControlNetConfigV2,
|
||||||
|
isT2IAdapterConfigV2,
|
||||||
|
type ProcessorConfig,
|
||||||
|
type T2IAdapterConfigV2,
|
||||||
|
} from 'features/controlLayers/util/controlAdapters';
|
||||||
|
import { getRegionalPromptLayerBlobs } from 'features/controlLayers/util/getLayerBlobs';
|
||||||
|
import type { ImageField } from 'features/nodes/types/common';
|
||||||
|
import {
|
||||||
|
CONTROL_NET_COLLECT,
|
||||||
|
IP_ADAPTER_COLLECT,
|
||||||
|
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
|
||||||
|
PROMPT_REGION_MASK_TO_TENSOR_PREFIX,
|
||||||
|
PROMPT_REGION_NEGATIVE_COND_PREFIX,
|
||||||
|
PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX,
|
||||||
|
PROMPT_REGION_POSITIVE_COND_PREFIX,
|
||||||
|
T2I_ADAPTER_COLLECT,
|
||||||
|
} from 'features/nodes/util/graph/constants';
|
||||||
|
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||||
|
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||||
|
import { size } from 'lodash-es';
|
||||||
|
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
||||||
|
import type { ImageDTO, Invocation, S } from 'services/api/types';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
const buildControlImage = (
|
||||||
|
image: ImageWithDims | null,
|
||||||
|
processedImage: ImageWithDims | null,
|
||||||
|
processorConfig: ProcessorConfig | null
|
||||||
|
): ImageField => {
|
||||||
|
if (processedImage && processorConfig) {
|
||||||
|
// We've processed the image in the app - use it for the control image.
|
||||||
|
return {
|
||||||
|
image_name: processedImage.imageName,
|
||||||
|
};
|
||||||
|
} else if (image) {
|
||||||
|
// No processor selected, and we have an image - the user provided a processed image, use it for the control image.
|
||||||
|
return {
|
||||||
|
image_name: image.imageName,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
assert(false, 'Attempted to add unprocessed control image');
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildControlNetMetadata = (controlNet: ControlNetConfigV2): S['ControlNetMetadataField'] => {
|
||||||
|
const { beginEndStepPct, controlMode, image, model, processedImage, processorConfig, weight } = controlNet;
|
||||||
|
|
||||||
|
assert(model, 'ControlNet model is required');
|
||||||
|
assert(image, 'ControlNet image is required');
|
||||||
|
|
||||||
|
const processed_image =
|
||||||
|
processedImage && processorConfig
|
||||||
|
? {
|
||||||
|
image_name: processedImage.imageName,
|
||||||
|
}
|
||||||
|
: null;
|
||||||
|
|
||||||
|
return {
|
||||||
|
control_model: model,
|
||||||
|
control_weight: weight,
|
||||||
|
control_mode: controlMode,
|
||||||
|
begin_step_percent: beginEndStepPct[0],
|
||||||
|
end_step_percent: beginEndStepPct[1],
|
||||||
|
resize_mode: 'just_resize',
|
||||||
|
image: {
|
||||||
|
image_name: image.imageName,
|
||||||
|
},
|
||||||
|
processed_image,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => {
|
||||||
|
try {
|
||||||
|
// Attempt to retrieve the collector
|
||||||
|
const controlNetCollect = g.getNode(CONTROL_NET_COLLECT);
|
||||||
|
assert(controlNetCollect.type === 'collect');
|
||||||
|
return controlNetCollect;
|
||||||
|
} catch {
|
||||||
|
// Add the ControlNet collector
|
||||||
|
const controlNetCollect = g.addNode({
|
||||||
|
id: CONTROL_NET_COLLECT,
|
||||||
|
type: 'collect',
|
||||||
|
});
|
||||||
|
g.addEdge(controlNetCollect, 'collection', denoise, 'control');
|
||||||
|
return controlNetCollect;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const addGlobalControlNetsToGraph = (
|
||||||
|
controlNetConfigs: ControlNetConfigV2[],
|
||||||
|
g: Graph,
|
||||||
|
denoise: Invocation<'denoise_latents'>
|
||||||
|
): void => {
|
||||||
|
if (controlNetConfigs.length === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const controlNetMetadata: S['ControlNetMetadataField'][] = [];
|
||||||
|
const controlNetCollect = addControlNetCollectorSafe(g, denoise);
|
||||||
|
|
||||||
|
for (const controlNetConfig of controlNetConfigs) {
|
||||||
|
if (!controlNetConfig.model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const { id, beginEndStepPct, controlMode, image, model, processedImage, processorConfig, weight } =
|
||||||
|
controlNetConfig;
|
||||||
|
|
||||||
|
const controlNet = g.addNode({
|
||||||
|
id: `control_net_${id}`,
|
||||||
|
type: 'controlnet',
|
||||||
|
begin_step_percent: beginEndStepPct[0],
|
||||||
|
end_step_percent: beginEndStepPct[1],
|
||||||
|
control_mode: controlMode,
|
||||||
|
resize_mode: 'just_resize',
|
||||||
|
control_model: model,
|
||||||
|
control_weight: weight,
|
||||||
|
image: buildControlImage(image, processedImage, processorConfig),
|
||||||
|
});
|
||||||
|
|
||||||
|
g.addEdge(controlNet, 'control', controlNetCollect, 'item');
|
||||||
|
|
||||||
|
controlNetMetadata.push(buildControlNetMetadata(controlNetConfig));
|
||||||
|
}
|
||||||
|
MetadataUtil.add(g, { controlnets: controlNetMetadata });
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildT2IAdapterMetadata = (t2iAdapter: T2IAdapterConfigV2): S['T2IAdapterMetadataField'] => {
|
||||||
|
const { beginEndStepPct, image, model, processedImage, processorConfig, weight } = t2iAdapter;
|
||||||
|
|
||||||
|
assert(model, 'T2I Adapter model is required');
|
||||||
|
assert(image, 'T2I Adapter image is required');
|
||||||
|
|
||||||
|
const processed_image =
|
||||||
|
processedImage && processorConfig
|
||||||
|
? {
|
||||||
|
image_name: processedImage.imageName,
|
||||||
|
}
|
||||||
|
: null;
|
||||||
|
|
||||||
|
return {
|
||||||
|
t2i_adapter_model: model,
|
||||||
|
weight,
|
||||||
|
begin_step_percent: beginEndStepPct[0],
|
||||||
|
end_step_percent: beginEndStepPct[1],
|
||||||
|
resize_mode: 'just_resize',
|
||||||
|
image: {
|
||||||
|
image_name: image.imageName,
|
||||||
|
},
|
||||||
|
processed_image,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const addT2IAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => {
|
||||||
|
try {
|
||||||
|
// You see, we've already got one!
|
||||||
|
const t2iAdapterCollect = g.getNode(T2I_ADAPTER_COLLECT);
|
||||||
|
assert(t2iAdapterCollect.type === 'collect');
|
||||||
|
return t2iAdapterCollect;
|
||||||
|
} catch {
|
||||||
|
const t2iAdapterCollect = g.addNode({
|
||||||
|
id: T2I_ADAPTER_COLLECT,
|
||||||
|
type: 'collect',
|
||||||
|
});
|
||||||
|
|
||||||
|
g.addEdge(t2iAdapterCollect, 'collection', denoise, 't2i_adapter');
|
||||||
|
|
||||||
|
return t2iAdapterCollect;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const addGlobalT2IAdaptersToGraph = (
|
||||||
|
t2iAdapterConfigs: T2IAdapterConfigV2[],
|
||||||
|
g: Graph,
|
||||||
|
denoise: Invocation<'denoise_latents'>
|
||||||
|
): void => {
|
||||||
|
if (t2iAdapterConfigs.length === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const t2iAdapterMetadata: S['T2IAdapterMetadataField'][] = [];
|
||||||
|
const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise);
|
||||||
|
|
||||||
|
for (const t2iAdapterConfig of t2iAdapterConfigs) {
|
||||||
|
if (!t2iAdapterConfig.model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const { id, beginEndStepPct, image, model, processedImage, processorConfig, weight } = t2iAdapterConfig;
|
||||||
|
|
||||||
|
const t2iAdapter = g.addNode({
|
||||||
|
id: `t2i_adapter_${id}`,
|
||||||
|
type: 't2i_adapter',
|
||||||
|
begin_step_percent: beginEndStepPct[0],
|
||||||
|
end_step_percent: beginEndStepPct[1],
|
||||||
|
resize_mode: 'just_resize',
|
||||||
|
t2i_adapter_model: model,
|
||||||
|
weight: weight,
|
||||||
|
image: buildControlImage(image, processedImage, processorConfig),
|
||||||
|
});
|
||||||
|
|
||||||
|
g.addEdge(t2iAdapter, 't2i_adapter', t2iAdapterCollect, 'item');
|
||||||
|
|
||||||
|
t2iAdapterMetadata.push(buildT2IAdapterMetadata(t2iAdapterConfig));
|
||||||
|
}
|
||||||
|
|
||||||
|
MetadataUtil.add(g, { t2iAdapters: t2iAdapterMetadata });
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfigV2): S['IPAdapterMetadataField'] => {
|
||||||
|
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
|
||||||
|
|
||||||
|
assert(model, 'IP Adapter model is required');
|
||||||
|
assert(image, 'IP Adapter image is required');
|
||||||
|
|
||||||
|
return {
|
||||||
|
ip_adapter_model: model,
|
||||||
|
clip_vision_model: clipVisionModel,
|
||||||
|
weight,
|
||||||
|
method,
|
||||||
|
begin_step_percent: beginEndStepPct[0],
|
||||||
|
end_step_percent: beginEndStepPct[1],
|
||||||
|
image: {
|
||||||
|
image_name: image.imageName,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const addIPAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => {
|
||||||
|
try {
|
||||||
|
// You see, we've already got one!
|
||||||
|
const ipAdapterCollect = g.getNode(IP_ADAPTER_COLLECT);
|
||||||
|
assert(ipAdapterCollect.type === 'collect');
|
||||||
|
return ipAdapterCollect;
|
||||||
|
} catch {
|
||||||
|
const ipAdapterCollect = g.addNode({
|
||||||
|
id: IP_ADAPTER_COLLECT,
|
||||||
|
type: 'collect',
|
||||||
|
});
|
||||||
|
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
|
||||||
|
return ipAdapterCollect;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const addGlobalIPAdaptersToGraph = (
|
||||||
|
ipAdapterConfigs: IPAdapterConfigV2[],
|
||||||
|
g: Graph,
|
||||||
|
denoise: Invocation<'denoise_latents'>
|
||||||
|
): void => {
|
||||||
|
if (ipAdapterConfigs.length === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const ipAdapterMetdata: S['IPAdapterMetadataField'][] = [];
|
||||||
|
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
|
||||||
|
|
||||||
|
for (const ipAdapterConfig of ipAdapterConfigs) {
|
||||||
|
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapterConfig;
|
||||||
|
assert(image, 'IP Adapter image is required');
|
||||||
|
assert(model, 'IP Adapter model is required');
|
||||||
|
|
||||||
|
const ipAdapter = g.addNode({
|
||||||
|
id: `ip_adapter_${id}`,
|
||||||
|
type: 'ip_adapter',
|
||||||
|
weight,
|
||||||
|
method,
|
||||||
|
ip_adapter_model: model,
|
||||||
|
clip_vision_model: clipVisionModel,
|
||||||
|
begin_step_percent: beginEndStepPct[0],
|
||||||
|
end_step_percent: beginEndStepPct[1],
|
||||||
|
image: {
|
||||||
|
image_name: image.imageName,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item');
|
||||||
|
ipAdapterMetdata.push(buildIPAdapterMetadata(ipAdapterConfig));
|
||||||
|
}
|
||||||
|
|
||||||
|
MetadataUtil.add(g, { ipAdapters: ipAdapterMetdata });
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addGenerationTabControlLayers = async (
|
||||||
|
state: RootState,
|
||||||
|
g: Graph,
|
||||||
|
denoise: Invocation<'denoise_latents'>,
|
||||||
|
posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
|
||||||
|
negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
|
||||||
|
posCondCollect: Invocation<'collect'>,
|
||||||
|
negCondCollect: Invocation<'collect'>
|
||||||
|
) => {
|
||||||
|
const mainModel = state.generation.model;
|
||||||
|
assert(mainModel, 'Missing main model when building graph');
|
||||||
|
const isSDXL = mainModel.base === 'sdxl';
|
||||||
|
|
||||||
|
// Add global control adapters
|
||||||
|
const globalControlNetConfigs = state.controlLayers.present.layers
|
||||||
|
// Must be a CA layer
|
||||||
|
.filter(isControlAdapterLayer)
|
||||||
|
// Must be enabled
|
||||||
|
.filter((l) => l.isEnabled)
|
||||||
|
// We want the CAs themselves
|
||||||
|
.map((l) => l.controlAdapter)
|
||||||
|
// Must be a ControlNet
|
||||||
|
.filter(isControlNetConfigV2)
|
||||||
|
.filter((ca) => {
|
||||||
|
const hasModel = Boolean(ca.model);
|
||||||
|
const modelMatchesBase = ca.model?.base === mainModel.base;
|
||||||
|
const hasControlImage = ca.image || (ca.processedImage && ca.processorConfig);
|
||||||
|
return hasModel && modelMatchesBase && hasControlImage;
|
||||||
|
});
|
||||||
|
addGlobalControlNetsToGraph(globalControlNetConfigs, g, denoise);
|
||||||
|
|
||||||
|
const globalT2IAdapterConfigs = state.controlLayers.present.layers
|
||||||
|
// Must be a CA layer
|
||||||
|
.filter(isControlAdapterLayer)
|
||||||
|
// Must be enabled
|
||||||
|
.filter((l) => l.isEnabled)
|
||||||
|
// We want the CAs themselves
|
||||||
|
.map((l) => l.controlAdapter)
|
||||||
|
// Must have a ControlNet CA
|
||||||
|
.filter(isT2IAdapterConfigV2)
|
||||||
|
.filter((ca) => {
|
||||||
|
const hasModel = Boolean(ca.model);
|
||||||
|
const modelMatchesBase = ca.model?.base === mainModel.base;
|
||||||
|
const hasControlImage = ca.image || (ca.processedImage && ca.processorConfig);
|
||||||
|
return hasModel && modelMatchesBase && hasControlImage;
|
||||||
|
});
|
||||||
|
addGlobalT2IAdaptersToGraph(globalT2IAdapterConfigs, g, denoise);
|
||||||
|
|
||||||
|
const globalIPAdapterConfigs = state.controlLayers.present.layers
|
||||||
|
// Must be an IP Adapter layer
|
||||||
|
.filter(isIPAdapterLayer)
|
||||||
|
// Must be enabled
|
||||||
|
.filter((l) => l.isEnabled)
|
||||||
|
// We want the IP Adapters themselves
|
||||||
|
.map((l) => l.ipAdapter)
|
||||||
|
.filter((ca) => {
|
||||||
|
const hasModel = Boolean(ca.model);
|
||||||
|
const modelMatchesBase = ca.model?.base === mainModel.base;
|
||||||
|
const hasControlImage = Boolean(ca.image);
|
||||||
|
return hasModel && modelMatchesBase && hasControlImage;
|
||||||
|
});
|
||||||
|
addGlobalIPAdaptersToGraph(globalIPAdapterConfigs, g, denoise);
|
||||||
|
|
||||||
|
const rgLayers = state.controlLayers.present.layers
|
||||||
|
// Only RG layers are get masks
|
||||||
|
.filter(isRegionalGuidanceLayer)
|
||||||
|
// Only visible layers are rendered on the canvas
|
||||||
|
.filter((l) => l.isEnabled)
|
||||||
|
// Only layers with prompts get added to the graph
|
||||||
|
.filter((l) => {
|
||||||
|
const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt);
|
||||||
|
const hasIPAdapter = l.ipAdapters.length !== 0;
|
||||||
|
return hasTextPrompt || hasIPAdapter;
|
||||||
|
});
|
||||||
|
|
||||||
|
const layerIds = rgLayers.map((l) => l.id);
|
||||||
|
const blobs = await getRegionalPromptLayerBlobs(layerIds);
|
||||||
|
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
|
||||||
|
|
||||||
|
for (const layer of rgLayers) {
|
||||||
|
const blob = blobs[layer.id];
|
||||||
|
assert(blob, `Blob for layer ${layer.id} not found`);
|
||||||
|
// Upload the mask image, or get the cached image if it exists
|
||||||
|
const { image_name } = await getMaskImage(layer, blob);
|
||||||
|
|
||||||
|
// The main mask-to-tensor node
|
||||||
|
const maskToTensor = g.addNode({
|
||||||
|
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${layer.id}`,
|
||||||
|
type: 'alpha_mask_to_tensor',
|
||||||
|
image: {
|
||||||
|
image_name,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (layer.positivePrompt) {
|
||||||
|
// The main positive conditioning node
|
||||||
|
const regionalPosCond = g.addNode(
|
||||||
|
isSDXL
|
||||||
|
? {
|
||||||
|
type: 'sdxl_compel_prompt',
|
||||||
|
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
|
||||||
|
prompt: layer.positivePrompt,
|
||||||
|
style: layer.positivePrompt, // TODO: Should we put the positive prompt in both fields?
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
type: 'compel',
|
||||||
|
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
|
||||||
|
prompt: layer.positivePrompt,
|
||||||
|
}
|
||||||
|
);
|
||||||
|
// Connect the mask to the conditioning
|
||||||
|
g.addEdge(maskToTensor, 'mask', regionalPosCond, 'mask');
|
||||||
|
// Connect the conditioning to the collector
|
||||||
|
g.addEdge(regionalPosCond, 'conditioning', posCondCollect, 'item');
|
||||||
|
// Copy the connections to the "global" positive conditioning node to the regional cond
|
||||||
|
for (const edge of g.getEdgesTo(posCond)) {
|
||||||
|
console.log(edge);
|
||||||
|
if (edge.destination.field !== 'prompt') {
|
||||||
|
// Clone the edge, but change the destination node to the regional conditioning node
|
||||||
|
const clone = deepClone(edge);
|
||||||
|
clone.destination.node_id = regionalPosCond.id;
|
||||||
|
g.addEdgeFromObj(clone);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (layer.negativePrompt) {
|
||||||
|
// The main negative conditioning node
|
||||||
|
const regionalNegCond = g.addNode(
|
||||||
|
isSDXL
|
||||||
|
? {
|
||||||
|
type: 'sdxl_compel_prompt',
|
||||||
|
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
|
||||||
|
prompt: layer.negativePrompt,
|
||||||
|
style: layer.negativePrompt,
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
type: 'compel',
|
||||||
|
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
|
||||||
|
prompt: layer.negativePrompt,
|
||||||
|
}
|
||||||
|
);
|
||||||
|
// Connect the mask to the conditioning
|
||||||
|
g.addEdge(maskToTensor, 'mask', regionalNegCond, 'mask');
|
||||||
|
// Connect the conditioning to the collector
|
||||||
|
g.addEdge(regionalNegCond, 'conditioning', negCondCollect, 'item');
|
||||||
|
// Copy the connections to the "global" negative conditioning node to the regional cond
|
||||||
|
for (const edge of g.getEdgesTo(negCond)) {
|
||||||
|
if (edge.destination.field !== 'prompt') {
|
||||||
|
// Clone the edge, but change the destination node to the regional conditioning node
|
||||||
|
const clone = deepClone(edge);
|
||||||
|
clone.destination.node_id = regionalNegCond.id;
|
||||||
|
g.addEdgeFromObj(clone);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
|
||||||
|
if (layer.autoNegative === 'invert' && layer.positivePrompt) {
|
||||||
|
// We re-use the mask image, but invert it when converting to tensor
|
||||||
|
const invertTensorMask = g.addNode({
|
||||||
|
id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${layer.id}`,
|
||||||
|
type: 'invert_tensor_mask',
|
||||||
|
});
|
||||||
|
// Connect the OG mask image to the inverted mask-to-tensor node
|
||||||
|
g.addEdge(maskToTensor, 'mask', invertTensorMask, 'mask');
|
||||||
|
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the positive prompt
|
||||||
|
const regionalPosCondInverted = g.addNode(
|
||||||
|
isSDXL
|
||||||
|
? {
|
||||||
|
type: 'sdxl_compel_prompt',
|
||||||
|
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
|
||||||
|
prompt: layer.positivePrompt,
|
||||||
|
style: layer.positivePrompt,
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
type: 'compel',
|
||||||
|
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
|
||||||
|
prompt: layer.positivePrompt,
|
||||||
|
}
|
||||||
|
);
|
||||||
|
// Connect the inverted mask to the conditioning
|
||||||
|
g.addEdge(invertTensorMask, 'mask', regionalPosCondInverted, 'mask');
|
||||||
|
// Connect the conditioning to the negative collector
|
||||||
|
g.addEdge(regionalPosCondInverted, 'conditioning', negCondCollect, 'item');
|
||||||
|
// Copy the connections to the "global" positive conditioning node to our regional node
|
||||||
|
for (const edge of g.getEdgesTo(posCond)) {
|
||||||
|
if (edge.destination.field !== 'prompt') {
|
||||||
|
// Clone the edge, but change the destination node to the regional conditioning node
|
||||||
|
const clone = deepClone(edge);
|
||||||
|
clone.destination.node_id = regionalPosCondInverted.id;
|
||||||
|
g.addEdgeFromObj(clone);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(psyche): For some reason, I have to explicitly annotate regionalIPAdapters here. Not sure why.
|
||||||
|
const regionalIPAdapters: IPAdapterConfigV2[] = layer.ipAdapters.filter((ipAdapter) => {
|
||||||
|
const hasModel = Boolean(ipAdapter.model);
|
||||||
|
const modelMatchesBase = ipAdapter.model?.base === mainModel.base;
|
||||||
|
const hasControlImage = Boolean(ipAdapter.image);
|
||||||
|
return hasModel && modelMatchesBase && hasControlImage;
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const ipAdapterConfig of regionalIPAdapters) {
|
||||||
|
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
|
||||||
|
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapterConfig;
|
||||||
|
assert(model, 'IP Adapter model is required');
|
||||||
|
assert(image, 'IP Adapter image is required');
|
||||||
|
|
||||||
|
const ipAdapter = g.addNode({
|
||||||
|
id: `ip_adapter_${id}`,
|
||||||
|
type: 'ip_adapter',
|
||||||
|
weight,
|
||||||
|
method,
|
||||||
|
ip_adapter_model: model,
|
||||||
|
clip_vision_model: clipVisionModel,
|
||||||
|
begin_step_percent: beginEndStepPct[0],
|
||||||
|
end_step_percent: beginEndStepPct[1],
|
||||||
|
image: {
|
||||||
|
image_name: image.imageName,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Connect the mask to the conditioning
|
||||||
|
g.addEdge(maskToTensor, 'mask', ipAdapter, 'mask');
|
||||||
|
g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {
|
||||||
|
if (layer.uploadedMaskImage) {
|
||||||
|
const imageDTO = await getImageDTO(layer.uploadedMaskImage.imageName);
|
||||||
|
if (imageDTO) {
|
||||||
|
return imageDTO;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const { dispatch } = getStore();
|
||||||
|
// No cached mask, or the cached image no longer exists - we need to upload the mask image
|
||||||
|
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
|
||||||
|
const req = dispatch(
|
||||||
|
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
|
||||||
|
);
|
||||||
|
req.reset();
|
||||||
|
|
||||||
|
const imageDTO = await req.unwrap();
|
||||||
|
dispatch(rgLayerMaskImageUploaded({ layerId: layer.id, imageDTO }));
|
||||||
|
return imageDTO;
|
||||||
|
};
|
@ -15,12 +15,12 @@ import { IMAGE_TO_LATENTS, RESIZE } from './constants';
|
|||||||
* @param noise The noise node in the graph
|
* @param noise The noise node in the graph
|
||||||
* @returns Whether the initial image was added to the graph
|
* @returns Whether the initial image was added to the graph
|
||||||
*/
|
*/
|
||||||
export const addInitialImageToGenerationTabGraph = (
|
export const addGenerationTabInitialImage = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
g: Graph,
|
g: Graph,
|
||||||
denoise: Invocation<'denoise_latents'>,
|
denoise: Invocation<'denoise_latents'>,
|
||||||
noise: Invocation<'noise'>
|
noise: Invocation<'noise'>
|
||||||
): boolean => {
|
): Invocation<'i2l'> | null => {
|
||||||
// Remove Existing UNet Connections
|
// Remove Existing UNet Connections
|
||||||
const { img2imgStrength, vaePrecision, model } = state.generation;
|
const { img2imgStrength, vaePrecision, model } = state.generation;
|
||||||
const { refinerModel, refinerStart } = state.sdxl;
|
const { refinerModel, refinerStart } = state.sdxl;
|
||||||
@ -29,7 +29,7 @@ export const addInitialImageToGenerationTabGraph = (
|
|||||||
const initialImage = initialImageLayer?.isEnabled ? initialImageLayer?.image : null;
|
const initialImage = initialImageLayer?.isEnabled ? initialImageLayer?.image : null;
|
||||||
|
|
||||||
if (!initialImage) {
|
if (!initialImage) {
|
||||||
return false;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
const isSDXL = model?.base === 'sdxl';
|
const isSDXL = model?.base === 'sdxl';
|
||||||
@ -75,5 +75,5 @@ export const addInitialImageToGenerationTabGraph = (
|
|||||||
init_image: initialImage.imageName,
|
init_image: initialImage.imageName,
|
||||||
});
|
});
|
||||||
|
|
||||||
return true;
|
return i2l;
|
||||||
};
|
};
|
@ -0,0 +1,94 @@
|
|||||||
|
import type { RootState } from 'app/store/store';
|
||||||
|
import { deepClone } from 'common/util/deepClone';
|
||||||
|
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||||
|
import { Graph } from 'features/nodes/util/graph/Graph';
|
||||||
|
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||||
|
import { filter, size } from 'lodash-es';
|
||||||
|
import type { Invocation, S } from 'services/api/types';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
import { LORA_LOADER } from './constants';
|
||||||
|
|
||||||
|
export const addGenerationTabLoRAs = (
|
||||||
|
state: RootState,
|
||||||
|
g: Graph,
|
||||||
|
denoise: Invocation<'denoise_latents'>,
|
||||||
|
unetSource: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'>,
|
||||||
|
clipSkip: Invocation<'clip_skip'>,
|
||||||
|
posCond: Invocation<'compel'>,
|
||||||
|
negCond: Invocation<'compel'>
|
||||||
|
): void => {
|
||||||
|
/**
|
||||||
|
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
||||||
|
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
|
||||||
|
* or to the inference/conditioning nodes.
|
||||||
|
*
|
||||||
|
* So we need to inject a LoRA chain into the graph.
|
||||||
|
*/
|
||||||
|
|
||||||
|
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
|
||||||
|
const loraCount = size(enabledLoRAs);
|
||||||
|
|
||||||
|
if (loraCount === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove modelLoaderNodeId unet connection to feed it to LoRAs
|
||||||
|
console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e)));
|
||||||
|
g.deleteEdgesFrom(unetSource, 'unet');
|
||||||
|
console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e)));
|
||||||
|
if (clipSkip) {
|
||||||
|
// Remove CLIP_SKIP connections to conditionings to feed it through LoRAs
|
||||||
|
g.deleteEdgesFrom(clipSkip, 'clip');
|
||||||
|
}
|
||||||
|
console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e)));
|
||||||
|
|
||||||
|
// we need to remember the last lora so we can chain from it
|
||||||
|
let lastLoRALoader: Invocation<'lora_loader'> | null = null;
|
||||||
|
let currentLoraIndex = 0;
|
||||||
|
const loraMetadata: S['LoRAMetadataField'][] = [];
|
||||||
|
|
||||||
|
for (const lora of enabledLoRAs) {
|
||||||
|
const { weight } = lora;
|
||||||
|
const { key } = lora.model;
|
||||||
|
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||||
|
const parsedModel = zModelIdentifierField.parse(lora.model);
|
||||||
|
|
||||||
|
const currentLoRALoader = g.addNode({
|
||||||
|
type: 'lora_loader',
|
||||||
|
id: currentLoraNodeId,
|
||||||
|
lora: parsedModel,
|
||||||
|
weight,
|
||||||
|
});
|
||||||
|
|
||||||
|
loraMetadata.push({
|
||||||
|
model: parsedModel,
|
||||||
|
weight,
|
||||||
|
});
|
||||||
|
|
||||||
|
// add to graph
|
||||||
|
if (currentLoraIndex === 0) {
|
||||||
|
// first lora = start the lora chain, attach directly to model loader
|
||||||
|
g.addEdge(unetSource, 'unet', currentLoRALoader, 'unet');
|
||||||
|
g.addEdge(clipSkip, 'clip', currentLoRALoader, 'clip');
|
||||||
|
} else {
|
||||||
|
assert(lastLoRALoader !== null);
|
||||||
|
// we are in the middle of the lora chain, instead connect to the previous lora
|
||||||
|
g.addEdge(lastLoRALoader, 'unet', currentLoRALoader, 'unet');
|
||||||
|
g.addEdge(lastLoRALoader, 'clip', currentLoRALoader, 'clip');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (currentLoraIndex === loraCount - 1) {
|
||||||
|
// final lora, end the lora chain - we need to connect up to inference and conditioning nodes
|
||||||
|
g.addEdge(currentLoRALoader, 'unet', denoise, 'unet');
|
||||||
|
g.addEdge(currentLoRALoader, 'clip', posCond, 'clip');
|
||||||
|
g.addEdge(currentLoRALoader, 'clip', negCond, 'clip');
|
||||||
|
}
|
||||||
|
|
||||||
|
// increment the lora for the next one in the chain
|
||||||
|
lastLoRALoader = currentLoRALoader;
|
||||||
|
currentLoraIndex += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
MetadataUtil.add(g, { loras: loraMetadata });
|
||||||
|
};
|
@ -15,26 +15,28 @@ import { SEAMLESS, VAE_LOADER } from './constants';
|
|||||||
* @param modelLoader The model loader node in the graph
|
* @param modelLoader The model loader node in the graph
|
||||||
* @returns The terminal model loader node in the graph
|
* @returns The terminal model loader node in the graph
|
||||||
*/
|
*/
|
||||||
export const addSeamlessToGenerationTabGraph = (
|
export const addGenerationTabSeamless = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
g: Graph,
|
g: Graph,
|
||||||
denoise: Invocation<'denoise_latents'>,
|
denoise: Invocation<'denoise_latents'>,
|
||||||
modelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'>
|
modelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'>
|
||||||
): Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'> => {
|
): Invocation<'seamless'> | null => {
|
||||||
const { seamlessXAxis, seamlessYAxis, vae } = state.generation;
|
const { seamlessXAxis: seamless_x, seamlessYAxis: seamless_y, vae } = state.generation;
|
||||||
|
|
||||||
if (!seamlessXAxis && !seamlessYAxis) {
|
if (!seamless_x && !seamless_y) {
|
||||||
return modelLoader;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
const seamless = g.addNode({
|
const seamless = g.addNode({
|
||||||
id: SEAMLESS,
|
id: SEAMLESS,
|
||||||
type: 'seamless',
|
type: 'seamless',
|
||||||
seamless_x: seamlessXAxis,
|
seamless_x,
|
||||||
seamless_y: seamlessYAxis,
|
seamless_y,
|
||||||
});
|
});
|
||||||
|
|
||||||
const vaeLoader = vae
|
// The VAE helper also adds the VAE loader - so we need to check if it's already there
|
||||||
|
const shouldAddVAELoader = !g.hasNode(VAE_LOADER) && vae;
|
||||||
|
const vaeLoader = shouldAddVAELoader
|
||||||
? g.addNode({
|
? g.addNode({
|
||||||
type: 'vae_loader',
|
type: 'vae_loader',
|
||||||
id: VAE_LOADER,
|
id: VAE_LOADER,
|
||||||
@ -42,29 +44,18 @@ export const addSeamlessToGenerationTabGraph = (
|
|||||||
})
|
})
|
||||||
: null;
|
: null;
|
||||||
|
|
||||||
let terminalModelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'> =
|
MetadataUtil.add(g, {
|
||||||
modelLoader;
|
seamless_x: seamless_x || undefined,
|
||||||
|
seamless_y: seamless_y || undefined,
|
||||||
if (seamlessXAxis) {
|
});
|
||||||
MetadataUtil.add(g, {
|
|
||||||
seamless_x: seamlessXAxis,
|
|
||||||
});
|
|
||||||
terminalModelLoader = seamless;
|
|
||||||
}
|
|
||||||
if (seamlessYAxis) {
|
|
||||||
MetadataUtil.add(g, {
|
|
||||||
seamless_y: seamlessYAxis,
|
|
||||||
});
|
|
||||||
terminalModelLoader = seamless;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Seamless slots into the graph between the model loader and the denoise node
|
// Seamless slots into the graph between the model loader and the denoise node
|
||||||
g.deleteEdgesFrom(modelLoader, 'unet');
|
g.deleteEdgesFrom(modelLoader, 'unet');
|
||||||
g.deleteEdgesFrom(modelLoader, 'clip');
|
g.deleteEdgesFrom(modelLoader, 'vae');
|
||||||
|
|
||||||
g.addEdge(modelLoader, 'unet', seamless, 'unet');
|
g.addEdge(modelLoader, 'unet', seamless, 'unet');
|
||||||
g.addEdge(vaeLoader ?? modelLoader, 'vae', seamless, 'unet');
|
g.addEdge(vaeLoader ?? modelLoader, 'vae', seamless, 'vae');
|
||||||
g.addEdge(seamless, 'unet', denoise, 'unet');
|
g.addEdge(seamless, 'unet', denoise, 'unet');
|
||||||
|
|
||||||
return terminalModelLoader;
|
return seamless;
|
||||||
};
|
};
|
@ -0,0 +1,37 @@
|
|||||||
|
import type { RootState } from 'app/store/store';
|
||||||
|
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||||
|
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||||
|
import type { Invocation } from 'services/api/types';
|
||||||
|
|
||||||
|
import { VAE_LOADER } from './constants';
|
||||||
|
|
||||||
|
export const addGenerationTabVAE = (
|
||||||
|
state: RootState,
|
||||||
|
g: Graph,
|
||||||
|
modelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'>,
|
||||||
|
l2i: Invocation<'l2i'>,
|
||||||
|
i2l: Invocation<'i2l'> | null,
|
||||||
|
seamless: Invocation<'seamless'> | null
|
||||||
|
): void => {
|
||||||
|
const { vae } = state.generation;
|
||||||
|
|
||||||
|
// The seamless helper also adds the VAE loader... so we need to check if it's already there
|
||||||
|
const shouldAddVAELoader = !g.hasNode(VAE_LOADER) && vae;
|
||||||
|
const vaeLoader = shouldAddVAELoader
|
||||||
|
? g.addNode({
|
||||||
|
type: 'vae_loader',
|
||||||
|
id: VAE_LOADER,
|
||||||
|
vae_model: vae,
|
||||||
|
})
|
||||||
|
: null;
|
||||||
|
|
||||||
|
const vaeSource = seamless ? seamless : vaeLoader ? vaeLoader : modelLoader;
|
||||||
|
g.addEdge(vaeSource, 'vae', l2i, 'vae');
|
||||||
|
if (i2l) {
|
||||||
|
g.addEdge(vaeSource, 'vae', i2l, 'vae');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vae) {
|
||||||
|
MetadataUtil.add(g, { vae });
|
||||||
|
}
|
||||||
|
};
|
@ -1,18 +1,19 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
|
import { addGenerationTabControlLayers } from 'features/nodes/util/graph/addGenerationTabControlLayers';
|
||||||
import { addInitialImageToGenerationTabGraph } from 'features/nodes/util/graph/addInitialImageToGenerationTabGraph';
|
import { addGenerationTabInitialImage } from 'features/nodes/util/graph/addGenerationTabInitialImage';
|
||||||
import { addSeamlessToGenerationTabGraph } from 'features/nodes/util/graph/addSeamlessToGenerationTabGraph';
|
import { addGenerationTabLoRAs } from 'features/nodes/util/graph/addGenerationTabLoRAs';
|
||||||
|
import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless';
|
||||||
|
import { addGenerationTabVAE } from 'features/nodes/util/graph/addGenerationTabVAE';
|
||||||
|
import type { GraphType } from 'features/nodes/util/graph/Graph';
|
||||||
import { Graph } from 'features/nodes/util/graph/Graph';
|
import { Graph } from 'features/nodes/util/graph/Graph';
|
||||||
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
|
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import { addHrfToGraph } from './addHrfToGraph';
|
import { addHrfToGraph } from './addHrfToGraph';
|
||||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
|
||||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||||
import { addVAEToGraph } from './addVAEToGraph';
|
|
||||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||||
import {
|
import {
|
||||||
CLIP_SKIP,
|
CLIP_SKIP,
|
||||||
@ -29,7 +30,7 @@ import {
|
|||||||
import { getModelMetadataField } from './metadata';
|
import { getModelMetadataField } from './metadata';
|
||||||
|
|
||||||
const log = logger('nodes');
|
const log = logger('nodes');
|
||||||
export const buildGenerationTabGraph = async (state: RootState): Promise<Graph> => {
|
export const buildGenerationTabGraph2 = async (state: RootState): Promise<GraphType> => {
|
||||||
const {
|
const {
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
@ -39,8 +40,6 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<Graph>
|
|||||||
clipSkip: skipped_layers,
|
clipSkip: skipped_layers,
|
||||||
shouldUseCpuNoise,
|
shouldUseCpuNoise,
|
||||||
vaePrecision,
|
vaePrecision,
|
||||||
seamlessXAxis,
|
|
||||||
seamlessYAxis,
|
|
||||||
seed,
|
seed,
|
||||||
} = state.generation;
|
} = state.generation;
|
||||||
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
||||||
@ -114,6 +113,8 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<Graph>
|
|||||||
g.addEdge(clipSkip, 'clip', negCond, 'clip');
|
g.addEdge(clipSkip, 'clip', negCond, 'clip');
|
||||||
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
|
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
|
||||||
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
|
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
|
||||||
|
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
|
||||||
|
g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
|
||||||
g.addEdge(noise, 'noise', denoise, 'noise');
|
g.addEdge(noise, 'noise', denoise, 'noise');
|
||||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||||
|
|
||||||
@ -135,20 +136,22 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<Graph>
|
|||||||
clip_skip: skipped_layers,
|
clip_skip: skipped_layers,
|
||||||
});
|
});
|
||||||
MetadataUtil.setMetadataReceivingNode(g, l2i);
|
MetadataUtil.setMetadataReceivingNode(g, l2i);
|
||||||
|
g.validate();
|
||||||
|
|
||||||
const didAddInitialImage = addInitialImageToGenerationTabGraph(state, g, denoise, noise);
|
const i2l = addGenerationTabInitialImage(state, g, denoise, noise);
|
||||||
const terminalModelLoader = addSeamlessToGenerationTabGraph(state, g, denoise, modelLoader);
|
g.validate();
|
||||||
|
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader);
|
||||||
|
g.validate();
|
||||||
|
addGenerationTabVAE(state, g, modelLoader, l2i, i2l, seamless);
|
||||||
|
g.validate();
|
||||||
|
addGenerationTabLoRAs(state, g, denoise, seamless ?? modelLoader, clipSkip, posCond, negCond);
|
||||||
|
g.validate();
|
||||||
|
|
||||||
// optionally add custom VAE
|
await addGenerationTabControlLayers(state, g, denoise, posCond, negCond, posCondCollect, negCondCollect);
|
||||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
g.validate();
|
||||||
|
|
||||||
// add LoRA support
|
|
||||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
|
||||||
|
|
||||||
await addControlLayersToGraph(state, graph, DENOISE_LATENTS);
|
|
||||||
|
|
||||||
// High resolution fix.
|
// High resolution fix.
|
||||||
if (state.hrf.hrfEnabled && !didAddInitialImage) {
|
if (state.hrf.hrfEnabled && !i2l) {
|
||||||
addHrfToGraph(state, graph);
|
addHrfToGraph(state, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -163,5 +166,5 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<Graph>
|
|||||||
addWatermarkerToGraph(state, graph);
|
addWatermarkerToGraph(state, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
return graph;
|
return g.getGraph();
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user