feat(nodes,ui): add t2i to linear UI

- Update backend metadata for t2i adapter
- Fix typo in `T2IAdapterInvocation`: `ip_adapter_model` -> `t2i_adapter_model`
- Update linear graphs to use t2i adapter
- Add client metadata recall for t2i adapter
- Fix bug with controlnet metadata recall - processor should be set to 'none' when recalling a control adapter
This commit is contained in:
psychedelicious
2023-10-06 20:16:00 +11:00
parent 1a9d2f1701
commit 078c9b6964
24 changed files with 2035 additions and 890 deletions

View File

@ -0,0 +1,116 @@
import { RootState } from 'app/store/store';
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import { omit } from 'lodash-es';
import {
CollectInvocation,
MetadataAccumulatorInvocation,
T2IAdapterInvocation,
} from 'services/api/types';
import { NonNullableGraph, T2IAdapterField } from '../../types/types';
import {
CANVAS_COHERENCE_DENOISE_LATENTS,
METADATA_ACCUMULATOR,
T2I_ADAPTER_COLLECT,
} from './constants';
export const addT2IAdaptersToLinearGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): void => {
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters);
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (validT2IAdapters.length) {
// Even though denoise_latents' control input is polymorphic, keep it simple and always use a collect
const t2iAdapterCollectNode: CollectInvocation = {
id: T2I_ADAPTER_COLLECT,
type: 'collect',
is_intermediate: true,
};
graph.nodes[T2I_ADAPTER_COLLECT] = t2iAdapterCollectNode;
graph.edges.push({
source: { node_id: T2I_ADAPTER_COLLECT, field: 'collection' },
destination: {
node_id: baseNodeId,
field: 't2i_adapter',
},
});
validT2IAdapters.forEach((t2iAdapter) => {
if (!t2iAdapter.model) {
return;
}
const {
id,
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
resizeMode,
model,
processorType,
weight,
} = t2iAdapter;
const t2iAdapterNode: T2IAdapterInvocation = {
id: `t2i_adapter_${id}`,
type: 't2i_adapter',
is_intermediate: true,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
t2i_adapter_model: model,
weight: weight,
};
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
t2iAdapterNode.image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
t2iAdapterNode.image = {
image_name: controlImage,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
return;
}
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode as T2IAdapterInvocation;
if (metadataAccumulator?.ipAdapters) {
// metadata accumulator only needs a control field - not the whole node
// extract what we need and add to the accumulator
const t2iAdapterField = omit(t2iAdapterNode, [
'id',
'type',
]) as T2IAdapterField;
metadataAccumulator.t2iAdapters.push(t2iAdapterField);
}
graph.edges.push({
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
destination: {
node_id: T2I_ADAPTER_COLLECT,
field: 'item',
},
});
if (CANVAS_COHERENCE_DENOISE_LATENTS in graph.nodes) {
graph.edges.push({
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
destination: {
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
field: 't2i_adapter',
},
});
}
});
}
};

View File

@ -8,6 +8,7 @@ import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -328,6 +329,7 @@ export const buildCanvasImageToImageGraph = (
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
t2iAdapters: [],
clip_skip: clipSkip,
strength,
init_image: initialImage.image_name,
@ -350,6 +352,7 @@ export const buildCanvasImageToImageGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -13,7 +13,9 @@ import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -40,7 +42,6 @@ import {
POSITIVE_CONDITIONING,
SEAMLESS,
} from './constants';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Canvas tab's Inpaint graph.
@ -653,7 +654,7 @@ export const buildCanvasInpaintGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!

View File

@ -14,6 +14,7 @@ import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -756,6 +757,8 @@ export const buildCanvasOutpaintGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!

View File

@ -27,6 +27,7 @@ import {
SEAMLESS,
} from './constants';
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
/**
* Builds the Canvas tab's Image to Image graph.
@ -339,6 +340,7 @@ export const buildCanvasSDXLImageToImageGraph = (
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
t2iAdapters: [],
strength,
init_image: initialImage.image_name,
};
@ -384,6 +386,7 @@ export const buildCanvasSDXLImageToImageGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -16,6 +16,7 @@ import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -682,6 +683,7 @@ export const buildCanvasSDXLInpaintGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -15,6 +15,7 @@ import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -785,6 +786,8 @@ export const buildCanvasSDXLOutpaintGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!

View File

@ -12,6 +12,7 @@ import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -321,6 +322,7 @@ export const buildCanvasSDXLTextToImageGraph = (
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
t2iAdapters: [],
};
graph.edges.push({
@ -364,6 +366,7 @@ export const buildCanvasSDXLTextToImageGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -11,6 +11,7 @@ import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -309,6 +310,7 @@ export const buildCanvasTextToImageGraph = (
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
t2iAdapters: [],
clip_skip: clipSkip,
};
@ -340,6 +342,7 @@ export const buildCanvasTextToImageGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -11,6 +11,7 @@ import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -329,6 +330,7 @@ export const buildLinearImageToImageGraph = (
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
t2iAdapters: [], // populated in addT2IAdapterToLinearGraph
clip_skip: clipSkip,
strength,
init_image: initialImage.imageName,
@ -362,6 +364,7 @@ export const buildLinearImageToImageGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -12,6 +12,7 @@ import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -349,6 +350,7 @@ export const buildLinearSDXLImageToImageGraph = (
controlnets: [],
loras: [],
ipAdapters: [],
t2iAdapters: [],
strength: strength,
init_image: initialImage.imageName,
positive_style_prompt: positiveStylePrompt,
@ -392,6 +394,8 @@ export const buildLinearSDXLImageToImageGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!

View File

@ -8,6 +8,7 @@ import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -243,6 +244,7 @@ export const buildLinearSDXLTextToImageGraph = (
controlnets: [],
loras: [],
ipAdapters: [],
t2iAdapters: [],
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
};
@ -284,6 +286,8 @@ export const buildLinearSDXLTextToImageGraph = (
// add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!

View File

@ -11,6 +11,7 @@ import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -251,6 +252,7 @@ export const buildLinearTextToImageGraph = (
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
t2iAdapters: [], // populated in addT2IAdapterToLinearGraph
clip_skip: clipSkip,
};
@ -283,6 +285,8 @@ export const buildLinearTextToImageGraph = (
// add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!

View File

@ -46,6 +46,7 @@ export const MASK_RESIZE_DOWN = 'mask_resize_down';
export const COLOR_CORRECT = 'color_correct';
export const PASTE_IMAGE = 'img_paste';
export const CONTROL_NET_COLLECT = 'control_net_collect';
export const T2I_ADAPTER_COLLECT = 't2i_adapter_collect';
export const IP_ADAPTER = 'ip_adapter';
export const DYNAMIC_PROMPT = 'dynamic_prompt';
export const IMAGE_COLLECTION = 'image_collection';