tidy(ui): use Invocation<> helper type in canvas graph builders, elsewhere

This commit is contained in:
psychedelicious 2024-05-14 20:36:01 +10:00
parent 67dbe6d949
commit 0ff0290735
19 changed files with 80 additions and 175 deletions

View File

@ -1,6 +1,6 @@
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils'; import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
import type { ESRGANInvocation, Graph, NonNullableGraph } from 'services/api/types'; import type { Graph, Invocation, NonNullableGraph } from 'services/api/types';
import { addCoreMetadataNode, upsertMetadata } from './canvas/metadata'; import { addCoreMetadataNode, upsertMetadata } from './canvas/metadata';
import { ESRGAN } from './constants'; import { ESRGAN } from './constants';
@ -13,7 +13,7 @@ type Arg = {
export const buildAdHocUpscaleGraph = ({ image_name, state }: Arg): Graph => { export const buildAdHocUpscaleGraph = ({ image_name, state }: Arg): Graph => {
const { esrganModelName } = state.postprocessing; const { esrganModelName } = state.postprocessing;
const realesrganNode: ESRGANInvocation = { const realesrganNode: Invocation<'esrgan'> = {
id: ESRGAN, id: ESRGAN,
type: 'esrgan', type: 'esrgan',
image: { image_name }, image: { image_name },

View File

@ -5,13 +5,7 @@ import type { ImageField } from 'features/nodes/types/common';
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata'; import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
import { CONTROL_NET_COLLECT } from 'features/nodes/util/graph/constants'; import { CONTROL_NET_COLLECT } from 'features/nodes/util/graph/constants';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import type { import type { Invocation, NonNullableGraph, S } from 'services/api/types';
CollectInvocation,
ControlNetInvocation,
CoreMetadataInvocation,
NonNullableGraph,
S,
} from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
export const addControlNetToLinearGraph = async ( export const addControlNetToLinearGraph = async (
@ -19,7 +13,7 @@ export const addControlNetToLinearGraph = async (
graph: NonNullableGraph, graph: NonNullableGraph,
baseNodeId: string baseNodeId: string
): Promise<void> => { ): Promise<void> => {
const controlNetMetadata: CoreMetadataInvocation['controlnets'] = []; const controlNetMetadata: S['CoreMetadataInvocation']['controlnets'] = [];
const controlNets = selectValidControlNets(state.controlAdapters).filter( const controlNets = selectValidControlNets(state.controlAdapters).filter(
({ model, processedControlImage, processorType, controlImage, isEnabled }) => { ({ model, processedControlImage, processorType, controlImage, isEnabled }) => {
const hasModel = Boolean(model); const hasModel = Boolean(model);
@ -36,7 +30,7 @@ export const addControlNetToLinearGraph = async (
if (controlNets.length) { if (controlNets.length) {
// Even though denoise_latents' control input is collection or scalar, keep it simple and always use a collect // Even though denoise_latents' control input is collection or scalar, keep it simple and always use a collect
const controlNetIterateNode: CollectInvocation = { const controlNetIterateNode: Invocation<'collect'> = {
id: CONTROL_NET_COLLECT, id: CONTROL_NET_COLLECT,
type: 'collect', type: 'collect',
is_intermediate: true, is_intermediate: true,
@ -67,7 +61,7 @@ export const addControlNetToLinearGraph = async (
weight, weight,
} = controlNet; } = controlNet;
const controlNetNode: ControlNetInvocation = { const controlNetNode: Invocation<'controlnet'> = {
id: `control_net_${id}`, id: `control_net_${id}`,
type: 'controlnet', type: 'controlnet',
is_intermediate: true, is_intermediate: true,

View File

@ -5,13 +5,7 @@ import type { ImageField } from 'features/nodes/types/common';
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata'; import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
import { IP_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants'; import { IP_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import type { import type { Invocation, NonNullableGraph, S } from 'services/api/types';
CollectInvocation,
CoreMetadataInvocation,
IPAdapterInvocation,
NonNullableGraph,
S,
} from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
export const addIPAdapterToLinearGraph = async ( export const addIPAdapterToLinearGraph = async (
@ -32,7 +26,7 @@ export const addIPAdapterToLinearGraph = async (
if (ipAdapters.length) { if (ipAdapters.length) {
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect // Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
const ipAdapterCollectNode: CollectInvocation = { const ipAdapterCollectNode: Invocation<'collect'> = {
id: IP_ADAPTER_COLLECT, id: IP_ADAPTER_COLLECT,
type: 'collect', type: 'collect',
is_intermediate: true, is_intermediate: true,
@ -46,7 +40,7 @@ export const addIPAdapterToLinearGraph = async (
}, },
}); });
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = []; const ipAdapterMetdata: S['CoreMetadataInvocation']['ipAdapters'] = [];
for (const ipAdapter of ipAdapters) { for (const ipAdapter of ipAdapters) {
if (!ipAdapter.model) { if (!ipAdapter.model) {
@ -56,7 +50,7 @@ export const addIPAdapterToLinearGraph = async (
assert(controlImage, 'IP Adapter image is required'); assert(controlImage, 'IP Adapter image is required');
const ipAdapterNode: IPAdapterInvocation = { const ipAdapterNode: Invocation<'ip_adapter'> = {
id: `ip_adapter_${id}`, id: `ip_adapter_${id}`,
type: 'ip_adapter', type: 'ip_adapter',
is_intermediate: true, is_intermediate: true,

View File

@ -9,7 +9,7 @@ import {
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
} from 'features/nodes/util/graph/constants'; } from 'features/nodes/util/graph/constants';
import { filter, size } from 'lodash-es'; import { filter, size } from 'lodash-es';
import type { CoreMetadataInvocation, LoRALoaderInvocation, NonNullableGraph } from 'services/api/types'; import type { Invocation, NonNullableGraph, S } from 'services/api/types';
export const addLoRAsToGraph = async ( export const addLoRAsToGraph = async (
state: RootState, state: RootState,
@ -43,7 +43,7 @@ export const addLoRAsToGraph = async (
// we need to remember the last lora so we can chain from it // we need to remember the last lora so we can chain from it
let lastLoraNodeId = ''; let lastLoraNodeId = '';
let currentLoraIndex = 0; let currentLoraIndex = 0;
const loraMetadata: CoreMetadataInvocation['loras'] = []; const loraMetadata: S['CoreMetadataInvocation']['loras'] = [];
enabledLoRAs.forEach(async (lora) => { enabledLoRAs.forEach(async (lora) => {
const { weight } = lora; const { weight } = lora;
@ -51,7 +51,7 @@ export const addLoRAsToGraph = async (
const currentLoraNodeId = `${LORA_LOADER}_${key}`; const currentLoraNodeId = `${LORA_LOADER}_${key}`;
const parsedModel = zModelIdentifierField.parse(lora.model); const parsedModel = zModelIdentifierField.parse(lora.model);
const loraLoaderNode: LoRALoaderInvocation = { const loraLoaderNode: Invocation<'lora_loader'> = {
type: 'lora_loader', type: 'lora_loader',
id: currentLoraNodeId, id: currentLoraNodeId,
is_intermediate: true, is_intermediate: true,

View File

@ -1,14 +1,14 @@
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { LATENTS_TO_IMAGE, NSFW_CHECKER } from 'features/nodes/util/graph/constants'; import { LATENTS_TO_IMAGE, NSFW_CHECKER } from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageNSFWBlurInvocation, LatentsToImageInvocation, NonNullableGraph } from 'services/api/types'; import type { Invocation, NonNullableGraph } from 'services/api/types';
export const addNSFWCheckerToGraph = ( export const addNSFWCheckerToGraph = (
state: RootState, state: RootState,
graph: NonNullableGraph, graph: NonNullableGraph,
nodeIdToAddTo = LATENTS_TO_IMAGE nodeIdToAddTo = LATENTS_TO_IMAGE
): void => { ): void => {
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as LatentsToImageInvocation | undefined; const nodeToAddTo = graph.nodes[nodeIdToAddTo] as Invocation<'l2i'> | undefined;
if (!nodeToAddTo) { if (!nodeToAddTo) {
// something has gone terribly awry // something has gone terribly awry
@ -18,14 +18,14 @@ export const addNSFWCheckerToGraph = (
nodeToAddTo.is_intermediate = true; nodeToAddTo.is_intermediate = true;
nodeToAddTo.use_cache = true; nodeToAddTo.use_cache = true;
const nsfwCheckerNode: ImageNSFWBlurInvocation = { const nsfwCheckerNode: Invocation<'img_nsfw'> = {
id: NSFW_CHECKER, id: NSFW_CHECKER,
type: 'img_nsfw', type: 'img_nsfw',
is_intermediate: getIsIntermediate(state), is_intermediate: getIsIntermediate(state),
board: getBoardField(state), board: getBoardField(state),
}; };
graph.nodes[NSFW_CHECKER] = nsfwCheckerNode as ImageNSFWBlurInvocation; graph.nodes[NSFW_CHECKER] = nsfwCheckerNode;
graph.edges.push({ graph.edges.push({
source: { source: {
node_id: nodeIdToAddTo, node_id: nodeIdToAddTo,

View File

@ -10,7 +10,7 @@ import {
SEAMLESS, SEAMLESS,
} from 'features/nodes/util/graph/constants'; } from 'features/nodes/util/graph/constants';
import { filter, size } from 'lodash-es'; import { filter, size } from 'lodash-es';
import type { CoreMetadataInvocation, NonNullableGraph, SDXLLoRALoaderInvocation } from 'services/api/types'; import type { Invocation, NonNullableGraph, S } from 'services/api/types';
export const addSDXLLoRAsToGraph = async ( export const addSDXLLoRAsToGraph = async (
state: RootState, state: RootState,
@ -34,7 +34,7 @@ export const addSDXLLoRAsToGraph = async (
return; return;
} }
const loraMetadata: CoreMetadataInvocation['loras'] = []; const loraMetadata: S['CoreMetadataInvocation']['loras'] = [];
// Handle Seamless Plugs // Handle Seamless Plugs
const unetLoaderId = modelLoaderNodeId; const unetLoaderId = modelLoaderNodeId;
@ -60,7 +60,7 @@ export const addSDXLLoRAsToGraph = async (
const currentLoraNodeId = `${LORA_LOADER}_${lora.model.key}`; const currentLoraNodeId = `${LORA_LOADER}_${lora.model.key}`;
const parsedModel = zModelIdentifierField.parse(lora.model); const parsedModel = zModelIdentifierField.parse(lora.model);
const loraLoaderNode: SDXLLoRALoaderInvocation = { const loraLoaderNode: Invocation<'sdxl_lora_loader'> = {
type: 'sdxl_lora_loader', type: 'sdxl_lora_loader',
id: currentLoraNodeId, id: currentLoraNodeId,
is_intermediate: true, is_intermediate: true,

View File

@ -17,7 +17,7 @@ import {
SDXL_REFINER_SEAMLESS, SDXL_REFINER_SEAMLESS,
} from 'features/nodes/util/graph/constants'; } from 'features/nodes/util/graph/constants';
import { getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils'; import { getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import type { NonNullableGraph, SeamlessModeInvocation } from 'services/api/types'; import type { NonNullableGraph } from 'services/api/types';
import { isRefinerMainModelModelConfig } from 'services/api/types'; import { isRefinerMainModelModelConfig } from 'services/api/types';
export const addSDXLRefinerToGraph = async ( export const addSDXLRefinerToGraph = async (
@ -100,7 +100,7 @@ export const addSDXLRefinerToGraph = async (
type: 'seamless', type: 'seamless',
seamless_x: seamlessXAxis, seamless_x: seamlessXAxis,
seamless_y: seamlessYAxis, seamless_y: seamlessYAxis,
} as SeamlessModeInvocation; };
graph.edges.push( graph.edges.push(
{ {

View File

@ -11,7 +11,7 @@ import {
SEAMLESS, SEAMLESS,
VAE_LOADER, VAE_LOADER,
} from 'features/nodes/util/graph/constants'; } from 'features/nodes/util/graph/constants';
import type { NonNullableGraph, SeamlessModeInvocation } from 'services/api/types'; import type { NonNullableGraph } from 'services/api/types';
export const addSeamlessToLinearGraph = ( export const addSeamlessToLinearGraph = (
state: RootState, state: RootState,
@ -27,7 +27,7 @@ export const addSeamlessToLinearGraph = (
type: 'seamless', type: 'seamless',
seamless_x: seamlessXAxis, seamless_x: seamlessXAxis,
seamless_y: seamlessYAxis, seamless_y: seamlessYAxis,
} as SeamlessModeInvocation; };
if (!isAutoVae) { if (!isAutoVae) {
graph.nodes[VAE_LOADER] = { graph.nodes[VAE_LOADER] = {

View File

@ -5,13 +5,7 @@ import type { ImageField } from 'features/nodes/types/common';
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata'; import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
import { T2I_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants'; import { T2I_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import type { import type { Invocation, NonNullableGraph, S } from 'services/api/types';
CollectInvocation,
CoreMetadataInvocation,
NonNullableGraph,
S,
T2IAdapterInvocation,
} from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
export const addT2IAdaptersToLinearGraph = async ( export const addT2IAdaptersToLinearGraph = async (
@ -35,7 +29,7 @@ export const addT2IAdaptersToLinearGraph = async (
if (t2iAdapters.length) { if (t2iAdapters.length) {
// Even though denoise_latents' t2i adapter input is collection or scalar, keep it simple and always use a collect // Even though denoise_latents' t2i adapter input is collection or scalar, keep it simple and always use a collect
const t2iAdapterCollectNode: CollectInvocation = { const t2iAdapterCollectNode: Invocation<'collect'> = {
id: T2I_ADAPTER_COLLECT, id: T2I_ADAPTER_COLLECT,
type: 'collect', type: 'collect',
is_intermediate: true, is_intermediate: true,
@ -49,7 +43,7 @@ export const addT2IAdaptersToLinearGraph = async (
}, },
}); });
const t2iAdapterMetadata: CoreMetadataInvocation['t2iAdapters'] = []; const t2iAdapterMetadata: S['CoreMetadataInvocation']['t2iAdapters'] = [];
for (const t2iAdapter of t2iAdapters) { for (const t2iAdapter of t2iAdapters) {
if (!t2iAdapter.model) { if (!t2iAdapter.model) {
@ -67,7 +61,7 @@ export const addT2IAdaptersToLinearGraph = async (
weight, weight,
} = t2iAdapter; } = t2iAdapter;
const t2iAdapterNode: T2IAdapterInvocation = { const t2iAdapterNode: Invocation<'t2i_adapter'> = {
id: `t2i_adapter_${id}`, id: `t2i_adapter_${id}`,
type: 't2i_adapter', type: 't2i_adapter',
is_intermediate: true, is_intermediate: true,

View File

@ -1,28 +1,23 @@
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from 'features/nodes/util/graph/constants'; import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { import type { Invocation, NonNullableGraph } from 'services/api/types';
ImageNSFWBlurInvocation,
ImageWatermarkInvocation,
LatentsToImageInvocation,
NonNullableGraph,
} from 'services/api/types';
export const addWatermarkerToGraph = ( export const addWatermarkerToGraph = (
state: RootState, state: RootState,
graph: NonNullableGraph, graph: NonNullableGraph,
nodeIdToAddTo = LATENTS_TO_IMAGE nodeIdToAddTo = LATENTS_TO_IMAGE
): void => { ): void => {
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as LatentsToImageInvocation | undefined; const nodeToAddTo = graph.nodes[nodeIdToAddTo] as Invocation<'l2i'> | undefined;
const nsfwCheckerNode = graph.nodes[NSFW_CHECKER] as ImageNSFWBlurInvocation | undefined; const nsfwCheckerNode = graph.nodes[NSFW_CHECKER] as Invocation<'img_nsfw'> | undefined;
if (!nodeToAddTo) { if (!nodeToAddTo) {
// something has gone terribly awry // something has gone terribly awry
return; return;
} }
const watermarkerNode: ImageWatermarkInvocation = { const watermarkerNode: Invocation<'img_watermark'> = {
id: WATERMARKER, id: WATERMARKER,
type: 'img_watermark', type: 'img_watermark',
is_intermediate: getIsIntermediate(state), is_intermediate: getIsIntermediate(state),

View File

@ -17,12 +17,8 @@ import {
SEAMLESS, SEAMLESS,
} from 'features/nodes/util/graph/constants'; } from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import { import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
type ImageDTO, import { isNonRefinerMainModelConfig } from 'services/api/types';
type ImageToLatentsInvocation,
isNonRefinerMainModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -300,7 +296,7 @@ export const buildCanvasImageToImageGraph = async (
use_cache: false, use_cache: false,
}; };
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = initialImage; (graph.nodes[IMAGE_TO_LATENTS] as Invocation<'i2l'>).image = initialImage;
graph.edges.push({ graph.edges.push({
source: { source: {

View File

@ -19,14 +19,7 @@ import {
SEAMLESS, SEAMLESS,
} from 'features/nodes/util/graph/constants'; } from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
CanvasPasteBackInvocation,
CreateGradientMaskInvocation,
ImageDTO,
ImageToLatentsInvocation,
NoiseInvocation,
NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -316,8 +309,8 @@ export const buildCanvasInpaintGraph = async (
height: height, height: height,
}; };
(graph.nodes[NOISE] as NoiseInvocation).width = scaledWidth; (graph.nodes[NOISE] as Invocation<'noise'>).width = scaledWidth;
(graph.nodes[NOISE] as NoiseInvocation).height = scaledHeight; (graph.nodes[NOISE] as Invocation<'noise'>).height = scaledHeight;
// Connect Nodes // Connect Nodes
graph.edges.push( graph.edges.push(
@ -397,22 +390,22 @@ export const buildCanvasInpaintGraph = async (
); );
} else { } else {
// Add Images To Nodes // Add Images To Nodes
(graph.nodes[NOISE] as NoiseInvocation).width = width; (graph.nodes[NOISE] as Invocation<'noise'>).width = width;
(graph.nodes[NOISE] as NoiseInvocation).height = height; (graph.nodes[NOISE] as Invocation<'noise'>).height = height;
graph.nodes[INPAINT_IMAGE] = { graph.nodes[INPAINT_IMAGE] = {
...(graph.nodes[INPAINT_IMAGE] as ImageToLatentsInvocation), ...(graph.nodes[INPAINT_IMAGE] as Invocation<'i2l'>),
image: canvasInitImage, image: canvasInitImage,
}; };
graph.nodes[INPAINT_CREATE_MASK] = { graph.nodes[INPAINT_CREATE_MASK] = {
...(graph.nodes[INPAINT_CREATE_MASK] as CreateGradientMaskInvocation), ...(graph.nodes[INPAINT_CREATE_MASK] as Invocation<'create_gradient_mask'>),
mask: canvasMaskImage, mask: canvasMaskImage,
}; };
// Paste Back // Paste Back
graph.nodes[CANVAS_OUTPUT] = { graph.nodes[CANVAS_OUTPUT] = {
...(graph.nodes[CANVAS_OUTPUT] as CanvasPasteBackInvocation), ...(graph.nodes[CANVAS_OUTPUT] as Invocation<'canvas_paste_back'>),
mask: canvasMaskImage, mask: canvasMaskImage,
}; };

View File

@ -23,14 +23,7 @@ import {
SEAMLESS, SEAMLESS,
} from 'features/nodes/util/graph/constants'; } from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
ImageDTO,
ImageToLatentsInvocation,
InfillPatchMatchInvocation,
InfillTileInvocation,
NoiseInvocation,
NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -437,8 +430,8 @@ export const buildCanvasOutpaintGraph = async (
height: height, height: height,
}; };
(graph.nodes[NOISE] as NoiseInvocation).width = scaledWidth; (graph.nodes[NOISE] as Invocation<'noise'>).width = scaledWidth;
(graph.nodes[NOISE] as NoiseInvocation).height = scaledHeight; (graph.nodes[NOISE] as Invocation<'noise'>).height = scaledHeight;
// Connect Nodes // Connect Nodes
graph.edges.push( graph.edges.push(
@ -540,15 +533,15 @@ export const buildCanvasOutpaintGraph = async (
} else { } else {
// Add Images To Nodes // Add Images To Nodes
graph.nodes[INPAINT_INFILL] = { graph.nodes[INPAINT_INFILL] = {
...(graph.nodes[INPAINT_INFILL] as InfillTileInvocation | InfillPatchMatchInvocation), ...(graph.nodes[INPAINT_INFILL] as Invocation<'infill_tile'> | Invocation<'infill_patchmatch'>),
image: canvasInitImage, image: canvasInitImage,
}; };
(graph.nodes[NOISE] as NoiseInvocation).width = width; (graph.nodes[NOISE] as Invocation<'noise'>).width = width;
(graph.nodes[NOISE] as NoiseInvocation).height = height; (graph.nodes[NOISE] as Invocation<'noise'>).height = height;
graph.nodes[INPAINT_IMAGE] = { graph.nodes[INPAINT_IMAGE] = {
...(graph.nodes[INPAINT_IMAGE] as ImageToLatentsInvocation), ...(graph.nodes[INPAINT_IMAGE] as Invocation<'i2l'>),
image: canvasInitImage, image: canvasInitImage,
}; };

View File

@ -17,12 +17,8 @@ import {
SEAMLESS, SEAMLESS,
} from 'features/nodes/util/graph/constants'; } from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils'; import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import { import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
type ImageDTO, import { isNonRefinerMainModelConfig } from 'services/api/types';
type ImageToLatentsInvocation,
isNonRefinerMainModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -301,7 +297,7 @@ export const buildCanvasSDXLImageToImageGraph = async (
use_cache: false, use_cache: false,
}; };
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = initialImage; (graph.nodes[IMAGE_TO_LATENTS] as Invocation<'i2l'>).image = initialImage;
graph.edges.push({ graph.edges.push({
source: { source: {

View File

@ -19,14 +19,7 @@ import {
SEAMLESS, SEAMLESS,
} from 'features/nodes/util/graph/constants'; } from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils'; import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import type { import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
CanvasPasteBackInvocation,
CreateGradientMaskInvocation,
ImageDTO,
ImageToLatentsInvocation,
NoiseInvocation,
NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -327,8 +320,8 @@ export const buildCanvasSDXLInpaintGraph = async (
height: height, height: height,
}; };
(graph.nodes[NOISE] as NoiseInvocation).width = scaledWidth; (graph.nodes[NOISE] as Invocation<'noise'>).width = scaledWidth;
(graph.nodes[NOISE] as NoiseInvocation).height = scaledHeight; (graph.nodes[NOISE] as Invocation<'noise'>).height = scaledHeight;
// Connect Nodes // Connect Nodes
graph.edges.push( graph.edges.push(
@ -408,22 +401,22 @@ export const buildCanvasSDXLInpaintGraph = async (
); );
} else { } else {
// Add Images To Nodes // Add Images To Nodes
(graph.nodes[NOISE] as NoiseInvocation).width = width; (graph.nodes[NOISE] as Invocation<'noise'>).width = width;
(graph.nodes[NOISE] as NoiseInvocation).height = height; (graph.nodes[NOISE] as Invocation<'noise'>).height = height;
graph.nodes[INPAINT_IMAGE] = { graph.nodes[INPAINT_IMAGE] = {
...(graph.nodes[INPAINT_IMAGE] as ImageToLatentsInvocation), ...(graph.nodes[INPAINT_IMAGE] as Invocation<'i2l'>),
image: canvasInitImage, image: canvasInitImage,
}; };
graph.nodes[INPAINT_CREATE_MASK] = { graph.nodes[INPAINT_CREATE_MASK] = {
...(graph.nodes[INPAINT_CREATE_MASK] as CreateGradientMaskInvocation), ...(graph.nodes[INPAINT_CREATE_MASK] as Invocation<'create_gradient_mask'>),
mask: canvasMaskImage, mask: canvasMaskImage,
}; };
// Paste Back // Paste Back
graph.nodes[CANVAS_OUTPUT] = { graph.nodes[CANVAS_OUTPUT] = {
...(graph.nodes[CANVAS_OUTPUT] as CanvasPasteBackInvocation), ...(graph.nodes[CANVAS_OUTPUT] as Invocation<'canvas_paste_back'>),
mask: canvasMaskImage, mask: canvasMaskImage,
}; };

View File

@ -23,14 +23,7 @@ import {
SEAMLESS, SEAMLESS,
} from 'features/nodes/util/graph/constants'; } from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils'; import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import type { import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
ImageDTO,
ImageToLatentsInvocation,
InfillPatchMatchInvocation,
InfillTileInvocation,
NoiseInvocation,
NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -446,8 +439,8 @@ export const buildCanvasSDXLOutpaintGraph = async (
height: height, height: height,
}; };
(graph.nodes[NOISE] as NoiseInvocation).width = scaledWidth; (graph.nodes[NOISE] as Invocation<'noise'>).width = scaledWidth;
(graph.nodes[NOISE] as NoiseInvocation).height = scaledHeight; (graph.nodes[NOISE] as Invocation<'noise'>).height = scaledHeight;
// Connect Nodes // Connect Nodes
graph.edges.push( graph.edges.push(
@ -549,15 +542,15 @@ export const buildCanvasSDXLOutpaintGraph = async (
} else { } else {
// Add Images To Nodes // Add Images To Nodes
graph.nodes[INPAINT_INFILL] = { graph.nodes[INPAINT_INFILL] = {
...(graph.nodes[INPAINT_INFILL] as InfillTileInvocation | InfillPatchMatchInvocation), ...(graph.nodes[INPAINT_INFILL] as Invocation<'infill_tile'> | Invocation<'infill_patchmatch'>),
image: canvasInitImage, image: canvasInitImage,
}; };
(graph.nodes[NOISE] as NoiseInvocation).width = width; (graph.nodes[NOISE] as Invocation<'noise'>).width = width;
(graph.nodes[NOISE] as NoiseInvocation).height = height; (graph.nodes[NOISE] as Invocation<'noise'>).height = height;
graph.nodes[INPAINT_IMAGE] = { graph.nodes[INPAINT_IMAGE] = {
...(graph.nodes[INPAINT_IMAGE] as ImageToLatentsInvocation), ...(graph.nodes[INPAINT_IMAGE] as Invocation<'i2l'>),
image: canvasInitImage, image: canvasInitImage,
}; };

View File

@ -1,11 +1,11 @@
import type { JSONObject } from 'common/types'; import type { JSONObject } from 'common/types';
import type { ModelIdentifierField } from 'features/nodes/types/common'; import type { ModelIdentifierField } from 'features/nodes/types/common';
import { METADATA } from 'features/nodes/util/graph/constants'; import { METADATA } from 'features/nodes/util/graph/constants';
import type { AnyModelConfig, CoreMetadataInvocation, NonNullableGraph } from 'services/api/types'; import type { AnyModelConfig, NonNullableGraph, S } from 'services/api/types';
export const addCoreMetadataNode = ( export const addCoreMetadataNode = (
graph: NonNullableGraph, graph: NonNullableGraph,
metadata: Partial<CoreMetadataInvocation>, metadata: Partial<S['CoreMetadataInvocation']>,
nodeId: string nodeId: string
): void => { ): void => {
graph.nodes[METADATA] = { graph.nodes[METADATA] = {
@ -30,9 +30,9 @@ export const addCoreMetadataNode = (
export const upsertMetadata = ( export const upsertMetadata = (
graph: NonNullableGraph, graph: NonNullableGraph,
metadata: Partial<CoreMetadataInvocation> | JSONObject metadata: Partial<S['CoreMetadataInvocation']> | JSONObject
): void => { ): void => {
const metadataNode = graph.nodes[METADATA] as CoreMetadataInvocation | undefined; const metadataNode = graph.nodes[METADATA] as S['CoreMetadataInvocation'] | undefined;
if (!metadataNode) { if (!metadataNode) {
return; return;
@ -41,8 +41,8 @@ export const upsertMetadata = (
Object.assign(metadataNode, metadata); Object.assign(metadataNode, metadata);
}; };
export const removeMetadata = (graph: NonNullableGraph, key: keyof CoreMetadataInvocation): void => { export const removeMetadata = (graph: NonNullableGraph, key: keyof S['CoreMetadataInvocation']): void => {
const metadataNode = graph.nodes[METADATA] as CoreMetadataInvocation | undefined; const metadataNode = graph.nodes[METADATA] as S['CoreMetadataInvocation'] | undefined;
if (!metadataNode) { if (!metadataNode) {
return; return;
@ -52,7 +52,7 @@ export const removeMetadata = (graph: NonNullableGraph, key: keyof CoreMetadataI
}; };
export const getHasMetadata = (graph: NonNullableGraph): boolean => { export const getHasMetadata = (graph: NonNullableGraph): boolean => {
const metadataNode = graph.nodes[METADATA] as CoreMetadataInvocation | undefined; const metadataNode = graph.nodes[METADATA] as S['CoreMetadataInvocation'] | undefined;
return Boolean(metadataNode); return Boolean(metadataNode);
}; };

View File

@ -7,11 +7,11 @@ import type {
AnyInvocationInputField, AnyInvocationInputField,
AnyInvocationOutputField, AnyInvocationOutputField,
AnyModelConfig, AnyModelConfig,
CoreMetadataInvocation,
InputFields, InputFields,
Invocation, Invocation,
InvocationType, InvocationType,
OutputFields, OutputFields,
S,
} from 'services/api/types'; } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
@ -335,13 +335,13 @@ export class Graph {
* INTERNAL: Get the metadata node. If it does not exist, it is created. * INTERNAL: Get the metadata node. If it does not exist, it is created.
* @returns The metadata node. * @returns The metadata node.
*/ */
_getMetadataNode(): CoreMetadataInvocation { _getMetadataNode(): S['CoreMetadataInvocation'] {
try { try {
const node = this.getNode(METADATA) as AnyInvocationIncMetadata; const node = this.getNode(METADATA) as AnyInvocationIncMetadata;
assert(node.type === 'core_metadata'); assert(node.type === 'core_metadata');
return node; return node;
} catch { } catch {
const node: CoreMetadataInvocation = { id: METADATA, type: 'core_metadata' }; const node: S['CoreMetadataInvocation'] = { id: METADATA, type: 'core_metadata' };
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing // @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
return this.addNode(node); return this.addNode(node);
} }
@ -353,7 +353,7 @@ export class Graph {
* @param metadata The metadata to add. * @param metadata The metadata to add.
* @returns The metadata node. * @returns The metadata node.
*/ */
upsertMetadata(metadata: Partial<CoreMetadataInvocation>): CoreMetadataInvocation { upsertMetadata(metadata: Partial<S['CoreMetadataInvocation']>): S['CoreMetadataInvocation'] {
const node = this._getMetadataNode(); const node = this._getMetadataNode();
Object.assign(node, metadata); Object.assign(node, metadata);
return node; return node;
@ -364,7 +364,7 @@ export class Graph {
* @param keys The keys of the metadata to remove * @param keys The keys of the metadata to remove
* @returns The metadata node * @returns The metadata node
*/ */
removeMetadata(keys: string[]): CoreMetadataInvocation { removeMetadata(keys: string[]): S['CoreMetadataInvocation'] {
const metadataNode = this._getMetadataNode(); const metadataNode = this._getMetadataNode();
for (const k of keys) { for (const k of keys) {
unset(metadataNode, k); unset(metadataNode, k);

View File

@ -154,42 +154,6 @@ export type OutputFields<T extends AnyInvocation> = Extract<
AnyInvocationOutputField AnyInvocationOutputField
>; >;
// General nodes
export type CollectInvocation = Invocation<'collect'>;
export type InfillPatchMatchInvocation = Invocation<'infill_patchmatch'>;
export type InfillTileInvocation = Invocation<'infill_tile'>;
export type CreateGradientMaskInvocation = Invocation<'create_gradient_mask'>;
export type CanvasPasteBackInvocation = Invocation<'canvas_paste_back'>;
export type NoiseInvocation = Invocation<'noise'>;
export type SDXLLoRALoaderInvocation = Invocation<'sdxl_lora_loader'>;
export type ImageToLatentsInvocation = Invocation<'i2l'>;
export type LatentsToImageInvocation = Invocation<'l2i'>;
export type LoRALoaderInvocation = Invocation<'lora_loader'>;
export type ESRGANInvocation = Invocation<'esrgan'>;
export type ImageNSFWBlurInvocation = Invocation<'img_nsfw'>;
export type ImageWatermarkInvocation = Invocation<'img_watermark'>;
export type SeamlessModeInvocation = Invocation<'seamless'>;
export type CoreMetadataInvocation = Extract<Graph['nodes'][string], { type: 'core_metadata' }>;
// ControlNet Nodes
export type ControlNetInvocation = Invocation<'controlnet'>;
export type T2IAdapterInvocation = Invocation<'t2i_adapter'>;
export type IPAdapterInvocation = Invocation<'ip_adapter'>;
export type CannyImageProcessorInvocation = Invocation<'canny_image_processor'>;
export type ColorMapImageProcessorInvocation = Invocation<'color_map_image_processor'>;
export type ContentShuffleImageProcessorInvocation = Invocation<'content_shuffle_image_processor'>;
export type DepthAnythingImageProcessorInvocation = Invocation<'depth_anything_image_processor'>;
export type HedImageProcessorInvocation = Invocation<'hed_image_processor'>;
export type LineartAnimeImageProcessorInvocation = Invocation<'lineart_anime_image_processor'>;
export type LineartImageProcessorInvocation = Invocation<'lineart_image_processor'>;
export type MediapipeFaceProcessorInvocation = Invocation<'mediapipe_face_processor'>;
export type MidasDepthImageProcessorInvocation = Invocation<'midas_depth_image_processor'>;
export type MlsdImageProcessorInvocation = Invocation<'mlsd_image_processor'>;
export type NormalbaeImageProcessorInvocation = Invocation<'normalbae_image_processor'>;
export type DWOpenposeImageProcessorInvocation = Invocation<'dw_openpose_image_processor'>;
export type PidiImageProcessorInvocation = Invocation<'pidi_image_processor'>;
export type ZoeDepthImageProcessorInvocation = Invocation<'zoe_depth_image_processor'>;
// Node Outputs // Node Outputs
export type ImageOutput = S['ImageOutput']; export type ImageOutput = S['ImageOutput'];