diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts index 2d267b92b2..bbb77c9ac5 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts @@ -1,7 +1,7 @@ import { enqueueRequested } from 'app/store/actions'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice'; -import { buildGenerationTabGraph } from 'features/nodes/util/graph/buildGenerationTabGraph'; +import { buildGenerationTabGraph2 } from 'features/nodes/util/graph/buildGenerationTabGraph2'; import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { queueApi } from 'services/api/endpoints/queue'; @@ -21,7 +21,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) if (model && model.base === 'sdxl') { graph = await buildGenerationTabSDXLGraph(state); } else { - graph = await buildGenerationTabGraph(state); + graph = await buildGenerationTabGraph2(state); } const batchConfig = prepareLinearUIBatch(state, graph, prepend); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/Graph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/Graph.test.ts index 71bcd9331c..b11e16545f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/Graph.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/Graph.test.ts @@ -289,4 +289,80 @@ describe('Graph', () => { }); }); }); + + describe('deleteEdgesFrom', () => { + it('should delete edges from the provided node', () => { + const g = new Graph(); + const n1 = g.addNode({ + id: 'n1', + type: 'img_resize', + }); + const n2 = g.addNode({ + id: 'n2', + type: 'add', + }); + const _e1 = g.addEdge(n1, 'height', n2, 'a'); + const _e2 = g.addEdge(n1, 'width', n2, 'b'); + g.deleteEdgesFrom(n1); + expect(g.getEdgesFrom(n1)).toEqual([]); + }); + it('should delete edges from the provided node, with the provided field', () => { + const g = new Graph(); + const n1 = g.addNode({ + id: 'n1', + type: 'img_resize', + }); + const n2 = g.addNode({ + id: 'n2', + type: 'add', + }); + const n3 = g.addNode({ + id: 'n3', + type: 'add', + }); + const _e1 = g.addEdge(n1, 'height', n2, 'a'); + const e2 = g.addEdge(n1, 'width', n2, 'b'); + const e3 = g.addEdge(n1, 'width', n3, 'b'); + g.deleteEdgesFrom(n1, 'height'); + expect(g.getEdgesFrom(n1)).toEqual([e2, e3]); + }); + }); + + describe('deleteEdgesTo', () => { + it('should delete edges to the provided node', () => { + const g = new Graph(); + const n1 = g.addNode({ + id: 'n1', + type: 'img_resize', + }); + const n2 = g.addNode({ + id: 'n2', + type: 'add', + }); + const _e1 = g.addEdge(n1, 'height', n2, 'a'); + const _e2 = g.addEdge(n1, 'width', n2, 'b'); + g.deleteEdgesTo(n2); + expect(g.getEdgesTo(n2)).toEqual([]); + }); + it('should delete edges to the provided node, with the provided field', () => { + const g = new Graph(); + const n1 = g.addNode({ + id: 'n1', + type: 'img_resize', + }); + const n2 = g.addNode({ + id: 'n2', + type: 'img_resize', + }); + const n3 = g.addNode({ + id: 'n3', + type: 'add', + }); + const _e1 = g.addEdge(n1, 'height', n3, 'a'); + const e2 = g.addEdge(n1, 'width', n3, 'b'); + const _e3 = g.addEdge(n2, 'width', n3, 'a'); + g.deleteEdgesTo(n3, 'a'); + expect(g.getEdgesTo(n3)).toEqual([e2]); + }); + }); }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/Graph.ts index e25cbaa78d..b578c5b40a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/Graph.ts @@ -1,4 +1,4 @@ -import { isEqual } from 'lodash-es'; +import { forEach, groupBy, isEqual, values } from 'lodash-es'; import type { AnyInvocation, AnyInvocationInputField, @@ -22,7 +22,7 @@ type Edge = { }; }; -type GraphType = { id: string; nodes: Record; edges: Edge[] }; +export type GraphType = { id: string; nodes: Record; edges: Edge[] }; export class Graph { _graph: GraphType; @@ -130,6 +130,31 @@ export class Graph { return edge; } + /** + * Add an edge to the graph. If an edge with the same source and destination already exists, an `AssertionError` is raised. + * If providing node ids, provide the from and to node types as generics to get type hints for from and to field names. + * @param fromNode The source node or id of the source node. + * @param fromField The field of the source node. + * @param toNode The source node or id of the destination node. + * @param toField The field of the destination node. + * @returns The added edge. + * @raises `AssertionError` if an edge with the same source and destination already exists. + */ + addEdgeFromObj(edge: Edge): Edge { + const edgeAlreadyExists = this._graph.edges.some((e) => isEqual(e, edge)); + assert( + !edgeAlreadyExists, + Graph.getEdgeAlreadyExistsMsg( + edge.source.node_id, + edge.source.field, + edge.destination.node_id, + edge.destination.field + ) + ); + this._graph.edges.push(edge); + return edge; + } + /** * Get an edge from the graph. If the edge does not exist, an `AssertionError` is raised. * Provide the from and to node types as generics to get type hints for from and to field names. @@ -255,6 +280,24 @@ export class Graph { for (const edge of this._graph.edges) { this.getNode(edge.source.node_id); this.getNode(edge.destination.node_id); + assert( + !this._graph.edges.filter((e) => e !== edge).find((e) => isEqual(e, edge)), + `Duplicate edge: ${Graph.edgeToString(edge)}` + ); + } + for (const node of values(this._graph.nodes)) { + const edgesTo = this.getEdgesTo(node); + // Validate that no node has multiple incoming edges with the same field + forEach(groupBy(edgesTo, 'destination.field'), (group, field) => { + if (node.type === 'collect' && field === 'item') { + // Collectors' item field accepts multiple incoming edges + return; + } + assert( + group.length === 1, + `Node ${node.id} has multiple incoming edges with field ${field}: ${group.map(Graph.edgeToString).join(', ')}` + ); + }); } } @@ -299,6 +342,10 @@ export class Graph { return `Edge from ${fromNodeId}.${fromField} to ${toNodeId}.${toField} already exists`; } + static edgeToString(edge: Edge): string { + return `${edge.source.node_id}.${edge.source.field} -> ${edge.destination.node_id}.${edge.destination.field}`; + } + static uuid = uuidv4; //#endregion } diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabControlLayers.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabControlLayers.ts new file mode 100644 index 0000000000..0bc907e5e1 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabControlLayers.ts @@ -0,0 +1,539 @@ +import { getStore } from 'app/store/nanostores/store'; +import type { RootState } from 'app/store/store'; +import { deepClone } from 'common/util/deepClone'; +import { + isControlAdapterLayer, + isIPAdapterLayer, + isRegionalGuidanceLayer, + rgLayerMaskImageUploaded, +} from 'features/controlLayers/store/controlLayersSlice'; +import type { RegionalGuidanceLayer } from 'features/controlLayers/store/types'; +import { + type ControlNetConfigV2, + type ImageWithDims, + type IPAdapterConfigV2, + isControlNetConfigV2, + isT2IAdapterConfigV2, + type ProcessorConfig, + type T2IAdapterConfigV2, +} from 'features/controlLayers/util/controlAdapters'; +import { getRegionalPromptLayerBlobs } from 'features/controlLayers/util/getLayerBlobs'; +import type { ImageField } from 'features/nodes/types/common'; +import { + CONTROL_NET_COLLECT, + IP_ADAPTER_COLLECT, + PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX, + PROMPT_REGION_MASK_TO_TENSOR_PREFIX, + PROMPT_REGION_NEGATIVE_COND_PREFIX, + PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX, + PROMPT_REGION_POSITIVE_COND_PREFIX, + T2I_ADAPTER_COLLECT, +} from 'features/nodes/util/graph/constants'; +import type { Graph } from 'features/nodes/util/graph/Graph'; +import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil'; +import { size } from 'lodash-es'; +import { getImageDTO, imagesApi } from 'services/api/endpoints/images'; +import type { ImageDTO, Invocation, S } from 'services/api/types'; +import { assert } from 'tsafe'; + +const buildControlImage = ( + image: ImageWithDims | null, + processedImage: ImageWithDims | null, + processorConfig: ProcessorConfig | null +): ImageField => { + if (processedImage && processorConfig) { + // We've processed the image in the app - use it for the control image. + return { + image_name: processedImage.imageName, + }; + } else if (image) { + // No processor selected, and we have an image - the user provided a processed image, use it for the control image. + return { + image_name: image.imageName, + }; + } + assert(false, 'Attempted to add unprocessed control image'); +}; + +const buildControlNetMetadata = (controlNet: ControlNetConfigV2): S['ControlNetMetadataField'] => { + const { beginEndStepPct, controlMode, image, model, processedImage, processorConfig, weight } = controlNet; + + assert(model, 'ControlNet model is required'); + assert(image, 'ControlNet image is required'); + + const processed_image = + processedImage && processorConfig + ? { + image_name: processedImage.imageName, + } + : null; + + return { + control_model: model, + control_weight: weight, + control_mode: controlMode, + begin_step_percent: beginEndStepPct[0], + end_step_percent: beginEndStepPct[1], + resize_mode: 'just_resize', + image: { + image_name: image.imageName, + }, + processed_image, + }; +}; + +const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => { + try { + // Attempt to retrieve the collector + const controlNetCollect = g.getNode(CONTROL_NET_COLLECT); + assert(controlNetCollect.type === 'collect'); + return controlNetCollect; + } catch { + // Add the ControlNet collector + const controlNetCollect = g.addNode({ + id: CONTROL_NET_COLLECT, + type: 'collect', + }); + g.addEdge(controlNetCollect, 'collection', denoise, 'control'); + return controlNetCollect; + } +}; + +const addGlobalControlNetsToGraph = ( + controlNetConfigs: ControlNetConfigV2[], + g: Graph, + denoise: Invocation<'denoise_latents'> +): void => { + if (controlNetConfigs.length === 0) { + return; + } + const controlNetMetadata: S['ControlNetMetadataField'][] = []; + const controlNetCollect = addControlNetCollectorSafe(g, denoise); + + for (const controlNetConfig of controlNetConfigs) { + if (!controlNetConfig.model) { + return; + } + const { id, beginEndStepPct, controlMode, image, model, processedImage, processorConfig, weight } = + controlNetConfig; + + const controlNet = g.addNode({ + id: `control_net_${id}`, + type: 'controlnet', + begin_step_percent: beginEndStepPct[0], + end_step_percent: beginEndStepPct[1], + control_mode: controlMode, + resize_mode: 'just_resize', + control_model: model, + control_weight: weight, + image: buildControlImage(image, processedImage, processorConfig), + }); + + g.addEdge(controlNet, 'control', controlNetCollect, 'item'); + + controlNetMetadata.push(buildControlNetMetadata(controlNetConfig)); + } + MetadataUtil.add(g, { controlnets: controlNetMetadata }); +}; + +const buildT2IAdapterMetadata = (t2iAdapter: T2IAdapterConfigV2): S['T2IAdapterMetadataField'] => { + const { beginEndStepPct, image, model, processedImage, processorConfig, weight } = t2iAdapter; + + assert(model, 'T2I Adapter model is required'); + assert(image, 'T2I Adapter image is required'); + + const processed_image = + processedImage && processorConfig + ? { + image_name: processedImage.imageName, + } + : null; + + return { + t2i_adapter_model: model, + weight, + begin_step_percent: beginEndStepPct[0], + end_step_percent: beginEndStepPct[1], + resize_mode: 'just_resize', + image: { + image_name: image.imageName, + }, + processed_image, + }; +}; + +const addT2IAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => { + try { + // You see, we've already got one! + const t2iAdapterCollect = g.getNode(T2I_ADAPTER_COLLECT); + assert(t2iAdapterCollect.type === 'collect'); + return t2iAdapterCollect; + } catch { + const t2iAdapterCollect = g.addNode({ + id: T2I_ADAPTER_COLLECT, + type: 'collect', + }); + + g.addEdge(t2iAdapterCollect, 'collection', denoise, 't2i_adapter'); + + return t2iAdapterCollect; + } +}; + +const addGlobalT2IAdaptersToGraph = ( + t2iAdapterConfigs: T2IAdapterConfigV2[], + g: Graph, + denoise: Invocation<'denoise_latents'> +): void => { + if (t2iAdapterConfigs.length === 0) { + return; + } + const t2iAdapterMetadata: S['T2IAdapterMetadataField'][] = []; + const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise); + + for (const t2iAdapterConfig of t2iAdapterConfigs) { + if (!t2iAdapterConfig.model) { + return; + } + const { id, beginEndStepPct, image, model, processedImage, processorConfig, weight } = t2iAdapterConfig; + + const t2iAdapter = g.addNode({ + id: `t2i_adapter_${id}`, + type: 't2i_adapter', + begin_step_percent: beginEndStepPct[0], + end_step_percent: beginEndStepPct[1], + resize_mode: 'just_resize', + t2i_adapter_model: model, + weight: weight, + image: buildControlImage(image, processedImage, processorConfig), + }); + + g.addEdge(t2iAdapter, 't2i_adapter', t2iAdapterCollect, 'item'); + + t2iAdapterMetadata.push(buildT2IAdapterMetadata(t2iAdapterConfig)); + } + + MetadataUtil.add(g, { t2iAdapters: t2iAdapterMetadata }); +}; + +const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfigV2): S['IPAdapterMetadataField'] => { + const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter; + + assert(model, 'IP Adapter model is required'); + assert(image, 'IP Adapter image is required'); + + return { + ip_adapter_model: model, + clip_vision_model: clipVisionModel, + weight, + method, + begin_step_percent: beginEndStepPct[0], + end_step_percent: beginEndStepPct[1], + image: { + image_name: image.imageName, + }, + }; +}; + +const addIPAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => { + try { + // You see, we've already got one! + const ipAdapterCollect = g.getNode(IP_ADAPTER_COLLECT); + assert(ipAdapterCollect.type === 'collect'); + return ipAdapterCollect; + } catch { + const ipAdapterCollect = g.addNode({ + id: IP_ADAPTER_COLLECT, + type: 'collect', + }); + g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter'); + return ipAdapterCollect; + } +}; + +const addGlobalIPAdaptersToGraph = ( + ipAdapterConfigs: IPAdapterConfigV2[], + g: Graph, + denoise: Invocation<'denoise_latents'> +): void => { + if (ipAdapterConfigs.length === 0) { + return; + } + const ipAdapterMetdata: S['IPAdapterMetadataField'][] = []; + const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise); + + for (const ipAdapterConfig of ipAdapterConfigs) { + const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapterConfig; + assert(image, 'IP Adapter image is required'); + assert(model, 'IP Adapter model is required'); + + const ipAdapter = g.addNode({ + id: `ip_adapter_${id}`, + type: 'ip_adapter', + weight, + method, + ip_adapter_model: model, + clip_vision_model: clipVisionModel, + begin_step_percent: beginEndStepPct[0], + end_step_percent: beginEndStepPct[1], + image: { + image_name: image.imageName, + }, + }); + g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item'); + ipAdapterMetdata.push(buildIPAdapterMetadata(ipAdapterConfig)); + } + + MetadataUtil.add(g, { ipAdapters: ipAdapterMetdata }); +}; + +export const addGenerationTabControlLayers = async ( + state: RootState, + g: Graph, + denoise: Invocation<'denoise_latents'>, + posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>, + negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>, + posCondCollect: Invocation<'collect'>, + negCondCollect: Invocation<'collect'> +) => { + const mainModel = state.generation.model; + assert(mainModel, 'Missing main model when building graph'); + const isSDXL = mainModel.base === 'sdxl'; + + // Add global control adapters + const globalControlNetConfigs = state.controlLayers.present.layers + // Must be a CA layer + .filter(isControlAdapterLayer) + // Must be enabled + .filter((l) => l.isEnabled) + // We want the CAs themselves + .map((l) => l.controlAdapter) + // Must be a ControlNet + .filter(isControlNetConfigV2) + .filter((ca) => { + const hasModel = Boolean(ca.model); + const modelMatchesBase = ca.model?.base === mainModel.base; + const hasControlImage = ca.image || (ca.processedImage && ca.processorConfig); + return hasModel && modelMatchesBase && hasControlImage; + }); + addGlobalControlNetsToGraph(globalControlNetConfigs, g, denoise); + + const globalT2IAdapterConfigs = state.controlLayers.present.layers + // Must be a CA layer + .filter(isControlAdapterLayer) + // Must be enabled + .filter((l) => l.isEnabled) + // We want the CAs themselves + .map((l) => l.controlAdapter) + // Must have a ControlNet CA + .filter(isT2IAdapterConfigV2) + .filter((ca) => { + const hasModel = Boolean(ca.model); + const modelMatchesBase = ca.model?.base === mainModel.base; + const hasControlImage = ca.image || (ca.processedImage && ca.processorConfig); + return hasModel && modelMatchesBase && hasControlImage; + }); + addGlobalT2IAdaptersToGraph(globalT2IAdapterConfigs, g, denoise); + + const globalIPAdapterConfigs = state.controlLayers.present.layers + // Must be an IP Adapter layer + .filter(isIPAdapterLayer) + // Must be enabled + .filter((l) => l.isEnabled) + // We want the IP Adapters themselves + .map((l) => l.ipAdapter) + .filter((ca) => { + const hasModel = Boolean(ca.model); + const modelMatchesBase = ca.model?.base === mainModel.base; + const hasControlImage = Boolean(ca.image); + return hasModel && modelMatchesBase && hasControlImage; + }); + addGlobalIPAdaptersToGraph(globalIPAdapterConfigs, g, denoise); + + const rgLayers = state.controlLayers.present.layers + // Only RG layers are get masks + .filter(isRegionalGuidanceLayer) + // Only visible layers are rendered on the canvas + .filter((l) => l.isEnabled) + // Only layers with prompts get added to the graph + .filter((l) => { + const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt); + const hasIPAdapter = l.ipAdapters.length !== 0; + return hasTextPrompt || hasIPAdapter; + }); + + const layerIds = rgLayers.map((l) => l.id); + const blobs = await getRegionalPromptLayerBlobs(layerIds); + assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs'); + + for (const layer of rgLayers) { + const blob = blobs[layer.id]; + assert(blob, `Blob for layer ${layer.id} not found`); + // Upload the mask image, or get the cached image if it exists + const { image_name } = await getMaskImage(layer, blob); + + // The main mask-to-tensor node + const maskToTensor = g.addNode({ + id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${layer.id}`, + type: 'alpha_mask_to_tensor', + image: { + image_name, + }, + }); + + if (layer.positivePrompt) { + // The main positive conditioning node + const regionalPosCond = g.addNode( + isSDXL + ? { + type: 'sdxl_compel_prompt', + id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`, + prompt: layer.positivePrompt, + style: layer.positivePrompt, // TODO: Should we put the positive prompt in both fields? + } + : { + type: 'compel', + id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`, + prompt: layer.positivePrompt, + } + ); + // Connect the mask to the conditioning + g.addEdge(maskToTensor, 'mask', regionalPosCond, 'mask'); + // Connect the conditioning to the collector + g.addEdge(regionalPosCond, 'conditioning', posCondCollect, 'item'); + // Copy the connections to the "global" positive conditioning node to the regional cond + for (const edge of g.getEdgesTo(posCond)) { + console.log(edge); + if (edge.destination.field !== 'prompt') { + // Clone the edge, but change the destination node to the regional conditioning node + const clone = deepClone(edge); + clone.destination.node_id = regionalPosCond.id; + g.addEdgeFromObj(clone); + } + } + } + + if (layer.negativePrompt) { + // The main negative conditioning node + const regionalNegCond = g.addNode( + isSDXL + ? { + type: 'sdxl_compel_prompt', + id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`, + prompt: layer.negativePrompt, + style: layer.negativePrompt, + } + : { + type: 'compel', + id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`, + prompt: layer.negativePrompt, + } + ); + // Connect the mask to the conditioning + g.addEdge(maskToTensor, 'mask', regionalNegCond, 'mask'); + // Connect the conditioning to the collector + g.addEdge(regionalNegCond, 'conditioning', negCondCollect, 'item'); + // Copy the connections to the "global" negative conditioning node to the regional cond + for (const edge of g.getEdgesTo(negCond)) { + if (edge.destination.field !== 'prompt') { + // Clone the edge, but change the destination node to the regional conditioning node + const clone = deepClone(edge); + clone.destination.node_id = regionalNegCond.id; + g.addEdgeFromObj(clone); + } + } + } + + // If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node + if (layer.autoNegative === 'invert' && layer.positivePrompt) { + // We re-use the mask image, but invert it when converting to tensor + const invertTensorMask = g.addNode({ + id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${layer.id}`, + type: 'invert_tensor_mask', + }); + // Connect the OG mask image to the inverted mask-to-tensor node + g.addEdge(maskToTensor, 'mask', invertTensorMask, 'mask'); + // Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the positive prompt + const regionalPosCondInverted = g.addNode( + isSDXL + ? { + type: 'sdxl_compel_prompt', + id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`, + prompt: layer.positivePrompt, + style: layer.positivePrompt, + } + : { + type: 'compel', + id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`, + prompt: layer.positivePrompt, + } + ); + // Connect the inverted mask to the conditioning + g.addEdge(invertTensorMask, 'mask', regionalPosCondInverted, 'mask'); + // Connect the conditioning to the negative collector + g.addEdge(regionalPosCondInverted, 'conditioning', negCondCollect, 'item'); + // Copy the connections to the "global" positive conditioning node to our regional node + for (const edge of g.getEdgesTo(posCond)) { + if (edge.destination.field !== 'prompt') { + // Clone the edge, but change the destination node to the regional conditioning node + const clone = deepClone(edge); + clone.destination.node_id = regionalPosCondInverted.id; + g.addEdgeFromObj(clone); + } + } + } + + // TODO(psyche): For some reason, I have to explicitly annotate regionalIPAdapters here. Not sure why. + const regionalIPAdapters: IPAdapterConfigV2[] = layer.ipAdapters.filter((ipAdapter) => { + const hasModel = Boolean(ipAdapter.model); + const modelMatchesBase = ipAdapter.model?.base === mainModel.base; + const hasControlImage = Boolean(ipAdapter.image); + return hasModel && modelMatchesBase && hasControlImage; + }); + + for (const ipAdapterConfig of regionalIPAdapters) { + const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise); + const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapterConfig; + assert(model, 'IP Adapter model is required'); + assert(image, 'IP Adapter image is required'); + + const ipAdapter = g.addNode({ + id: `ip_adapter_${id}`, + type: 'ip_adapter', + weight, + method, + ip_adapter_model: model, + clip_vision_model: clipVisionModel, + begin_step_percent: beginEndStepPct[0], + end_step_percent: beginEndStepPct[1], + image: { + image_name: image.imageName, + }, + }); + + // Connect the mask to the conditioning + g.addEdge(maskToTensor, 'mask', ipAdapter, 'mask'); + g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item'); + } + } +}; + +const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise => { + if (layer.uploadedMaskImage) { + const imageDTO = await getImageDTO(layer.uploadedMaskImage.imageName); + if (imageDTO) { + return imageDTO; + } + } + const { dispatch } = getStore(); + // No cached mask, or the cached image no longer exists - we need to upload the mask image + const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' }); + const req = dispatch( + imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true }) + ); + req.reset(); + + const imageDTO = await req.unwrap(); + dispatch(rgLayerMaskImageUploaded({ layerId: layer.id, imageDTO })); + return imageDTO; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addInitialImageToGenerationTabGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabInitialImage.ts similarity index 96% rename from invokeai/frontend/web/src/features/nodes/util/graph/addInitialImageToGenerationTabGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabInitialImage.ts index e0cbea810f..3a6b124b30 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addInitialImageToGenerationTabGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabInitialImage.ts @@ -15,12 +15,12 @@ import { IMAGE_TO_LATENTS, RESIZE } from './constants'; * @param noise The noise node in the graph * @returns Whether the initial image was added to the graph */ -export const addInitialImageToGenerationTabGraph = ( +export const addGenerationTabInitialImage = ( state: RootState, g: Graph, denoise: Invocation<'denoise_latents'>, noise: Invocation<'noise'> -): boolean => { +): Invocation<'i2l'> | null => { // Remove Existing UNet Connections const { img2imgStrength, vaePrecision, model } = state.generation; const { refinerModel, refinerStart } = state.sdxl; @@ -29,7 +29,7 @@ export const addInitialImageToGenerationTabGraph = ( const initialImage = initialImageLayer?.isEnabled ? initialImageLayer?.image : null; if (!initialImage) { - return false; + return null; } const isSDXL = model?.base === 'sdxl'; @@ -75,5 +75,5 @@ export const addInitialImageToGenerationTabGraph = ( init_image: initialImage.imageName, }); - return true; + return i2l; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabLoRAs.ts new file mode 100644 index 0000000000..3cb43fd48d --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabLoRAs.ts @@ -0,0 +1,94 @@ +import type { RootState } from 'app/store/store'; +import { deepClone } from 'common/util/deepClone'; +import { zModelIdentifierField } from 'features/nodes/types/common'; +import { Graph } from 'features/nodes/util/graph/Graph'; +import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil'; +import { filter, size } from 'lodash-es'; +import type { Invocation, S } from 'services/api/types'; +import { assert } from 'tsafe'; + +import { LORA_LOADER } from './constants'; + +export const addGenerationTabLoRAs = ( + state: RootState, + g: Graph, + denoise: Invocation<'denoise_latents'>, + unetSource: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'>, + clipSkip: Invocation<'clip_skip'>, + posCond: Invocation<'compel'>, + negCond: Invocation<'compel'> +): void => { + /** + * LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them. + * They then output the UNet and CLIP models references on to either the next LoRA in the chain, + * or to the inference/conditioning nodes. + * + * So we need to inject a LoRA chain into the graph. + */ + + const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false); + const loraCount = size(enabledLoRAs); + + if (loraCount === 0) { + return; + } + + // Remove modelLoaderNodeId unet connection to feed it to LoRAs + console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e))); + g.deleteEdgesFrom(unetSource, 'unet'); + console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e))); + if (clipSkip) { + // Remove CLIP_SKIP connections to conditionings to feed it through LoRAs + g.deleteEdgesFrom(clipSkip, 'clip'); + } + console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e))); + + // we need to remember the last lora so we can chain from it + let lastLoRALoader: Invocation<'lora_loader'> | null = null; + let currentLoraIndex = 0; + const loraMetadata: S['LoRAMetadataField'][] = []; + + for (const lora of enabledLoRAs) { + const { weight } = lora; + const { key } = lora.model; + const currentLoraNodeId = `${LORA_LOADER}_${key}`; + const parsedModel = zModelIdentifierField.parse(lora.model); + + const currentLoRALoader = g.addNode({ + type: 'lora_loader', + id: currentLoraNodeId, + lora: parsedModel, + weight, + }); + + loraMetadata.push({ + model: parsedModel, + weight, + }); + + // add to graph + if (currentLoraIndex === 0) { + // first lora = start the lora chain, attach directly to model loader + g.addEdge(unetSource, 'unet', currentLoRALoader, 'unet'); + g.addEdge(clipSkip, 'clip', currentLoRALoader, 'clip'); + } else { + assert(lastLoRALoader !== null); + // we are in the middle of the lora chain, instead connect to the previous lora + g.addEdge(lastLoRALoader, 'unet', currentLoRALoader, 'unet'); + g.addEdge(lastLoRALoader, 'clip', currentLoRALoader, 'clip'); + } + + if (currentLoraIndex === loraCount - 1) { + // final lora, end the lora chain - we need to connect up to inference and conditioning nodes + g.addEdge(currentLoRALoader, 'unet', denoise, 'unet'); + g.addEdge(currentLoRALoader, 'clip', posCond, 'clip'); + g.addEdge(currentLoRALoader, 'clip', negCond, 'clip'); + } + + // increment the lora for the next one in the chain + lastLoRALoader = currentLoRALoader; + currentLoraIndex += 1; + } + + MetadataUtil.add(g, { loras: loraMetadata }); +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addSeamlessToGenerationTabGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSeamless.ts similarity index 60% rename from invokeai/frontend/web/src/features/nodes/util/graph/addSeamlessToGenerationTabGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSeamless.ts index 7434058f7a..e56f37916c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addSeamlessToGenerationTabGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSeamless.ts @@ -15,26 +15,28 @@ import { SEAMLESS, VAE_LOADER } from './constants'; * @param modelLoader The model loader node in the graph * @returns The terminal model loader node in the graph */ -export const addSeamlessToGenerationTabGraph = ( +export const addGenerationTabSeamless = ( state: RootState, g: Graph, denoise: Invocation<'denoise_latents'>, modelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> -): Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'> => { - const { seamlessXAxis, seamlessYAxis, vae } = state.generation; +): Invocation<'seamless'> | null => { + const { seamlessXAxis: seamless_x, seamlessYAxis: seamless_y, vae } = state.generation; - if (!seamlessXAxis && !seamlessYAxis) { - return modelLoader; + if (!seamless_x && !seamless_y) { + return null; } const seamless = g.addNode({ id: SEAMLESS, type: 'seamless', - seamless_x: seamlessXAxis, - seamless_y: seamlessYAxis, + seamless_x, + seamless_y, }); - const vaeLoader = vae + // The VAE helper also adds the VAE loader - so we need to check if it's already there + const shouldAddVAELoader = !g.hasNode(VAE_LOADER) && vae; + const vaeLoader = shouldAddVAELoader ? g.addNode({ type: 'vae_loader', id: VAE_LOADER, @@ -42,29 +44,18 @@ export const addSeamlessToGenerationTabGraph = ( }) : null; - let terminalModelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'> = - modelLoader; - - if (seamlessXAxis) { - MetadataUtil.add(g, { - seamless_x: seamlessXAxis, - }); - terminalModelLoader = seamless; - } - if (seamlessYAxis) { - MetadataUtil.add(g, { - seamless_y: seamlessYAxis, - }); - terminalModelLoader = seamless; - } + MetadataUtil.add(g, { + seamless_x: seamless_x || undefined, + seamless_y: seamless_y || undefined, + }); // Seamless slots into the graph between the model loader and the denoise node g.deleteEdgesFrom(modelLoader, 'unet'); - g.deleteEdgesFrom(modelLoader, 'clip'); + g.deleteEdgesFrom(modelLoader, 'vae'); g.addEdge(modelLoader, 'unet', seamless, 'unet'); - g.addEdge(vaeLoader ?? modelLoader, 'vae', seamless, 'unet'); + g.addEdge(vaeLoader ?? modelLoader, 'vae', seamless, 'vae'); g.addEdge(seamless, 'unet', denoise, 'unet'); - return terminalModelLoader; + return seamless; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabVAE.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabVAE.ts new file mode 100644 index 0000000000..037924d5cb --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabVAE.ts @@ -0,0 +1,37 @@ +import type { RootState } from 'app/store/store'; +import type { Graph } from 'features/nodes/util/graph/Graph'; +import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil'; +import type { Invocation } from 'services/api/types'; + +import { VAE_LOADER } from './constants'; + +export const addGenerationTabVAE = ( + state: RootState, + g: Graph, + modelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'>, + l2i: Invocation<'l2i'>, + i2l: Invocation<'i2l'> | null, + seamless: Invocation<'seamless'> | null +): void => { + const { vae } = state.generation; + + // The seamless helper also adds the VAE loader... so we need to check if it's already there + const shouldAddVAELoader = !g.hasNode(VAE_LOADER) && vae; + const vaeLoader = shouldAddVAELoader + ? g.addNode({ + type: 'vae_loader', + id: VAE_LOADER, + vae_model: vae, + }) + : null; + + const vaeSource = seamless ? seamless : vaeLoader ? vaeLoader : modelLoader; + g.addEdge(vaeSource, 'vae', l2i, 'vae'); + if (i2l) { + g.addEdge(vaeSource, 'vae', i2l, 'vae'); + } + + if (vae) { + MetadataUtil.add(g, { vae }); + } +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildGenerationTabGraph2.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildGenerationTabGraph2.ts index be530095a3..328cccb98a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildGenerationTabGraph2.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildGenerationTabGraph2.ts @@ -1,18 +1,19 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; -import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph'; -import { addInitialImageToGenerationTabGraph } from 'features/nodes/util/graph/addInitialImageToGenerationTabGraph'; -import { addSeamlessToGenerationTabGraph } from 'features/nodes/util/graph/addSeamlessToGenerationTabGraph'; +import { addGenerationTabControlLayers } from 'features/nodes/util/graph/addGenerationTabControlLayers'; +import { addGenerationTabInitialImage } from 'features/nodes/util/graph/addGenerationTabInitialImage'; +import { addGenerationTabLoRAs } from 'features/nodes/util/graph/addGenerationTabLoRAs'; +import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless'; +import { addGenerationTabVAE } from 'features/nodes/util/graph/addGenerationTabVAE'; +import type { GraphType } from 'features/nodes/util/graph/Graph'; import { Graph } from 'features/nodes/util/graph/Graph'; import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils'; import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil'; import { isNonRefinerMainModelConfig } from 'services/api/types'; import { addHrfToGraph } from './addHrfToGraph'; -import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; -import { addVAEToGraph } from './addVAEToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { CLIP_SKIP, @@ -29,7 +30,7 @@ import { import { getModelMetadataField } from './metadata'; const log = logger('nodes'); -export const buildGenerationTabGraph = async (state: RootState): Promise => { +export const buildGenerationTabGraph2 = async (state: RootState): Promise => { const { model, cfgScale: cfg_scale, @@ -39,8 +40,6 @@ export const buildGenerationTabGraph = async (state: RootState): Promise clipSkip: skipped_layers, shouldUseCpuNoise, vaePrecision, - seamlessXAxis, - seamlessYAxis, seed, } = state.generation; const { positivePrompt, negativePrompt } = state.controlLayers.present; @@ -114,6 +113,8 @@ export const buildGenerationTabGraph = async (state: RootState): Promise g.addEdge(clipSkip, 'clip', negCond, 'clip'); g.addEdge(posCond, 'conditioning', posCondCollect, 'item'); g.addEdge(negCond, 'conditioning', negCondCollect, 'item'); + g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning'); + g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning'); g.addEdge(noise, 'noise', denoise, 'noise'); g.addEdge(denoise, 'latents', l2i, 'latents'); @@ -135,20 +136,22 @@ export const buildGenerationTabGraph = async (state: RootState): Promise clip_skip: skipped_layers, }); MetadataUtil.setMetadataReceivingNode(g, l2i); + g.validate(); - const didAddInitialImage = addInitialImageToGenerationTabGraph(state, g, denoise, noise); - const terminalModelLoader = addSeamlessToGenerationTabGraph(state, g, denoise, modelLoader); + const i2l = addGenerationTabInitialImage(state, g, denoise, noise); + g.validate(); + const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader); + g.validate(); + addGenerationTabVAE(state, g, modelLoader, l2i, i2l, seamless); + g.validate(); + addGenerationTabLoRAs(state, g, denoise, seamless ?? modelLoader, clipSkip, posCond, negCond); + g.validate(); - // optionally add custom VAE - await addVAEToGraph(state, graph, modelLoaderNodeId); - - // add LoRA support - await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId); - - await addControlLayersToGraph(state, graph, DENOISE_LATENTS); + await addGenerationTabControlLayers(state, g, denoise, posCond, negCond, posCondCollect, negCondCollect); + g.validate(); // High resolution fix. - if (state.hrf.hrfEnabled && !didAddInitialImage) { + if (state.hrf.hrfEnabled && !i2l) { addHrfToGraph(state, graph); } @@ -163,5 +166,5 @@ export const buildGenerationTabGraph = async (state: RootState): Promise addWatermarkerToGraph(state, graph); } - return graph; + return g.getGraph(); };