mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes,ui): add t2i to linear UI
- Update backend metadata for t2i adapter - Fix typo in `T2IAdapterInvocation`: `ip_adapter_model` -> `t2i_adapter_model` - Update linear graphs to use t2i adapter - Add client metadata recall for t2i adapter - Fix bug with controlnet metadata recall - processor should be set to 'none' when recalling a control adapter
This commit is contained in:
parent
1a9d2f1701
commit
078c9b6964
@ -15,6 +15,7 @@ from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
from ...version import __version__
|
||||
@ -63,6 +64,7 @@ class CoreMetadata(BaseModelExcludeNull):
|
||||
model: MainModelField = Field(description="The main model used for inference")
|
||||
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
||||
ipAdapters: list[IPAdapterMetadataField] = Field(description="The IP Adapters used for inference")
|
||||
t2iAdapters: list[T2IAdapterField] = Field(description="The IP Adapters used for inference")
|
||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||
vae: Optional[VAEModelField] = Field(
|
||||
default=None,
|
||||
@ -139,6 +141,7 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
model: MainModelField = InputField(description="The main model used for inference")
|
||||
controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
|
||||
ipAdapters: list[IPAdapterMetadataField] = InputField(description="The IP Adapters used for inference")
|
||||
t2iAdapters: list[T2IAdapterField] = Field(description="The IP Adapters used for inference")
|
||||
loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
|
||||
strength: Optional[float] = InputField(
|
||||
default=None,
|
||||
|
@ -50,7 +50,7 @@ class T2IAdapterInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: T2IAdapterModelField = InputField(
|
||||
t2i_adapter_model: T2IAdapterModelField = InputField(
|
||||
description="The T2I-Adapter model.",
|
||||
title="T2I-Adapter Model",
|
||||
input=Input.Direct,
|
||||
@ -74,7 +74,7 @@ class T2IAdapterInvocation(BaseInvocation):
|
||||
return T2IAdapterOutput(
|
||||
t2i_adapter=T2IAdapterField(
|
||||
image=self.image,
|
||||
t2i_adapter_model=self.ip_adapter_model,
|
||||
t2i_adapter_model=self.t2i_adapter_model,
|
||||
weight=self.weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
|
@ -31,6 +31,10 @@ export const addControlNetImageProcessedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
if (ca.processorType === 'none' || ca.processorNode.type === 'none') {
|
||||
return;
|
||||
}
|
||||
|
||||
// ControlNet one-off procressing graph is just the processor node, no edges.
|
||||
// Also we need to grab the image.
|
||||
const graph: Graph = {
|
||||
|
@ -155,22 +155,24 @@ export type RequiredZoeDepthImageProcessorInvocation = O.Required<
|
||||
/**
|
||||
* Any ControlNet Processor node, with its parameters flagged as required
|
||||
*/
|
||||
export type RequiredControlAdapterProcessorNode = O.Required<
|
||||
| RequiredCannyImageProcessorInvocation
|
||||
| RequiredColorMapImageProcessorInvocation
|
||||
| RequiredContentShuffleImageProcessorInvocation
|
||||
| RequiredHedImageProcessorInvocation
|
||||
| RequiredLineartAnimeImageProcessorInvocation
|
||||
| RequiredLineartImageProcessorInvocation
|
||||
| RequiredMediapipeFaceProcessorInvocation
|
||||
| RequiredMidasDepthImageProcessorInvocation
|
||||
| RequiredMlsdImageProcessorInvocation
|
||||
| RequiredNormalbaeImageProcessorInvocation
|
||||
| RequiredOpenposeImageProcessorInvocation
|
||||
| RequiredPidiImageProcessorInvocation
|
||||
| RequiredZoeDepthImageProcessorInvocation,
|
||||
'id'
|
||||
>;
|
||||
export type RequiredControlAdapterProcessorNode =
|
||||
| O.Required<
|
||||
| RequiredCannyImageProcessorInvocation
|
||||
| RequiredColorMapImageProcessorInvocation
|
||||
| RequiredContentShuffleImageProcessorInvocation
|
||||
| RequiredHedImageProcessorInvocation
|
||||
| RequiredLineartAnimeImageProcessorInvocation
|
||||
| RequiredLineartImageProcessorInvocation
|
||||
| RequiredMediapipeFaceProcessorInvocation
|
||||
| RequiredMidasDepthImageProcessorInvocation
|
||||
| RequiredMlsdImageProcessorInvocation
|
||||
| RequiredNormalbaeImageProcessorInvocation
|
||||
| RequiredOpenposeImageProcessorInvocation
|
||||
| RequiredPidiImageProcessorInvocation
|
||||
| RequiredZoeDepthImageProcessorInvocation,
|
||||
'id'
|
||||
>
|
||||
| { type: 'none' };
|
||||
|
||||
/**
|
||||
* Type guard for CannyImageProcessorInvocation
|
||||
|
@ -3,6 +3,7 @@ import {
|
||||
CoreMetadata,
|
||||
LoRAMetadataItem,
|
||||
IPAdapterMetadataItem,
|
||||
T2IAdapterMetadataItem,
|
||||
} from 'features/nodes/types/types';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { memo, useMemo, useCallback } from 'react';
|
||||
@ -10,6 +11,7 @@ import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
isValidControlNetModel,
|
||||
isValidLoRAModel,
|
||||
isValidT2IAdapterModel,
|
||||
} from '../../../parameters/types/parameterSchemas';
|
||||
import ImageMetadataItem from './ImageMetadataItem';
|
||||
|
||||
@ -36,6 +38,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
recallLoRA,
|
||||
recallControlNet,
|
||||
recallIPAdapter,
|
||||
recallT2IAdapter,
|
||||
} = useRecallParameters();
|
||||
|
||||
const handleRecallPositivePrompt = useCallback(() => {
|
||||
@ -99,6 +102,13 @@ const ImageMetadataActions = (props: Props) => {
|
||||
[recallIPAdapter]
|
||||
);
|
||||
|
||||
const handleRecallT2IAdapter = useCallback(
|
||||
(ipAdapter: T2IAdapterMetadataItem) => {
|
||||
recallT2IAdapter(ipAdapter);
|
||||
},
|
||||
[recallT2IAdapter]
|
||||
);
|
||||
|
||||
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
|
||||
return metadata?.controlnets
|
||||
? metadata.controlnets.filter((controlnet) =>
|
||||
@ -115,6 +125,14 @@ const ImageMetadataActions = (props: Props) => {
|
||||
: [];
|
||||
}, [metadata?.ipAdapters]);
|
||||
|
||||
const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => {
|
||||
return metadata?.t2iAdapters
|
||||
? metadata.t2iAdapters.filter((t2iAdapter) =>
|
||||
isValidT2IAdapterModel(t2iAdapter.t2i_adapter_model)
|
||||
)
|
||||
: [];
|
||||
}, [metadata?.t2iAdapters]);
|
||||
|
||||
if (!metadata || Object.keys(metadata).length === 0) {
|
||||
return null;
|
||||
}
|
||||
@ -236,6 +254,14 @@ const ImageMetadataActions = (props: Props) => {
|
||||
onClick={() => handleRecallIPAdapter(ipAdapter)}
|
||||
/>
|
||||
))}
|
||||
{validT2IAdapters.map((t2iAdapter, index) => (
|
||||
<ImageMetadataItem
|
||||
key={index}
|
||||
label="T2I Adapter"
|
||||
value={`${t2iAdapter.t2i_adapter_model?.model_name} - ${t2iAdapter.weight}`}
|
||||
onClick={() => handleRecallT2IAdapter(t2iAdapter)}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
@ -1275,6 +1275,10 @@ const zIPAdapterMetadataItem = zIPAdapterField.deepPartial();
|
||||
|
||||
export type IPAdapterMetadataItem = z.infer<typeof zIPAdapterMetadataItem>;
|
||||
|
||||
const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial();
|
||||
|
||||
export type T2IAdapterMetadataItem = z.infer<typeof zT2IAdapterMetadataItem>;
|
||||
|
||||
export const zCoreMetadata = z
|
||||
.object({
|
||||
app_version: z.string().nullish().catch(null),
|
||||
@ -1296,6 +1300,7 @@ export const zCoreMetadata = z
|
||||
.catch(null),
|
||||
controlnets: z.array(zControlNetMetadataItem).nullish().catch(null),
|
||||
ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null),
|
||||
t2iAdapters: z.array(zT2IAdapterMetadataItem).nullish().catch(null),
|
||||
loras: z.array(zLoRAMetadataItem).nullish().catch(null),
|
||||
vae: zVaeModelField.nullish().catch(null),
|
||||
strength: z.number().nullish().catch(null),
|
||||
|
@ -0,0 +1,116 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { omit } from 'lodash-es';
|
||||
import {
|
||||
CollectInvocation,
|
||||
MetadataAccumulatorInvocation,
|
||||
T2IAdapterInvocation,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph, T2IAdapterField } from '../../types/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
METADATA_ACCUMULATOR,
|
||||
T2I_ADAPTER_COLLECT,
|
||||
} from './constants';
|
||||
|
||||
export const addT2IAdaptersToLinearGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): void => {
|
||||
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters);
|
||||
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (validT2IAdapters.length) {
|
||||
// Even though denoise_latents' control input is polymorphic, keep it simple and always use a collect
|
||||
const t2iAdapterCollectNode: CollectInvocation = {
|
||||
id: T2I_ADAPTER_COLLECT,
|
||||
type: 'collect',
|
||||
is_intermediate: true,
|
||||
};
|
||||
graph.nodes[T2I_ADAPTER_COLLECT] = t2iAdapterCollectNode;
|
||||
graph.edges.push({
|
||||
source: { node_id: T2I_ADAPTER_COLLECT, field: 'collection' },
|
||||
destination: {
|
||||
node_id: baseNodeId,
|
||||
field: 't2i_adapter',
|
||||
},
|
||||
});
|
||||
|
||||
validT2IAdapters.forEach((t2iAdapter) => {
|
||||
if (!t2iAdapter.model) {
|
||||
return;
|
||||
}
|
||||
const {
|
||||
id,
|
||||
controlImage,
|
||||
processedControlImage,
|
||||
beginStepPct,
|
||||
endStepPct,
|
||||
resizeMode,
|
||||
model,
|
||||
processorType,
|
||||
weight,
|
||||
} = t2iAdapter;
|
||||
|
||||
const t2iAdapterNode: T2IAdapterInvocation = {
|
||||
id: `t2i_adapter_${id}`,
|
||||
type: 't2i_adapter',
|
||||
is_intermediate: true,
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
resize_mode: resizeMode,
|
||||
t2i_adapter_model: model,
|
||||
weight: weight,
|
||||
};
|
||||
|
||||
if (processedControlImage && processorType !== 'none') {
|
||||
// We've already processed the image in the app, so we can just use the processed image
|
||||
t2iAdapterNode.image = {
|
||||
image_name: processedControlImage,
|
||||
};
|
||||
} else if (controlImage) {
|
||||
// The control image is preprocessed
|
||||
t2iAdapterNode.image = {
|
||||
image_name: controlImage,
|
||||
};
|
||||
} else {
|
||||
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
|
||||
return;
|
||||
}
|
||||
|
||||
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode as T2IAdapterInvocation;
|
||||
|
||||
if (metadataAccumulator?.ipAdapters) {
|
||||
// metadata accumulator only needs a control field - not the whole node
|
||||
// extract what we need and add to the accumulator
|
||||
const t2iAdapterField = omit(t2iAdapterNode, [
|
||||
'id',
|
||||
'type',
|
||||
]) as T2IAdapterField;
|
||||
metadataAccumulator.t2iAdapters.push(t2iAdapterField);
|
||||
}
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
|
||||
destination: {
|
||||
node_id: T2I_ADAPTER_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
|
||||
if (CANVAS_COHERENCE_DENOISE_LATENTS in graph.nodes) {
|
||||
graph.edges.push({
|
||||
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
field: 't2i_adapter',
|
||||
},
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
@ -8,6 +8,7 @@ import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSaveImageNode } from './addSaveImageNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
@ -328,6 +329,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
||||
t2iAdapters: [],
|
||||
clip_skip: clipSkip,
|
||||
strength,
|
||||
init_image: initialImage.image_name,
|
||||
@ -350,6 +352,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -13,7 +13,9 @@ import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSaveImageNode } from './addSaveImageNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
@ -40,7 +42,6 @@ import {
|
||||
POSITIVE_CONDITIONING,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { addSaveImageNode } from './addSaveImageNode';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Inpaint graph.
|
||||
@ -653,7 +654,7 @@ export const buildCanvasInpaintGraph = (
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
// must add before watermarker!
|
||||
|
@ -14,6 +14,7 @@ import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSaveImageNode } from './addSaveImageNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
@ -756,6 +757,8 @@ export const buildCanvasOutpaintGraph = (
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
// must add before watermarker!
|
||||
|
@ -27,6 +27,7 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Image to Image graph.
|
||||
@ -339,6 +340,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
||||
t2iAdapters: [],
|
||||
strength,
|
||||
init_image: initialImage.image_name,
|
||||
};
|
||||
@ -384,6 +386,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -16,6 +16,7 @@ import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
import { addSaveImageNode } from './addSaveImageNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
@ -682,6 +683,7 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -15,6 +15,7 @@ import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
import { addSaveImageNode } from './addSaveImageNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
@ -785,6 +786,8 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
// must add before watermarker!
|
||||
|
@ -12,6 +12,7 @@ import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
import { addSaveImageNode } from './addSaveImageNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
@ -321,6 +322,7 @@ export const buildCanvasSDXLTextToImageGraph = (
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
||||
t2iAdapters: [],
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
@ -364,6 +366,7 @@ export const buildCanvasSDXLTextToImageGraph = (
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -11,6 +11,7 @@ import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSaveImageNode } from './addSaveImageNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
@ -309,6 +310,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
||||
t2iAdapters: [],
|
||||
clip_skip: clipSkip,
|
||||
};
|
||||
|
||||
@ -340,6 +342,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -11,6 +11,7 @@ import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSaveImageNode } from './addSaveImageNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
@ -329,6 +330,7 @@ export const buildLinearImageToImageGraph = (
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
||||
t2iAdapters: [], // populated in addT2IAdapterToLinearGraph
|
||||
clip_skip: clipSkip,
|
||||
strength,
|
||||
init_image: initialImage.imageName,
|
||||
@ -362,6 +364,7 @@ export const buildLinearImageToImageGraph = (
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -12,6 +12,7 @@ import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
import { addSaveImageNode } from './addSaveImageNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
@ -349,6 +350,7 @@ export const buildLinearSDXLImageToImageGraph = (
|
||||
controlnets: [],
|
||||
loras: [],
|
||||
ipAdapters: [],
|
||||
t2iAdapters: [],
|
||||
strength: strength,
|
||||
init_image: initialImage.imageName,
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
@ -392,6 +394,8 @@ export const buildLinearSDXLImageToImageGraph = (
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
// must add before watermarker!
|
||||
|
@ -8,6 +8,7 @@ import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
import { addSaveImageNode } from './addSaveImageNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
@ -243,6 +244,7 @@ export const buildLinearSDXLTextToImageGraph = (
|
||||
controlnets: [],
|
||||
loras: [],
|
||||
ipAdapters: [],
|
||||
t2iAdapters: [],
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
negative_style_prompt: negativeStylePrompt,
|
||||
};
|
||||
@ -284,6 +286,8 @@ export const buildLinearSDXLTextToImageGraph = (
|
||||
// add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
// must add before watermarker!
|
||||
|
@ -11,6 +11,7 @@ import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSaveImageNode } from './addSaveImageNode';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
@ -251,6 +252,7 @@ export const buildLinearTextToImageGraph = (
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
||||
t2iAdapters: [], // populated in addT2IAdapterToLinearGraph
|
||||
clip_skip: clipSkip,
|
||||
};
|
||||
|
||||
@ -283,6 +285,8 @@ export const buildLinearTextToImageGraph = (
|
||||
// add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
// must add before watermarker!
|
||||
|
@ -46,6 +46,7 @@ export const MASK_RESIZE_DOWN = 'mask_resize_down';
|
||||
export const COLOR_CORRECT = 'color_correct';
|
||||
export const PASTE_IMAGE = 'img_paste';
|
||||
export const CONTROL_NET_COLLECT = 'control_net_collect';
|
||||
export const T2I_ADAPTER_COLLECT = 't2i_adapter_collect';
|
||||
export const IP_ADAPTER = 'ip_adapter';
|
||||
export const DYNAMIC_PROMPT = 'dynamic_prompt';
|
||||
export const IMAGE_COLLECTION = 'image_collection';
|
||||
|
@ -3,10 +3,7 @@ import { useAppToaster } from 'app/components/Toaster';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import {
|
||||
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
|
||||
CONTROLNET_PROCESSORS,
|
||||
} from 'features/controlAdapters/store/constants';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import {
|
||||
controlAdapterRecalled,
|
||||
controlAdaptersReset,
|
||||
@ -14,16 +11,19 @@ import {
|
||||
import {
|
||||
ControlNetConfig,
|
||||
IPAdapterConfig,
|
||||
T2IAdapterConfig,
|
||||
} from 'features/controlAdapters/store/types';
|
||||
import {
|
||||
initialControlNet,
|
||||
initialIPAdapter,
|
||||
initialT2IAdapter,
|
||||
} from 'features/controlAdapters/util/buildControlAdapter';
|
||||
import {
|
||||
ControlNetMetadataItem,
|
||||
CoreMetadata,
|
||||
IPAdapterMetadataItem,
|
||||
LoRAMetadataItem,
|
||||
T2IAdapterMetadataItem,
|
||||
} from 'features/nodes/types/types';
|
||||
import {
|
||||
refinerModelChanged,
|
||||
@ -44,9 +44,11 @@ import {
|
||||
controlNetModelsAdapter,
|
||||
ipAdapterModelsAdapter,
|
||||
loraModelsAdapter,
|
||||
t2iAdapterModelsAdapter,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetT2IAdapterModelsQuery,
|
||||
} from '../../../services/api/endpoints/models';
|
||||
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
|
||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||
@ -363,7 +365,7 @@ export const useRecallParameters = () => {
|
||||
* Recall LoRA with toast
|
||||
*/
|
||||
|
||||
const { data: loras } = useGetLoRAModelsQuery(undefined);
|
||||
const { data: loraModels } = useGetLoRAModelsQuery(undefined);
|
||||
|
||||
const prepareLoRAMetadataItem = useCallback(
|
||||
(loraMetadataItem: LoRAMetadataItem) => {
|
||||
@ -373,10 +375,10 @@ export const useRecallParameters = () => {
|
||||
|
||||
const { base_model, model_name } = loraMetadataItem.lora;
|
||||
|
||||
const matchingLoRA = loras
|
||||
const matchingLoRA = loraModels
|
||||
? loraModelsAdapter
|
||||
.getSelectors()
|
||||
.selectById(loras, `${base_model}/lora/${model_name}`)
|
||||
.selectById(loraModels, `${base_model}/lora/${model_name}`)
|
||||
: undefined;
|
||||
|
||||
if (!matchingLoRA) {
|
||||
@ -395,7 +397,7 @@ export const useRecallParameters = () => {
|
||||
|
||||
return { lora: matchingLoRA, error: null };
|
||||
},
|
||||
[loras, model?.base_model]
|
||||
[loraModels, model?.base_model]
|
||||
);
|
||||
|
||||
const recallLoRA = useCallback(
|
||||
@ -420,7 +422,7 @@ export const useRecallParameters = () => {
|
||||
* Recall ControlNet with toast
|
||||
*/
|
||||
|
||||
const { data: controlNets } = useGetControlNetModelsQuery(undefined);
|
||||
const { data: controlNetModels } = useGetControlNetModelsQuery(undefined);
|
||||
|
||||
const prepareControlNetMetadataItem = useCallback(
|
||||
(controlnetMetadataItem: ControlNetMetadataItem) => {
|
||||
@ -438,11 +440,11 @@ export const useRecallParameters = () => {
|
||||
resize_mode,
|
||||
} = controlnetMetadataItem;
|
||||
|
||||
const matchingControlNetModel = controlNets
|
||||
const matchingControlNetModel = controlNetModels
|
||||
? controlNetModelsAdapter
|
||||
.getSelectors()
|
||||
.selectById(
|
||||
controlNets,
|
||||
controlNetModels,
|
||||
`${control_model.base_model}/controlnet/${control_model.model_name}`
|
||||
)
|
||||
: undefined;
|
||||
@ -461,16 +463,9 @@ export const useRecallParameters = () => {
|
||||
};
|
||||
}
|
||||
|
||||
let processorType = initialControlNet.processorType;
|
||||
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||
if (matchingControlNetModel.model_name.includes(modelSubstring)) {
|
||||
processorType =
|
||||
CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring] ||
|
||||
initialControlNet.processorType;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
|
||||
// We don't save the original image that was processed into a control image, only the processed image
|
||||
const processorType = 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||
|
||||
const controlnet: ControlNetConfig = {
|
||||
type: 'controlnet',
|
||||
@ -487,17 +482,14 @@ export const useRecallParameters = () => {
|
||||
controlImage: image?.image_name || null,
|
||||
processedControlImage: image?.image_name || null,
|
||||
processorType,
|
||||
processorNode:
|
||||
processorNode.type !== 'none'
|
||||
? processorNode
|
||||
: initialControlNet.processorNode,
|
||||
processorNode,
|
||||
shouldAutoConfig: true,
|
||||
id: uuidv4(),
|
||||
};
|
||||
|
||||
return { controlnet, error: null };
|
||||
},
|
||||
[controlNets, model?.base_model]
|
||||
[controlNetModels, model?.base_model]
|
||||
);
|
||||
|
||||
const recallControlNet = useCallback(
|
||||
@ -521,11 +513,101 @@ export const useRecallParameters = () => {
|
||||
]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall T2I Adapter with toast
|
||||
*/
|
||||
|
||||
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(undefined);
|
||||
|
||||
const prepareT2IAdapterMetadataItem = useCallback(
|
||||
(t2iAdapterMetadataItem: T2IAdapterMetadataItem) => {
|
||||
if (!isValidControlNetModel(t2iAdapterMetadataItem.t2i_adapter_model)) {
|
||||
return { controlnet: null, error: 'Invalid ControlNet model' };
|
||||
}
|
||||
|
||||
const {
|
||||
image,
|
||||
t2i_adapter_model,
|
||||
weight,
|
||||
begin_step_percent,
|
||||
end_step_percent,
|
||||
resize_mode,
|
||||
} = t2iAdapterMetadataItem;
|
||||
|
||||
const matchingT2IAdapterModel = t2iAdapterModels
|
||||
? t2iAdapterModelsAdapter
|
||||
.getSelectors()
|
||||
.selectById(
|
||||
t2iAdapterModels,
|
||||
`${t2i_adapter_model.base_model}/t2i_adapter/${t2i_adapter_model.model_name}`
|
||||
)
|
||||
: undefined;
|
||||
|
||||
if (!matchingT2IAdapterModel) {
|
||||
return { controlnet: null, error: 'ControlNet model is not installed' };
|
||||
}
|
||||
|
||||
const isCompatibleBaseModel =
|
||||
matchingT2IAdapterModel?.base_model === model?.base_model;
|
||||
|
||||
if (!isCompatibleBaseModel) {
|
||||
return {
|
||||
t2iAdapter: null,
|
||||
error: 'ControlNet incompatible with currently-selected model',
|
||||
};
|
||||
}
|
||||
|
||||
// We don't save the original image that was processed into a control image, only the processed image
|
||||
const processorType = 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||
|
||||
const t2iAdapter: T2IAdapterConfig = {
|
||||
type: 't2i_adapter',
|
||||
isEnabled: true,
|
||||
model: matchingT2IAdapterModel,
|
||||
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
|
||||
beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct,
|
||||
endStepPct: end_step_percent || initialT2IAdapter.endStepPct,
|
||||
resizeMode: resize_mode || initialT2IAdapter.resizeMode,
|
||||
controlImage: image?.image_name || null,
|
||||
processedControlImage: image?.image_name || null,
|
||||
processorType,
|
||||
processorNode,
|
||||
shouldAutoConfig: true,
|
||||
id: uuidv4(),
|
||||
};
|
||||
|
||||
return { t2iAdapter, error: null };
|
||||
},
|
||||
[model?.base_model, t2iAdapterModels]
|
||||
);
|
||||
|
||||
const recallT2IAdapter = useCallback(
|
||||
(t2iAdapterMetadataItem: T2IAdapterMetadataItem) => {
|
||||
const result = prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem);
|
||||
|
||||
if (!result.t2iAdapter) {
|
||||
parameterNotSetToast(result.error);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(controlAdapterRecalled(result.t2iAdapter));
|
||||
|
||||
parameterSetToast();
|
||||
},
|
||||
[
|
||||
prepareT2IAdapterMetadataItem,
|
||||
dispatch,
|
||||
parameterSetToast,
|
||||
parameterNotSetToast,
|
||||
]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall IP Adapter with toast
|
||||
*/
|
||||
|
||||
const { data: ipAdapters } = useGetIPAdapterModelsQuery(undefined);
|
||||
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(undefined);
|
||||
|
||||
const prepareIPAdapterMetadataItem = useCallback(
|
||||
(ipAdapterMetadataItem: IPAdapterMetadataItem) => {
|
||||
@ -541,11 +623,11 @@ export const useRecallParameters = () => {
|
||||
end_step_percent,
|
||||
} = ipAdapterMetadataItem;
|
||||
|
||||
const matchingIPAdapterModel = ipAdapters
|
||||
const matchingIPAdapterModel = ipAdapterModels
|
||||
? ipAdapterModelsAdapter
|
||||
.getSelectors()
|
||||
.selectById(
|
||||
ipAdapters,
|
||||
ipAdapterModels,
|
||||
`${ip_adapter_model.base_model}/ip_adapter/${ip_adapter_model.model_name}`
|
||||
)
|
||||
: undefined;
|
||||
@ -577,7 +659,7 @@ export const useRecallParameters = () => {
|
||||
|
||||
return { ipAdapter, error: null };
|
||||
},
|
||||
[ipAdapters, model?.base_model]
|
||||
[ipAdapterModels, model?.base_model]
|
||||
);
|
||||
|
||||
const recallIPAdapter = useCallback(
|
||||
@ -641,6 +723,7 @@ export const useRecallParameters = () => {
|
||||
loras,
|
||||
controlnets,
|
||||
ipAdapters,
|
||||
t2iAdapters,
|
||||
} = metadata;
|
||||
|
||||
if (isValidCfgScale(cfg_scale)) {
|
||||
@ -745,15 +828,23 @@ export const useRecallParameters = () => {
|
||||
}
|
||||
});
|
||||
|
||||
t2iAdapters?.forEach((t2iAdapter) => {
|
||||
const result = prepareT2IAdapterMetadataItem(t2iAdapter);
|
||||
if (result.t2iAdapter) {
|
||||
dispatch(controlAdapterRecalled(result.t2iAdapter));
|
||||
}
|
||||
});
|
||||
|
||||
allParameterSetToast();
|
||||
},
|
||||
[
|
||||
allParameterNotSetToast,
|
||||
allParameterSetToast,
|
||||
dispatch,
|
||||
allParameterSetToast,
|
||||
allParameterNotSetToast,
|
||||
prepareLoRAMetadataItem,
|
||||
prepareControlNetMetadataItem,
|
||||
prepareIPAdapterMetadataItem,
|
||||
prepareT2IAdapterMetadataItem,
|
||||
]
|
||||
);
|
||||
|
||||
@ -774,6 +865,7 @@ export const useRecallParameters = () => {
|
||||
recallLoRA,
|
||||
recallControlNet,
|
||||
recallIPAdapter,
|
||||
recallT2IAdapter,
|
||||
recallAllParameters,
|
||||
sendToImageToImage,
|
||||
};
|
||||
|
@ -347,6 +347,10 @@ export const zT2IAdapterModel = z.object({
|
||||
model_name: z.string().min(1),
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
export const isValidT2IAdapterModel = (
|
||||
val: unknown
|
||||
): val is T2IAdapterModelParam => zT2IAdapterModel.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
|
2531
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
2531
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
@ -155,6 +155,7 @@ export type SaveImageInvocation = s['SaveImageInvocation'];
|
||||
|
||||
// ControlNet Nodes
|
||||
export type ControlNetInvocation = s['ControlNetInvocation'];
|
||||
export type T2IAdapterInvocation = s['T2IAdapterInvocation'];
|
||||
export type IPAdapterInvocation = s['IPAdapterInvocation'];
|
||||
export type CannyImageProcessorInvocation = s['CannyImageProcessorInvocation'];
|
||||
export type ColorMapImageProcessorInvocation =
|
||||
|
Loading…
Reference in New Issue
Block a user