mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): add img2img via control layers to graph builders
This commit is contained in:
parent
c9886796f6
commit
dc81357152
@ -0,0 +1,117 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { isInitialImageLayer } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { IMAGE_TO_LATENTS, NOISE, RESIZE } from './constants';
|
||||
|
||||
export const addInitialImageToLinearGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
denoiseNodeId: string
|
||||
): void => {
|
||||
// Remove Existing UNet Connections
|
||||
const { img2imgStrength, vaePrecision, model } = state.generation;
|
||||
const { refinerModel, refinerStart } = state.sdxl;
|
||||
const { width, height } = state.controlLayers.present.size;
|
||||
const initialImageLayer = state.controlLayers.present.layers.find(isInitialImageLayer);
|
||||
const initialImage = initialImageLayer?.isEnabled ? initialImageLayer?.image : null;
|
||||
|
||||
if (!initialImage) {
|
||||
return;
|
||||
}
|
||||
|
||||
const useRefinerStartEnd = model?.base === 'sdxl' && Boolean(refinerModel);
|
||||
|
||||
const denoiseNode = graph.nodes[denoiseNodeId];
|
||||
assert(denoiseNode?.type === 'denoise_latents', `Missing denoise node or incorrect type: ${denoiseNode?.type}`);
|
||||
|
||||
denoiseNode.denoising_start = useRefinerStartEnd ? Math.min(refinerStart, 1 - img2imgStrength) : 1 - img2imgStrength;
|
||||
denoiseNode.denoising_end = useRefinerStartEnd ? refinerStart : 1;
|
||||
|
||||
// We conditionally hook the image in depending on if a resize is needed
|
||||
const i2lNode: ImageToLatentsInvocation = {
|
||||
type: 'i2l',
|
||||
id: IMAGE_TO_LATENTS,
|
||||
is_intermediate: true,
|
||||
use_cache: true,
|
||||
fp32: vaePrecision === 'fp32',
|
||||
};
|
||||
|
||||
graph.nodes[i2lNode.id] = i2lNode;
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: denoiseNode.id,
|
||||
field: 'latents',
|
||||
},
|
||||
});
|
||||
|
||||
if (initialImage.width !== width || initialImage.height !== height) {
|
||||
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
||||
|
||||
// Create a resize node, explicitly setting its image
|
||||
const resizeNode: ImageResizeInvocation = {
|
||||
id: RESIZE,
|
||||
type: 'img_resize',
|
||||
image: {
|
||||
image_name: initialImage.imageName,
|
||||
},
|
||||
is_intermediate: true,
|
||||
width,
|
||||
height,
|
||||
};
|
||||
|
||||
graph.nodes[RESIZE] = resizeNode;
|
||||
|
||||
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'image' },
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
|
||||
// The `RESIZE` node also passes its width and height to `NOISE`
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'width' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'width',
|
||||
},
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'height' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'height',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||
i2lNode.image = {
|
||||
image_name: initialImage.imageName,
|
||||
};
|
||||
|
||||
// Pass the image's dimensions to the `NOISE` node
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'width',
|
||||
},
|
||||
});
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'height',
|
||||
},
|
||||
});
|
||||
}
|
||||
};
|
@ -7,6 +7,7 @@ import {
|
||||
SDXL_CANVAS_INPAINT_GRAPH,
|
||||
SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||
SDXL_CONTROL_LAYERS_GRAPH,
|
||||
SDXL_DENOISE_LATENTS,
|
||||
SDXL_IMAGE_TO_IMAGE_GRAPH,
|
||||
SDXL_TEXT_TO_IMAGE_GRAPH,
|
||||
@ -54,6 +55,7 @@ export const addSeamlessToLinearGraph = (
|
||||
let denoisingNodeId = DENOISE_LATENTS;
|
||||
|
||||
if (
|
||||
graph.id === SDXL_CONTROL_LAYERS_GRAPH ||
|
||||
graph.id === SDXL_TEXT_TO_IMAGE_GRAPH ||
|
||||
graph.id === SDXL_IMAGE_TO_IMAGE_GRAPH ||
|
||||
graph.id === SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH ||
|
||||
|
@ -7,6 +7,7 @@ import {
|
||||
CANVAS_OUTPAINT_GRAPH,
|
||||
CANVAS_OUTPUT,
|
||||
CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||
CONTROL_LAYERS_GRAPH,
|
||||
IMAGE_TO_IMAGE_GRAPH,
|
||||
IMAGE_TO_LATENTS,
|
||||
INPAINT_CREATE_MASK,
|
||||
@ -17,11 +18,11 @@ import {
|
||||
SDXL_CANVAS_INPAINT_GRAPH,
|
||||
SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||
SDXL_CONTROL_LAYERS_GRAPH,
|
||||
SDXL_IMAGE_TO_IMAGE_GRAPH,
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
SDXL_TEXT_TO_IMAGE_GRAPH,
|
||||
SEAMLESS,
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
VAE_LOADER,
|
||||
} from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
@ -52,7 +53,8 @@ export const addVAEToGraph = async (
|
||||
}
|
||||
|
||||
if (
|
||||
graph.id === TEXT_TO_IMAGE_GRAPH ||
|
||||
graph.id === CONTROL_LAYERS_GRAPH ||
|
||||
graph.id === SDXL_CONTROL_LAYERS_GRAPH ||
|
||||
graph.id === IMAGE_TO_IMAGE_GRAPH ||
|
||||
graph.id === SDXL_TEXT_TO_IMAGE_GRAPH ||
|
||||
graph.id === SDXL_IMAGE_TO_IMAGE_GRAPH
|
||||
@ -100,6 +102,7 @@ export const addVAEToGraph = async (
|
||||
}
|
||||
|
||||
if (
|
||||
graph.id === SDXL_CONTROL_LAYERS_GRAPH ||
|
||||
graph.id === IMAGE_TO_IMAGE_GRAPH ||
|
||||
graph.id === SDXL_IMAGE_TO_IMAGE_GRAPH ||
|
||||
graph.id === CANVAS_IMAGE_TO_IMAGE_GRAPH ||
|
||||
|
@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
|
||||
import { addInitialImageToLinearGraph } from 'features/nodes/util/graph/addInitialImageToLinearGraph';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
@ -15,10 +16,10 @@ import {
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
SDXL_CONTROL_LAYERS_GRAPH,
|
||||
SDXL_DENOISE_LATENTS,
|
||||
SDXL_MODEL_LOADER,
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
SDXL_TEXT_TO_IMAGE_GRAPH,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
@ -70,7 +71,7 @@ export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
const graph: NonNullableGraph = {
|
||||
id: SDXL_TEXT_TO_IMAGE_GRAPH,
|
||||
id: SDXL_CONTROL_LAYERS_GRAPH,
|
||||
nodes: {
|
||||
[modelLoaderNodeId]: {
|
||||
type: 'sdxl_model_loader',
|
||||
@ -241,6 +242,8 @@ export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise
|
||||
LATENTS_TO_IMAGE
|
||||
);
|
||||
|
||||
addInitialImageToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// Add Seamless To Graph
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
|
||||
|
@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
|
||||
import { addInitialImageToLinearGraph } from 'features/nodes/util/graph/addInitialImageToLinearGraph';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
@ -13,6 +14,7 @@ import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
CLIP_SKIP,
|
||||
CONTROL_LAYERS_GRAPH,
|
||||
DENOISE_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
@ -20,7 +22,6 @@ import {
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
SEAMLESS,
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
} from './constants';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
@ -66,7 +67,7 @@ export const buildLinearTextToImageGraph = async (state: RootState): Promise<Non
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
|
||||
const graph: NonNullableGraph = {
|
||||
id: TEXT_TO_IMAGE_GRAPH,
|
||||
id: CONTROL_LAYERS_GRAPH,
|
||||
nodes: {
|
||||
[modelLoaderNodeId]: {
|
||||
type: 'main_model_loader',
|
||||
@ -231,6 +232,8 @@ export const buildLinearTextToImageGraph = async (state: RootState): Promise<Non
|
||||
LATENTS_TO_IMAGE
|
||||
);
|
||||
|
||||
addInitialImageToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// Add Seamless To Graph
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
|
||||
|
@ -55,6 +55,8 @@ export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
|
||||
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';
|
||||
|
||||
// friendly graph ids
|
||||
export const CONTROL_LAYERS_GRAPH = 'control_layers_graph';
|
||||
export const SDXL_CONTROL_LAYERS_GRAPH = 'sdxl_control_layers_graph';
|
||||
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
||||
export const IMAGE_TO_IMAGE_GRAPH = 'image_to_image_graph';
|
||||
export const CANVAS_TEXT_TO_IMAGE_GRAPH = 'canvas_text_to_image_graph';
|
||||
|
Loading…
Reference in New Issue
Block a user