feat(nodes,ui): add t2i to linear UI

- Update backend metadata for t2i adapter
- Fix typo in `T2IAdapterInvocation`: `ip_adapter_model` -> `t2i_adapter_model`
- Update linear graphs to use t2i adapter
- Add client metadata recall for t2i adapter
- Fix bug with controlnet metadata recall - processor should be set to 'none' when recalling a control adapter
This commit is contained in:
psychedelicious 2023-10-06 20:16:00 +11:00
parent 1a9d2f1701
commit 078c9b6964
24 changed files with 2035 additions and 890 deletions

View File

@ -15,6 +15,7 @@ from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from ...version import __version__
@ -63,6 +64,7 @@ class CoreMetadata(BaseModelExcludeNull):
model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
ipAdapters: list[IPAdapterMetadataField] = Field(description="The IP Adapters used for inference")
t2iAdapters: list[T2IAdapterField] = Field(description="The IP Adapters used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
vae: Optional[VAEModelField] = Field(
default=None,
@ -139,6 +141,7 @@ class MetadataAccumulatorInvocation(BaseInvocation):
model: MainModelField = InputField(description="The main model used for inference")
controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
ipAdapters: list[IPAdapterMetadataField] = InputField(description="The IP Adapters used for inference")
t2iAdapters: list[T2IAdapterField] = Field(description="The IP Adapters used for inference")
loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
strength: Optional[float] = InputField(
default=None,

View File

@ -50,7 +50,7 @@ class T2IAdapterInvocation(BaseInvocation):
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
ip_adapter_model: T2IAdapterModelField = InputField(
t2i_adapter_model: T2IAdapterModelField = InputField(
description="The T2I-Adapter model.",
title="T2I-Adapter Model",
input=Input.Direct,
@ -74,7 +74,7 @@ class T2IAdapterInvocation(BaseInvocation):
return T2IAdapterOutput(
t2i_adapter=T2IAdapterField(
image=self.image,
t2i_adapter_model=self.ip_adapter_model,
t2i_adapter_model=self.t2i_adapter_model,
weight=self.weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,

View File

@ -31,6 +31,10 @@ export const addControlNetImageProcessedListener = () => {
return;
}
if (ca.processorType === 'none' || ca.processorNode.type === 'none') {
return;
}
// ControlNet one-off procressing graph is just the processor node, no edges.
// Also we need to grab the image.
const graph: Graph = {

View File

@ -155,22 +155,24 @@ export type RequiredZoeDepthImageProcessorInvocation = O.Required<
/**
* Any ControlNet Processor node, with its parameters flagged as required
*/
export type RequiredControlAdapterProcessorNode = O.Required<
| RequiredCannyImageProcessorInvocation
| RequiredColorMapImageProcessorInvocation
| RequiredContentShuffleImageProcessorInvocation
| RequiredHedImageProcessorInvocation
| RequiredLineartAnimeImageProcessorInvocation
| RequiredLineartImageProcessorInvocation
| RequiredMediapipeFaceProcessorInvocation
| RequiredMidasDepthImageProcessorInvocation
| RequiredMlsdImageProcessorInvocation
| RequiredNormalbaeImageProcessorInvocation
| RequiredOpenposeImageProcessorInvocation
| RequiredPidiImageProcessorInvocation
| RequiredZoeDepthImageProcessorInvocation,
'id'
>;
export type RequiredControlAdapterProcessorNode =
| O.Required<
| RequiredCannyImageProcessorInvocation
| RequiredColorMapImageProcessorInvocation
| RequiredContentShuffleImageProcessorInvocation
| RequiredHedImageProcessorInvocation
| RequiredLineartAnimeImageProcessorInvocation
| RequiredLineartImageProcessorInvocation
| RequiredMediapipeFaceProcessorInvocation
| RequiredMidasDepthImageProcessorInvocation
| RequiredMlsdImageProcessorInvocation
| RequiredNormalbaeImageProcessorInvocation
| RequiredOpenposeImageProcessorInvocation
| RequiredPidiImageProcessorInvocation
| RequiredZoeDepthImageProcessorInvocation,
'id'
>
| { type: 'none' };
/**
* Type guard for CannyImageProcessorInvocation

View File

@ -3,6 +3,7 @@ import {
CoreMetadata,
LoRAMetadataItem,
IPAdapterMetadataItem,
T2IAdapterMetadataItem,
} from 'features/nodes/types/types';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { memo, useMemo, useCallback } from 'react';
@ -10,6 +11,7 @@ import { useTranslation } from 'react-i18next';
import {
isValidControlNetModel,
isValidLoRAModel,
isValidT2IAdapterModel,
} from '../../../parameters/types/parameterSchemas';
import ImageMetadataItem from './ImageMetadataItem';
@ -36,6 +38,7 @@ const ImageMetadataActions = (props: Props) => {
recallLoRA,
recallControlNet,
recallIPAdapter,
recallT2IAdapter,
} = useRecallParameters();
const handleRecallPositivePrompt = useCallback(() => {
@ -99,6 +102,13 @@ const ImageMetadataActions = (props: Props) => {
[recallIPAdapter]
);
const handleRecallT2IAdapter = useCallback(
(ipAdapter: T2IAdapterMetadataItem) => {
recallT2IAdapter(ipAdapter);
},
[recallT2IAdapter]
);
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
return metadata?.controlnets
? metadata.controlnets.filter((controlnet) =>
@ -115,6 +125,14 @@ const ImageMetadataActions = (props: Props) => {
: [];
}, [metadata?.ipAdapters]);
const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => {
return metadata?.t2iAdapters
? metadata.t2iAdapters.filter((t2iAdapter) =>
isValidT2IAdapterModel(t2iAdapter.t2i_adapter_model)
)
: [];
}, [metadata?.t2iAdapters]);
if (!metadata || Object.keys(metadata).length === 0) {
return null;
}
@ -236,6 +254,14 @@ const ImageMetadataActions = (props: Props) => {
onClick={() => handleRecallIPAdapter(ipAdapter)}
/>
))}
{validT2IAdapters.map((t2iAdapter, index) => (
<ImageMetadataItem
key={index}
label="T2I Adapter"
value={`${t2iAdapter.t2i_adapter_model?.model_name} - ${t2iAdapter.weight}`}
onClick={() => handleRecallT2IAdapter(t2iAdapter)}
/>
))}
</>
);
};

View File

@ -1275,6 +1275,10 @@ const zIPAdapterMetadataItem = zIPAdapterField.deepPartial();
export type IPAdapterMetadataItem = z.infer<typeof zIPAdapterMetadataItem>;
const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial();
export type T2IAdapterMetadataItem = z.infer<typeof zT2IAdapterMetadataItem>;
export const zCoreMetadata = z
.object({
app_version: z.string().nullish().catch(null),
@ -1296,6 +1300,7 @@ export const zCoreMetadata = z
.catch(null),
controlnets: z.array(zControlNetMetadataItem).nullish().catch(null),
ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null),
t2iAdapters: z.array(zT2IAdapterMetadataItem).nullish().catch(null),
loras: z.array(zLoRAMetadataItem).nullish().catch(null),
vae: zVaeModelField.nullish().catch(null),
strength: z.number().nullish().catch(null),

View File

@ -0,0 +1,116 @@
import { RootState } from 'app/store/store';
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import { omit } from 'lodash-es';
import {
CollectInvocation,
MetadataAccumulatorInvocation,
T2IAdapterInvocation,
} from 'services/api/types';
import { NonNullableGraph, T2IAdapterField } from '../../types/types';
import {
CANVAS_COHERENCE_DENOISE_LATENTS,
METADATA_ACCUMULATOR,
T2I_ADAPTER_COLLECT,
} from './constants';
export const addT2IAdaptersToLinearGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): void => {
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters);
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (validT2IAdapters.length) {
// Even though denoise_latents' control input is polymorphic, keep it simple and always use a collect
const t2iAdapterCollectNode: CollectInvocation = {
id: T2I_ADAPTER_COLLECT,
type: 'collect',
is_intermediate: true,
};
graph.nodes[T2I_ADAPTER_COLLECT] = t2iAdapterCollectNode;
graph.edges.push({
source: { node_id: T2I_ADAPTER_COLLECT, field: 'collection' },
destination: {
node_id: baseNodeId,
field: 't2i_adapter',
},
});
validT2IAdapters.forEach((t2iAdapter) => {
if (!t2iAdapter.model) {
return;
}
const {
id,
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
resizeMode,
model,
processorType,
weight,
} = t2iAdapter;
const t2iAdapterNode: T2IAdapterInvocation = {
id: `t2i_adapter_${id}`,
type: 't2i_adapter',
is_intermediate: true,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
t2i_adapter_model: model,
weight: weight,
};
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
t2iAdapterNode.image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
t2iAdapterNode.image = {
image_name: controlImage,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
return;
}
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode as T2IAdapterInvocation;
if (metadataAccumulator?.ipAdapters) {
// metadata accumulator only needs a control field - not the whole node
// extract what we need and add to the accumulator
const t2iAdapterField = omit(t2iAdapterNode, [
'id',
'type',
]) as T2IAdapterField;
metadataAccumulator.t2iAdapters.push(t2iAdapterField);
}
graph.edges.push({
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
destination: {
node_id: T2I_ADAPTER_COLLECT,
field: 'item',
},
});
if (CANVAS_COHERENCE_DENOISE_LATENTS in graph.nodes) {
graph.edges.push({
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
destination: {
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
field: 't2i_adapter',
},
});
}
});
}
};

View File

@ -8,6 +8,7 @@ import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -328,6 +329,7 @@ export const buildCanvasImageToImageGraph = (
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
t2iAdapters: [],
clip_skip: clipSkip,
strength,
init_image: initialImage.image_name,
@ -350,6 +352,7 @@ export const buildCanvasImageToImageGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -13,7 +13,9 @@ import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -40,7 +42,6 @@ import {
POSITIVE_CONDITIONING,
SEAMLESS,
} from './constants';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Canvas tab's Inpaint graph.
@ -653,7 +654,7 @@ export const buildCanvasInpaintGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!

View File

@ -14,6 +14,7 @@ import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -756,6 +757,8 @@ export const buildCanvasOutpaintGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!

View File

@ -27,6 +27,7 @@ import {
SEAMLESS,
} from './constants';
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
/**
* Builds the Canvas tab's Image to Image graph.
@ -339,6 +340,7 @@ export const buildCanvasSDXLImageToImageGraph = (
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
t2iAdapters: [],
strength,
init_image: initialImage.image_name,
};
@ -384,6 +386,7 @@ export const buildCanvasSDXLImageToImageGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -16,6 +16,7 @@ import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -682,6 +683,7 @@ export const buildCanvasSDXLInpaintGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -15,6 +15,7 @@ import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -785,6 +786,8 @@ export const buildCanvasSDXLOutpaintGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!

View File

@ -12,6 +12,7 @@ import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -321,6 +322,7 @@ export const buildCanvasSDXLTextToImageGraph = (
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
t2iAdapters: [],
};
graph.edges.push({
@ -364,6 +366,7 @@ export const buildCanvasSDXLTextToImageGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -11,6 +11,7 @@ import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -309,6 +310,7 @@ export const buildCanvasTextToImageGraph = (
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
t2iAdapters: [],
clip_skip: clipSkip,
};
@ -340,6 +342,7 @@ export const buildCanvasTextToImageGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -11,6 +11,7 @@ import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -329,6 +330,7 @@ export const buildLinearImageToImageGraph = (
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
t2iAdapters: [], // populated in addT2IAdapterToLinearGraph
clip_skip: clipSkip,
strength,
init_image: initialImage.imageName,
@ -362,6 +364,7 @@ export const buildLinearImageToImageGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -12,6 +12,7 @@ import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -349,6 +350,7 @@ export const buildLinearSDXLImageToImageGraph = (
controlnets: [],
loras: [],
ipAdapters: [],
t2iAdapters: [],
strength: strength,
init_image: initialImage.imageName,
positive_style_prompt: positiveStylePrompt,
@ -392,6 +394,8 @@ export const buildLinearSDXLImageToImageGraph = (
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!

View File

@ -8,6 +8,7 @@ import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -243,6 +244,7 @@ export const buildLinearSDXLTextToImageGraph = (
controlnets: [],
loras: [],
ipAdapters: [],
t2iAdapters: [],
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
};
@ -284,6 +286,8 @@ export const buildLinearSDXLTextToImageGraph = (
// add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!

View File

@ -11,6 +11,7 @@ import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -251,6 +252,7 @@ export const buildLinearTextToImageGraph = (
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
t2iAdapters: [], // populated in addT2IAdapterToLinearGraph
clip_skip: clipSkip,
};
@ -283,6 +285,8 @@ export const buildLinearTextToImageGraph = (
// add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!

View File

@ -46,6 +46,7 @@ export const MASK_RESIZE_DOWN = 'mask_resize_down';
export const COLOR_CORRECT = 'color_correct';
export const PASTE_IMAGE = 'img_paste';
export const CONTROL_NET_COLLECT = 'control_net_collect';
export const T2I_ADAPTER_COLLECT = 't2i_adapter_collect';
export const IP_ADAPTER = 'ip_adapter';
export const DYNAMIC_PROMPT = 'dynamic_prompt';
export const IMAGE_COLLECTION = 'image_collection';

View File

@ -3,10 +3,7 @@ import { useAppToaster } from 'app/components/Toaster';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import {
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
CONTROLNET_PROCESSORS,
} from 'features/controlAdapters/store/constants';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import {
controlAdapterRecalled,
controlAdaptersReset,
@ -14,16 +11,19 @@ import {
import {
ControlNetConfig,
IPAdapterConfig,
T2IAdapterConfig,
} from 'features/controlAdapters/store/types';
import {
initialControlNet,
initialIPAdapter,
initialT2IAdapter,
} from 'features/controlAdapters/util/buildControlAdapter';
import {
ControlNetMetadataItem,
CoreMetadata,
IPAdapterMetadataItem,
LoRAMetadataItem,
T2IAdapterMetadataItem,
} from 'features/nodes/types/types';
import {
refinerModelChanged,
@ -44,9 +44,11 @@ import {
controlNetModelsAdapter,
ipAdapterModelsAdapter,
loraModelsAdapter,
t2iAdapterModelsAdapter,
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetLoRAModelsQuery,
useGetT2IAdapterModelsQuery,
} from '../../../services/api/endpoints/models';
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
import { initialImageSelected, modelSelected } from '../store/actions';
@ -363,7 +365,7 @@ export const useRecallParameters = () => {
* Recall LoRA with toast
*/
const { data: loras } = useGetLoRAModelsQuery(undefined);
const { data: loraModels } = useGetLoRAModelsQuery(undefined);
const prepareLoRAMetadataItem = useCallback(
(loraMetadataItem: LoRAMetadataItem) => {
@ -373,10 +375,10 @@ export const useRecallParameters = () => {
const { base_model, model_name } = loraMetadataItem.lora;
const matchingLoRA = loras
const matchingLoRA = loraModels
? loraModelsAdapter
.getSelectors()
.selectById(loras, `${base_model}/lora/${model_name}`)
.selectById(loraModels, `${base_model}/lora/${model_name}`)
: undefined;
if (!matchingLoRA) {
@ -395,7 +397,7 @@ export const useRecallParameters = () => {
return { lora: matchingLoRA, error: null };
},
[loras, model?.base_model]
[loraModels, model?.base_model]
);
const recallLoRA = useCallback(
@ -420,7 +422,7 @@ export const useRecallParameters = () => {
* Recall ControlNet with toast
*/
const { data: controlNets } = useGetControlNetModelsQuery(undefined);
const { data: controlNetModels } = useGetControlNetModelsQuery(undefined);
const prepareControlNetMetadataItem = useCallback(
(controlnetMetadataItem: ControlNetMetadataItem) => {
@ -438,11 +440,11 @@ export const useRecallParameters = () => {
resize_mode,
} = controlnetMetadataItem;
const matchingControlNetModel = controlNets
const matchingControlNetModel = controlNetModels
? controlNetModelsAdapter
.getSelectors()
.selectById(
controlNets,
controlNetModels,
`${control_model.base_model}/controlnet/${control_model.model_name}`
)
: undefined;
@ -461,16 +463,9 @@ export const useRecallParameters = () => {
};
}
let processorType = initialControlNet.processorType;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (matchingControlNetModel.model_name.includes(modelSubstring)) {
processorType =
CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring] ||
initialControlNet.processorType;
break;
}
}
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
// We don't save the original image that was processed into a control image, only the processed image
const processorType = 'none';
const processorNode = CONTROLNET_PROCESSORS.none.default;
const controlnet: ControlNetConfig = {
type: 'controlnet',
@ -487,17 +482,14 @@ export const useRecallParameters = () => {
controlImage: image?.image_name || null,
processedControlImage: image?.image_name || null,
processorType,
processorNode:
processorNode.type !== 'none'
? processorNode
: initialControlNet.processorNode,
processorNode,
shouldAutoConfig: true,
id: uuidv4(),
};
return { controlnet, error: null };
},
[controlNets, model?.base_model]
[controlNetModels, model?.base_model]
);
const recallControlNet = useCallback(
@ -521,11 +513,101 @@ export const useRecallParameters = () => {
]
);
/**
* Recall T2I Adapter with toast
*/
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(undefined);
const prepareT2IAdapterMetadataItem = useCallback(
(t2iAdapterMetadataItem: T2IAdapterMetadataItem) => {
if (!isValidControlNetModel(t2iAdapterMetadataItem.t2i_adapter_model)) {
return { controlnet: null, error: 'Invalid ControlNet model' };
}
const {
image,
t2i_adapter_model,
weight,
begin_step_percent,
end_step_percent,
resize_mode,
} = t2iAdapterMetadataItem;
const matchingT2IAdapterModel = t2iAdapterModels
? t2iAdapterModelsAdapter
.getSelectors()
.selectById(
t2iAdapterModels,
`${t2i_adapter_model.base_model}/t2i_adapter/${t2i_adapter_model.model_name}`
)
: undefined;
if (!matchingT2IAdapterModel) {
return { controlnet: null, error: 'ControlNet model is not installed' };
}
const isCompatibleBaseModel =
matchingT2IAdapterModel?.base_model === model?.base_model;
if (!isCompatibleBaseModel) {
return {
t2iAdapter: null,
error: 'ControlNet incompatible with currently-selected model',
};
}
// We don't save the original image that was processed into a control image, only the processed image
const processorType = 'none';
const processorNode = CONTROLNET_PROCESSORS.none.default;
const t2iAdapter: T2IAdapterConfig = {
type: 't2i_adapter',
isEnabled: true,
model: matchingT2IAdapterModel,
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct,
endStepPct: end_step_percent || initialT2IAdapter.endStepPct,
resizeMode: resize_mode || initialT2IAdapter.resizeMode,
controlImage: image?.image_name || null,
processedControlImage: image?.image_name || null,
processorType,
processorNode,
shouldAutoConfig: true,
id: uuidv4(),
};
return { t2iAdapter, error: null };
},
[model?.base_model, t2iAdapterModels]
);
const recallT2IAdapter = useCallback(
(t2iAdapterMetadataItem: T2IAdapterMetadataItem) => {
const result = prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem);
if (!result.t2iAdapter) {
parameterNotSetToast(result.error);
return;
}
dispatch(controlAdapterRecalled(result.t2iAdapter));
parameterSetToast();
},
[
prepareT2IAdapterMetadataItem,
dispatch,
parameterSetToast,
parameterNotSetToast,
]
);
/**
* Recall IP Adapter with toast
*/
const { data: ipAdapters } = useGetIPAdapterModelsQuery(undefined);
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(undefined);
const prepareIPAdapterMetadataItem = useCallback(
(ipAdapterMetadataItem: IPAdapterMetadataItem) => {
@ -541,11 +623,11 @@ export const useRecallParameters = () => {
end_step_percent,
} = ipAdapterMetadataItem;
const matchingIPAdapterModel = ipAdapters
const matchingIPAdapterModel = ipAdapterModels
? ipAdapterModelsAdapter
.getSelectors()
.selectById(
ipAdapters,
ipAdapterModels,
`${ip_adapter_model.base_model}/ip_adapter/${ip_adapter_model.model_name}`
)
: undefined;
@ -577,7 +659,7 @@ export const useRecallParameters = () => {
return { ipAdapter, error: null };
},
[ipAdapters, model?.base_model]
[ipAdapterModels, model?.base_model]
);
const recallIPAdapter = useCallback(
@ -641,6 +723,7 @@ export const useRecallParameters = () => {
loras,
controlnets,
ipAdapters,
t2iAdapters,
} = metadata;
if (isValidCfgScale(cfg_scale)) {
@ -745,15 +828,23 @@ export const useRecallParameters = () => {
}
});
t2iAdapters?.forEach((t2iAdapter) => {
const result = prepareT2IAdapterMetadataItem(t2iAdapter);
if (result.t2iAdapter) {
dispatch(controlAdapterRecalled(result.t2iAdapter));
}
});
allParameterSetToast();
},
[
allParameterNotSetToast,
allParameterSetToast,
dispatch,
allParameterSetToast,
allParameterNotSetToast,
prepareLoRAMetadataItem,
prepareControlNetMetadataItem,
prepareIPAdapterMetadataItem,
prepareT2IAdapterMetadataItem,
]
);
@ -774,6 +865,7 @@ export const useRecallParameters = () => {
recallLoRA,
recallControlNet,
recallIPAdapter,
recallT2IAdapter,
recallAllParameters,
sendToImageToImage,
};

View File

@ -347,6 +347,10 @@ export const zT2IAdapterModel = z.object({
model_name: z.string().min(1),
base_model: zBaseModel,
});
export const isValidT2IAdapterModel = (
val: unknown
): val is T2IAdapterModelParam => zT2IAdapterModel.safeParse(val).success;
/**
* Type alias for model parameter, inferred from its zod schema
*/

File diff suppressed because one or more lines are too long

View File

@ -155,6 +155,7 @@ export type SaveImageInvocation = s['SaveImageInvocation'];
// ControlNet Nodes
export type ControlNetInvocation = s['ControlNetInvocation'];
export type T2IAdapterInvocation = s['T2IAdapterInvocation'];
export type IPAdapterInvocation = s['IPAdapterInvocation'];
export type CannyImageProcessorInvocation = s['CannyImageProcessorInvocation'];
export type ColorMapImageProcessorInvocation =