mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): add img2img via control layers to graph builders
This commit is contained in:
parent
c9886796f6
commit
dc81357152
@ -0,0 +1,117 @@
|
|||||||
|
import type { RootState } from 'app/store/store';
|
||||||
|
import { isInitialImageLayer } from 'features/controlLayers/store/controlLayersSlice';
|
||||||
|
import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
import { IMAGE_TO_LATENTS, NOISE, RESIZE } from './constants';
|
||||||
|
|
||||||
|
export const addInitialImageToLinearGraph = (
|
||||||
|
state: RootState,
|
||||||
|
graph: NonNullableGraph,
|
||||||
|
denoiseNodeId: string
|
||||||
|
): void => {
|
||||||
|
// Remove Existing UNet Connections
|
||||||
|
const { img2imgStrength, vaePrecision, model } = state.generation;
|
||||||
|
const { refinerModel, refinerStart } = state.sdxl;
|
||||||
|
const { width, height } = state.controlLayers.present.size;
|
||||||
|
const initialImageLayer = state.controlLayers.present.layers.find(isInitialImageLayer);
|
||||||
|
const initialImage = initialImageLayer?.isEnabled ? initialImageLayer?.image : null;
|
||||||
|
|
||||||
|
if (!initialImage) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const useRefinerStartEnd = model?.base === 'sdxl' && Boolean(refinerModel);
|
||||||
|
|
||||||
|
const denoiseNode = graph.nodes[denoiseNodeId];
|
||||||
|
assert(denoiseNode?.type === 'denoise_latents', `Missing denoise node or incorrect type: ${denoiseNode?.type}`);
|
||||||
|
|
||||||
|
denoiseNode.denoising_start = useRefinerStartEnd ? Math.min(refinerStart, 1 - img2imgStrength) : 1 - img2imgStrength;
|
||||||
|
denoiseNode.denoising_end = useRefinerStartEnd ? refinerStart : 1;
|
||||||
|
|
||||||
|
// We conditionally hook the image in depending on if a resize is needed
|
||||||
|
const i2lNode: ImageToLatentsInvocation = {
|
||||||
|
type: 'i2l',
|
||||||
|
id: IMAGE_TO_LATENTS,
|
||||||
|
is_intermediate: true,
|
||||||
|
use_cache: true,
|
||||||
|
fp32: vaePrecision === 'fp32',
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.nodes[i2lNode.id] = i2lNode;
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: IMAGE_TO_LATENTS,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: denoiseNode.id,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (initialImage.width !== width || initialImage.height !== height) {
|
||||||
|
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
||||||
|
|
||||||
|
// Create a resize node, explicitly setting its image
|
||||||
|
const resizeNode: ImageResizeInvocation = {
|
||||||
|
id: RESIZE,
|
||||||
|
type: 'img_resize',
|
||||||
|
image: {
|
||||||
|
image_name: initialImage.imageName,
|
||||||
|
},
|
||||||
|
is_intermediate: true,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.nodes[RESIZE] = resizeNode;
|
||||||
|
|
||||||
|
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RESIZE, field: 'image' },
|
||||||
|
destination: {
|
||||||
|
node_id: IMAGE_TO_LATENTS,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// The `RESIZE` node also passes its width and height to `NOISE`
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RESIZE, field: 'width' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'width',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RESIZE, field: 'height' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'height',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||||
|
i2lNode.image = {
|
||||||
|
image_name: initialImage.imageName,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Pass the image's dimensions to the `NOISE` node
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'width',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'height',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
@ -7,6 +7,7 @@ import {
|
|||||||
SDXL_CANVAS_INPAINT_GRAPH,
|
SDXL_CANVAS_INPAINT_GRAPH,
|
||||||
SDXL_CANVAS_OUTPAINT_GRAPH,
|
SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||||
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||||
|
SDXL_CONTROL_LAYERS_GRAPH,
|
||||||
SDXL_DENOISE_LATENTS,
|
SDXL_DENOISE_LATENTS,
|
||||||
SDXL_IMAGE_TO_IMAGE_GRAPH,
|
SDXL_IMAGE_TO_IMAGE_GRAPH,
|
||||||
SDXL_TEXT_TO_IMAGE_GRAPH,
|
SDXL_TEXT_TO_IMAGE_GRAPH,
|
||||||
@ -54,6 +55,7 @@ export const addSeamlessToLinearGraph = (
|
|||||||
let denoisingNodeId = DENOISE_LATENTS;
|
let denoisingNodeId = DENOISE_LATENTS;
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
graph.id === SDXL_CONTROL_LAYERS_GRAPH ||
|
||||||
graph.id === SDXL_TEXT_TO_IMAGE_GRAPH ||
|
graph.id === SDXL_TEXT_TO_IMAGE_GRAPH ||
|
||||||
graph.id === SDXL_IMAGE_TO_IMAGE_GRAPH ||
|
graph.id === SDXL_IMAGE_TO_IMAGE_GRAPH ||
|
||||||
graph.id === SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH ||
|
graph.id === SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH ||
|
||||||
|
@ -7,6 +7,7 @@ import {
|
|||||||
CANVAS_OUTPAINT_GRAPH,
|
CANVAS_OUTPAINT_GRAPH,
|
||||||
CANVAS_OUTPUT,
|
CANVAS_OUTPUT,
|
||||||
CANVAS_TEXT_TO_IMAGE_GRAPH,
|
CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||||
|
CONTROL_LAYERS_GRAPH,
|
||||||
IMAGE_TO_IMAGE_GRAPH,
|
IMAGE_TO_IMAGE_GRAPH,
|
||||||
IMAGE_TO_LATENTS,
|
IMAGE_TO_LATENTS,
|
||||||
INPAINT_CREATE_MASK,
|
INPAINT_CREATE_MASK,
|
||||||
@ -17,11 +18,11 @@ import {
|
|||||||
SDXL_CANVAS_INPAINT_GRAPH,
|
SDXL_CANVAS_INPAINT_GRAPH,
|
||||||
SDXL_CANVAS_OUTPAINT_GRAPH,
|
SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||||
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||||
|
SDXL_CONTROL_LAYERS_GRAPH,
|
||||||
SDXL_IMAGE_TO_IMAGE_GRAPH,
|
SDXL_IMAGE_TO_IMAGE_GRAPH,
|
||||||
SDXL_REFINER_SEAMLESS,
|
SDXL_REFINER_SEAMLESS,
|
||||||
SDXL_TEXT_TO_IMAGE_GRAPH,
|
SDXL_TEXT_TO_IMAGE_GRAPH,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
TEXT_TO_IMAGE_GRAPH,
|
|
||||||
VAE_LOADER,
|
VAE_LOADER,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { upsertMetadata } from './metadata';
|
import { upsertMetadata } from './metadata';
|
||||||
@ -52,7 +53,8 @@ export const addVAEToGraph = async (
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
graph.id === TEXT_TO_IMAGE_GRAPH ||
|
graph.id === CONTROL_LAYERS_GRAPH ||
|
||||||
|
graph.id === SDXL_CONTROL_LAYERS_GRAPH ||
|
||||||
graph.id === IMAGE_TO_IMAGE_GRAPH ||
|
graph.id === IMAGE_TO_IMAGE_GRAPH ||
|
||||||
graph.id === SDXL_TEXT_TO_IMAGE_GRAPH ||
|
graph.id === SDXL_TEXT_TO_IMAGE_GRAPH ||
|
||||||
graph.id === SDXL_IMAGE_TO_IMAGE_GRAPH
|
graph.id === SDXL_IMAGE_TO_IMAGE_GRAPH
|
||||||
@ -100,6 +102,7 @@ export const addVAEToGraph = async (
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
graph.id === SDXL_CONTROL_LAYERS_GRAPH ||
|
||||||
graph.id === IMAGE_TO_IMAGE_GRAPH ||
|
graph.id === IMAGE_TO_IMAGE_GRAPH ||
|
||||||
graph.id === SDXL_IMAGE_TO_IMAGE_GRAPH ||
|
graph.id === SDXL_IMAGE_TO_IMAGE_GRAPH ||
|
||||||
graph.id === CANVAS_IMAGE_TO_IMAGE_GRAPH ||
|
graph.id === CANVAS_IMAGE_TO_IMAGE_GRAPH ||
|
||||||
|
@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
|
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
|
||||||
|
import { addInitialImageToLinearGraph } from 'features/nodes/util/graph/addInitialImageToLinearGraph';
|
||||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||||
|
|
||||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||||
@ -15,10 +16,10 @@ import {
|
|||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
|
SDXL_CONTROL_LAYERS_GRAPH,
|
||||||
SDXL_DENOISE_LATENTS,
|
SDXL_DENOISE_LATENTS,
|
||||||
SDXL_MODEL_LOADER,
|
SDXL_MODEL_LOADER,
|
||||||
SDXL_REFINER_SEAMLESS,
|
SDXL_REFINER_SEAMLESS,
|
||||||
SDXL_TEXT_TO_IMAGE_GRAPH,
|
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||||
@ -70,7 +71,7 @@ export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise
|
|||||||
|
|
||||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
id: SDXL_TEXT_TO_IMAGE_GRAPH,
|
id: SDXL_CONTROL_LAYERS_GRAPH,
|
||||||
nodes: {
|
nodes: {
|
||||||
[modelLoaderNodeId]: {
|
[modelLoaderNodeId]: {
|
||||||
type: 'sdxl_model_loader',
|
type: 'sdxl_model_loader',
|
||||||
@ -241,6 +242,8 @@ export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise
|
|||||||
LATENTS_TO_IMAGE
|
LATENTS_TO_IMAGE
|
||||||
);
|
);
|
||||||
|
|
||||||
|
addInitialImageToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
// Add Seamless To Graph
|
// Add Seamless To Graph
|
||||||
if (seamlessXAxis || seamlessYAxis) {
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
|
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
|
||||||
|
@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
|
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
|
||||||
|
import { addInitialImageToLinearGraph } from 'features/nodes/util/graph/addInitialImageToLinearGraph';
|
||||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||||
|
|
||||||
@ -13,6 +14,7 @@ import { addVAEToGraph } from './addVAEToGraph';
|
|||||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||||
import {
|
import {
|
||||||
CLIP_SKIP,
|
CLIP_SKIP,
|
||||||
|
CONTROL_LAYERS_GRAPH,
|
||||||
DENOISE_LATENTS,
|
DENOISE_LATENTS,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MAIN_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
@ -20,7 +22,6 @@ import {
|
|||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
TEXT_TO_IMAGE_GRAPH,
|
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||||
|
|
||||||
@ -66,7 +67,7 @@ export const buildLinearTextToImageGraph = async (state: RootState): Promise<Non
|
|||||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
|
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
id: TEXT_TO_IMAGE_GRAPH,
|
id: CONTROL_LAYERS_GRAPH,
|
||||||
nodes: {
|
nodes: {
|
||||||
[modelLoaderNodeId]: {
|
[modelLoaderNodeId]: {
|
||||||
type: 'main_model_loader',
|
type: 'main_model_loader',
|
||||||
@ -231,6 +232,8 @@ export const buildLinearTextToImageGraph = async (state: RootState): Promise<Non
|
|||||||
LATENTS_TO_IMAGE
|
LATENTS_TO_IMAGE
|
||||||
);
|
);
|
||||||
|
|
||||||
|
addInitialImageToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
|
|
||||||
// Add Seamless To Graph
|
// Add Seamless To Graph
|
||||||
if (seamlessXAxis || seamlessYAxis) {
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
|
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
|
||||||
|
@ -55,6 +55,8 @@ export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
|
|||||||
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';
|
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';
|
||||||
|
|
||||||
// friendly graph ids
|
// friendly graph ids
|
||||||
|
export const CONTROL_LAYERS_GRAPH = 'control_layers_graph';
|
||||||
|
export const SDXL_CONTROL_LAYERS_GRAPH = 'sdxl_control_layers_graph';
|
||||||
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
||||||
export const IMAGE_TO_IMAGE_GRAPH = 'image_to_image_graph';
|
export const IMAGE_TO_IMAGE_GRAPH = 'image_to_image_graph';
|
||||||
export const CANVAS_TEXT_TO_IMAGE_GRAPH = 'canvas_text_to_image_graph';
|
export const CANVAS_TEXT_TO_IMAGE_GRAPH = 'canvas_text_to_image_graph';
|
||||||
|
Loading…
Reference in New Issue
Block a user