tidy(ui): use Invocation<> helper type in canvas graph builders, elsewhere

This commit is contained in:
psychedelicious 2024-05-14 20:36:01 +10:00
parent 67dbe6d949
commit 0ff0290735
19 changed files with 80 additions and 175 deletions

View File

@ -1,6 +1,6 @@
import type { RootState } from 'app/store/store';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
import type { ESRGANInvocation, Graph, NonNullableGraph } from 'services/api/types';
import type { Graph, Invocation, NonNullableGraph } from 'services/api/types';
import { addCoreMetadataNode, upsertMetadata } from './canvas/metadata';
import { ESRGAN } from './constants';
@ -13,7 +13,7 @@ type Arg = {
export const buildAdHocUpscaleGraph = ({ image_name, state }: Arg): Graph => {
const { esrganModelName } = state.postprocessing;
const realesrganNode: ESRGANInvocation = {
const realesrganNode: Invocation<'esrgan'> = {
id: ESRGAN,
type: 'esrgan',
image: { image_name },

View File

@ -5,13 +5,7 @@ import type { ImageField } from 'features/nodes/types/common';
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
import { CONTROL_NET_COLLECT } from 'features/nodes/util/graph/constants';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import type {
CollectInvocation,
ControlNetInvocation,
CoreMetadataInvocation,
NonNullableGraph,
S,
} from 'services/api/types';
import type { Invocation, NonNullableGraph, S } from 'services/api/types';
import { assert } from 'tsafe';
export const addControlNetToLinearGraph = async (
@ -19,7 +13,7 @@ export const addControlNetToLinearGraph = async (
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const controlNetMetadata: CoreMetadataInvocation['controlnets'] = [];
const controlNetMetadata: S['CoreMetadataInvocation']['controlnets'] = [];
const controlNets = selectValidControlNets(state.controlAdapters).filter(
({ model, processedControlImage, processorType, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
@ -36,7 +30,7 @@ export const addControlNetToLinearGraph = async (
if (controlNets.length) {
// Even though denoise_latents' control input is collection or scalar, keep it simple and always use a collect
const controlNetIterateNode: CollectInvocation = {
const controlNetIterateNode: Invocation<'collect'> = {
id: CONTROL_NET_COLLECT,
type: 'collect',
is_intermediate: true,
@ -67,7 +61,7 @@ export const addControlNetToLinearGraph = async (
weight,
} = controlNet;
const controlNetNode: ControlNetInvocation = {
const controlNetNode: Invocation<'controlnet'> = {
id: `control_net_${id}`,
type: 'controlnet',
is_intermediate: true,

View File

@ -5,13 +5,7 @@ import type { ImageField } from 'features/nodes/types/common';
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
import { IP_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import type {
CollectInvocation,
CoreMetadataInvocation,
IPAdapterInvocation,
NonNullableGraph,
S,
} from 'services/api/types';
import type { Invocation, NonNullableGraph, S } from 'services/api/types';
import { assert } from 'tsafe';
export const addIPAdapterToLinearGraph = async (
@ -32,7 +26,7 @@ export const addIPAdapterToLinearGraph = async (
if (ipAdapters.length) {
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
const ipAdapterCollectNode: CollectInvocation = {
const ipAdapterCollectNode: Invocation<'collect'> = {
id: IP_ADAPTER_COLLECT,
type: 'collect',
is_intermediate: true,
@ -46,7 +40,7 @@ export const addIPAdapterToLinearGraph = async (
},
});
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
const ipAdapterMetdata: S['CoreMetadataInvocation']['ipAdapters'] = [];
for (const ipAdapter of ipAdapters) {
if (!ipAdapter.model) {
@ -56,7 +50,7 @@ export const addIPAdapterToLinearGraph = async (
assert(controlImage, 'IP Adapter image is required');
const ipAdapterNode: IPAdapterInvocation = {
const ipAdapterNode: Invocation<'ip_adapter'> = {
id: `ip_adapter_${id}`,
type: 'ip_adapter',
is_intermediate: true,

View File

@ -9,7 +9,7 @@ import {
POSITIVE_CONDITIONING,
} from 'features/nodes/util/graph/constants';
import { filter, size } from 'lodash-es';
import type { CoreMetadataInvocation, LoRALoaderInvocation, NonNullableGraph } from 'services/api/types';
import type { Invocation, NonNullableGraph, S } from 'services/api/types';
export const addLoRAsToGraph = async (
state: RootState,
@ -43,7 +43,7 @@ export const addLoRAsToGraph = async (
// we need to remember the last lora so we can chain from it
let lastLoraNodeId = '';
let currentLoraIndex = 0;
const loraMetadata: CoreMetadataInvocation['loras'] = [];
const loraMetadata: S['CoreMetadataInvocation']['loras'] = [];
enabledLoRAs.forEach(async (lora) => {
const { weight } = lora;
@ -51,7 +51,7 @@ export const addLoRAsToGraph = async (
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
const parsedModel = zModelIdentifierField.parse(lora.model);
const loraLoaderNode: LoRALoaderInvocation = {
const loraLoaderNode: Invocation<'lora_loader'> = {
type: 'lora_loader',
id: currentLoraNodeId,
is_intermediate: true,

View File

@ -1,14 +1,14 @@
import type { RootState } from 'app/store/store';
import { LATENTS_TO_IMAGE, NSFW_CHECKER } from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageNSFWBlurInvocation, LatentsToImageInvocation, NonNullableGraph } from 'services/api/types';
import type { Invocation, NonNullableGraph } from 'services/api/types';
export const addNSFWCheckerToGraph = (
state: RootState,
graph: NonNullableGraph,
nodeIdToAddTo = LATENTS_TO_IMAGE
): void => {
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as LatentsToImageInvocation | undefined;
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as Invocation<'l2i'> | undefined;
if (!nodeToAddTo) {
// something has gone terribly awry
@ -18,14 +18,14 @@ export const addNSFWCheckerToGraph = (
nodeToAddTo.is_intermediate = true;
nodeToAddTo.use_cache = true;
const nsfwCheckerNode: ImageNSFWBlurInvocation = {
const nsfwCheckerNode: Invocation<'img_nsfw'> = {
id: NSFW_CHECKER,
type: 'img_nsfw',
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
};
graph.nodes[NSFW_CHECKER] = nsfwCheckerNode as ImageNSFWBlurInvocation;
graph.nodes[NSFW_CHECKER] = nsfwCheckerNode;
graph.edges.push({
source: {
node_id: nodeIdToAddTo,

View File

@ -10,7 +10,7 @@ import {
SEAMLESS,
} from 'features/nodes/util/graph/constants';
import { filter, size } from 'lodash-es';
import type { CoreMetadataInvocation, NonNullableGraph, SDXLLoRALoaderInvocation } from 'services/api/types';
import type { Invocation, NonNullableGraph, S } from 'services/api/types';
export const addSDXLLoRAsToGraph = async (
state: RootState,
@ -34,7 +34,7 @@ export const addSDXLLoRAsToGraph = async (
return;
}
const loraMetadata: CoreMetadataInvocation['loras'] = [];
const loraMetadata: S['CoreMetadataInvocation']['loras'] = [];
// Handle Seamless Plugs
const unetLoaderId = modelLoaderNodeId;
@ -60,7 +60,7 @@ export const addSDXLLoRAsToGraph = async (
const currentLoraNodeId = `${LORA_LOADER}_${lora.model.key}`;
const parsedModel = zModelIdentifierField.parse(lora.model);
const loraLoaderNode: SDXLLoRALoaderInvocation = {
const loraLoaderNode: Invocation<'sdxl_lora_loader'> = {
type: 'sdxl_lora_loader',
id: currentLoraNodeId,
is_intermediate: true,

View File

@ -17,7 +17,7 @@ import {
SDXL_REFINER_SEAMLESS,
} from 'features/nodes/util/graph/constants';
import { getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import type { NonNullableGraph, SeamlessModeInvocation } from 'services/api/types';
import type { NonNullableGraph } from 'services/api/types';
import { isRefinerMainModelModelConfig } from 'services/api/types';
export const addSDXLRefinerToGraph = async (
@ -100,7 +100,7 @@ export const addSDXLRefinerToGraph = async (
type: 'seamless',
seamless_x: seamlessXAxis,
seamless_y: seamlessYAxis,
} as SeamlessModeInvocation;
};
graph.edges.push(
{

View File

@ -11,7 +11,7 @@ import {
SEAMLESS,
VAE_LOADER,
} from 'features/nodes/util/graph/constants';
import type { NonNullableGraph, SeamlessModeInvocation } from 'services/api/types';
import type { NonNullableGraph } from 'services/api/types';
export const addSeamlessToLinearGraph = (
state: RootState,
@ -27,7 +27,7 @@ export const addSeamlessToLinearGraph = (
type: 'seamless',
seamless_x: seamlessXAxis,
seamless_y: seamlessYAxis,
} as SeamlessModeInvocation;
};
if (!isAutoVae) {
graph.nodes[VAE_LOADER] = {

View File

@ -5,13 +5,7 @@ import type { ImageField } from 'features/nodes/types/common';
import { upsertMetadata } from 'features/nodes/util/graph/canvas/metadata';
import { T2I_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import type {
CollectInvocation,
CoreMetadataInvocation,
NonNullableGraph,
S,
T2IAdapterInvocation,
} from 'services/api/types';
import type { Invocation, NonNullableGraph, S } from 'services/api/types';
import { assert } from 'tsafe';
export const addT2IAdaptersToLinearGraph = async (
@ -35,7 +29,7 @@ export const addT2IAdaptersToLinearGraph = async (
if (t2iAdapters.length) {
// Even though denoise_latents' t2i adapter input is collection or scalar, keep it simple and always use a collect
const t2iAdapterCollectNode: CollectInvocation = {
const t2iAdapterCollectNode: Invocation<'collect'> = {
id: T2I_ADAPTER_COLLECT,
type: 'collect',
is_intermediate: true,
@ -49,7 +43,7 @@ export const addT2IAdaptersToLinearGraph = async (
},
});
const t2iAdapterMetadata: CoreMetadataInvocation['t2iAdapters'] = [];
const t2iAdapterMetadata: S['CoreMetadataInvocation']['t2iAdapters'] = [];
for (const t2iAdapter of t2iAdapters) {
if (!t2iAdapter.model) {
@ -67,7 +61,7 @@ export const addT2IAdaptersToLinearGraph = async (
weight,
} = t2iAdapter;
const t2iAdapterNode: T2IAdapterInvocation = {
const t2iAdapterNode: Invocation<'t2i_adapter'> = {
id: `t2i_adapter_${id}`,
type: 't2i_adapter',
is_intermediate: true,

View File

@ -1,28 +1,23 @@
import type { RootState } from 'app/store/store';
import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type {
ImageNSFWBlurInvocation,
ImageWatermarkInvocation,
LatentsToImageInvocation,
NonNullableGraph,
} from 'services/api/types';
import type { Invocation, NonNullableGraph } from 'services/api/types';
export const addWatermarkerToGraph = (
state: RootState,
graph: NonNullableGraph,
nodeIdToAddTo = LATENTS_TO_IMAGE
): void => {
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as LatentsToImageInvocation | undefined;
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as Invocation<'l2i'> | undefined;
const nsfwCheckerNode = graph.nodes[NSFW_CHECKER] as ImageNSFWBlurInvocation | undefined;
const nsfwCheckerNode = graph.nodes[NSFW_CHECKER] as Invocation<'img_nsfw'> | undefined;
if (!nodeToAddTo) {
// something has gone terribly awry
return;
}
const watermarkerNode: ImageWatermarkInvocation = {
const watermarkerNode: Invocation<'img_watermark'> = {
id: WATERMARKER,
type: 'img_watermark',
is_intermediate: getIsIntermediate(state),

View File

@ -17,12 +17,8 @@ import {
SEAMLESS,
} from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import {
type ImageDTO,
type ImageToLatentsInvocation,
isNonRefinerMainModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -300,7 +296,7 @@ export const buildCanvasImageToImageGraph = async (
use_cache: false,
};
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = initialImage;
(graph.nodes[IMAGE_TO_LATENTS] as Invocation<'i2l'>).image = initialImage;
graph.edges.push({
source: {

View File

@ -19,14 +19,7 @@ import {
SEAMLESS,
} from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type {
CanvasPasteBackInvocation,
CreateGradientMaskInvocation,
ImageDTO,
ImageToLatentsInvocation,
NoiseInvocation,
NonNullableGraph,
} from 'services/api/types';
import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -316,8 +309,8 @@ export const buildCanvasInpaintGraph = async (
height: height,
};
(graph.nodes[NOISE] as NoiseInvocation).width = scaledWidth;
(graph.nodes[NOISE] as NoiseInvocation).height = scaledHeight;
(graph.nodes[NOISE] as Invocation<'noise'>).width = scaledWidth;
(graph.nodes[NOISE] as Invocation<'noise'>).height = scaledHeight;
// Connect Nodes
graph.edges.push(
@ -397,22 +390,22 @@ export const buildCanvasInpaintGraph = async (
);
} else {
// Add Images To Nodes
(graph.nodes[NOISE] as NoiseInvocation).width = width;
(graph.nodes[NOISE] as NoiseInvocation).height = height;
(graph.nodes[NOISE] as Invocation<'noise'>).width = width;
(graph.nodes[NOISE] as Invocation<'noise'>).height = height;
graph.nodes[INPAINT_IMAGE] = {
...(graph.nodes[INPAINT_IMAGE] as ImageToLatentsInvocation),
...(graph.nodes[INPAINT_IMAGE] as Invocation<'i2l'>),
image: canvasInitImage,
};
graph.nodes[INPAINT_CREATE_MASK] = {
...(graph.nodes[INPAINT_CREATE_MASK] as CreateGradientMaskInvocation),
...(graph.nodes[INPAINT_CREATE_MASK] as Invocation<'create_gradient_mask'>),
mask: canvasMaskImage,
};
// Paste Back
graph.nodes[CANVAS_OUTPUT] = {
...(graph.nodes[CANVAS_OUTPUT] as CanvasPasteBackInvocation),
...(graph.nodes[CANVAS_OUTPUT] as Invocation<'canvas_paste_back'>),
mask: canvasMaskImage,
};

View File

@ -23,14 +23,7 @@ import {
SEAMLESS,
} from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type {
ImageDTO,
ImageToLatentsInvocation,
InfillPatchMatchInvocation,
InfillTileInvocation,
NoiseInvocation,
NonNullableGraph,
} from 'services/api/types';
import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -437,8 +430,8 @@ export const buildCanvasOutpaintGraph = async (
height: height,
};
(graph.nodes[NOISE] as NoiseInvocation).width = scaledWidth;
(graph.nodes[NOISE] as NoiseInvocation).height = scaledHeight;
(graph.nodes[NOISE] as Invocation<'noise'>).width = scaledWidth;
(graph.nodes[NOISE] as Invocation<'noise'>).height = scaledHeight;
// Connect Nodes
graph.edges.push(
@ -540,15 +533,15 @@ export const buildCanvasOutpaintGraph = async (
} else {
// Add Images To Nodes
graph.nodes[INPAINT_INFILL] = {
...(graph.nodes[INPAINT_INFILL] as InfillTileInvocation | InfillPatchMatchInvocation),
...(graph.nodes[INPAINT_INFILL] as Invocation<'infill_tile'> | Invocation<'infill_patchmatch'>),
image: canvasInitImage,
};
(graph.nodes[NOISE] as NoiseInvocation).width = width;
(graph.nodes[NOISE] as NoiseInvocation).height = height;
(graph.nodes[NOISE] as Invocation<'noise'>).width = width;
(graph.nodes[NOISE] as Invocation<'noise'>).height = height;
graph.nodes[INPAINT_IMAGE] = {
...(graph.nodes[INPAINT_IMAGE] as ImageToLatentsInvocation),
...(graph.nodes[INPAINT_IMAGE] as Invocation<'i2l'>),
image: canvasInitImage,
};

View File

@ -17,12 +17,8 @@ import {
SEAMLESS,
} from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import {
type ImageDTO,
type ImageToLatentsInvocation,
isNonRefinerMainModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -301,7 +297,7 @@ export const buildCanvasSDXLImageToImageGraph = async (
use_cache: false,
};
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = initialImage;
(graph.nodes[IMAGE_TO_LATENTS] as Invocation<'i2l'>).image = initialImage;
graph.edges.push({
source: {

View File

@ -19,14 +19,7 @@ import {
SEAMLESS,
} from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import type {
CanvasPasteBackInvocation,
CreateGradientMaskInvocation,
ImageDTO,
ImageToLatentsInvocation,
NoiseInvocation,
NonNullableGraph,
} from 'services/api/types';
import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -327,8 +320,8 @@ export const buildCanvasSDXLInpaintGraph = async (
height: height,
};
(graph.nodes[NOISE] as NoiseInvocation).width = scaledWidth;
(graph.nodes[NOISE] as NoiseInvocation).height = scaledHeight;
(graph.nodes[NOISE] as Invocation<'noise'>).width = scaledWidth;
(graph.nodes[NOISE] as Invocation<'noise'>).height = scaledHeight;
// Connect Nodes
graph.edges.push(
@ -408,22 +401,22 @@ export const buildCanvasSDXLInpaintGraph = async (
);
} else {
// Add Images To Nodes
(graph.nodes[NOISE] as NoiseInvocation).width = width;
(graph.nodes[NOISE] as NoiseInvocation).height = height;
(graph.nodes[NOISE] as Invocation<'noise'>).width = width;
(graph.nodes[NOISE] as Invocation<'noise'>).height = height;
graph.nodes[INPAINT_IMAGE] = {
...(graph.nodes[INPAINT_IMAGE] as ImageToLatentsInvocation),
...(graph.nodes[INPAINT_IMAGE] as Invocation<'i2l'>),
image: canvasInitImage,
};
graph.nodes[INPAINT_CREATE_MASK] = {
...(graph.nodes[INPAINT_CREATE_MASK] as CreateGradientMaskInvocation),
...(graph.nodes[INPAINT_CREATE_MASK] as Invocation<'create_gradient_mask'>),
mask: canvasMaskImage,
};
// Paste Back
graph.nodes[CANVAS_OUTPUT] = {
...(graph.nodes[CANVAS_OUTPUT] as CanvasPasteBackInvocation),
...(graph.nodes[CANVAS_OUTPUT] as Invocation<'canvas_paste_back'>),
mask: canvasMaskImage,
};

View File

@ -23,14 +23,7 @@ import {
SEAMLESS,
} from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import type {
ImageDTO,
ImageToLatentsInvocation,
InfillPatchMatchInvocation,
InfillTileInvocation,
NoiseInvocation,
NonNullableGraph,
} from 'services/api/types';
import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -446,8 +439,8 @@ export const buildCanvasSDXLOutpaintGraph = async (
height: height,
};
(graph.nodes[NOISE] as NoiseInvocation).width = scaledWidth;
(graph.nodes[NOISE] as NoiseInvocation).height = scaledHeight;
(graph.nodes[NOISE] as Invocation<'noise'>).width = scaledWidth;
(graph.nodes[NOISE] as Invocation<'noise'>).height = scaledHeight;
// Connect Nodes
graph.edges.push(
@ -549,15 +542,15 @@ export const buildCanvasSDXLOutpaintGraph = async (
} else {
// Add Images To Nodes
graph.nodes[INPAINT_INFILL] = {
...(graph.nodes[INPAINT_INFILL] as InfillTileInvocation | InfillPatchMatchInvocation),
...(graph.nodes[INPAINT_INFILL] as Invocation<'infill_tile'> | Invocation<'infill_patchmatch'>),
image: canvasInitImage,
};
(graph.nodes[NOISE] as NoiseInvocation).width = width;
(graph.nodes[NOISE] as NoiseInvocation).height = height;
(graph.nodes[NOISE] as Invocation<'noise'>).width = width;
(graph.nodes[NOISE] as Invocation<'noise'>).height = height;
graph.nodes[INPAINT_IMAGE] = {
...(graph.nodes[INPAINT_IMAGE] as ImageToLatentsInvocation),
...(graph.nodes[INPAINT_IMAGE] as Invocation<'i2l'>),
image: canvasInitImage,
};

View File

@ -1,11 +1,11 @@
import type { JSONObject } from 'common/types';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { METADATA } from 'features/nodes/util/graph/constants';
import type { AnyModelConfig, CoreMetadataInvocation, NonNullableGraph } from 'services/api/types';
import type { AnyModelConfig, NonNullableGraph, S } from 'services/api/types';
export const addCoreMetadataNode = (
graph: NonNullableGraph,
metadata: Partial<CoreMetadataInvocation>,
metadata: Partial<S['CoreMetadataInvocation']>,
nodeId: string
): void => {
graph.nodes[METADATA] = {
@ -30,9 +30,9 @@ export const addCoreMetadataNode = (
export const upsertMetadata = (
graph: NonNullableGraph,
metadata: Partial<CoreMetadataInvocation> | JSONObject
metadata: Partial<S['CoreMetadataInvocation']> | JSONObject
): void => {
const metadataNode = graph.nodes[METADATA] as CoreMetadataInvocation | undefined;
const metadataNode = graph.nodes[METADATA] as S['CoreMetadataInvocation'] | undefined;
if (!metadataNode) {
return;
@ -41,8 +41,8 @@ export const upsertMetadata = (
Object.assign(metadataNode, metadata);
};
export const removeMetadata = (graph: NonNullableGraph, key: keyof CoreMetadataInvocation): void => {
const metadataNode = graph.nodes[METADATA] as CoreMetadataInvocation | undefined;
export const removeMetadata = (graph: NonNullableGraph, key: keyof S['CoreMetadataInvocation']): void => {
const metadataNode = graph.nodes[METADATA] as S['CoreMetadataInvocation'] | undefined;
if (!metadataNode) {
return;
@ -52,7 +52,7 @@ export const removeMetadata = (graph: NonNullableGraph, key: keyof CoreMetadataI
};
export const getHasMetadata = (graph: NonNullableGraph): boolean => {
const metadataNode = graph.nodes[METADATA] as CoreMetadataInvocation | undefined;
const metadataNode = graph.nodes[METADATA] as S['CoreMetadataInvocation'] | undefined;
return Boolean(metadataNode);
};

View File

@ -7,11 +7,11 @@ import type {
AnyInvocationInputField,
AnyInvocationOutputField,
AnyModelConfig,
CoreMetadataInvocation,
InputFields,
Invocation,
InvocationType,
OutputFields,
S,
} from 'services/api/types';
import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid';
@ -335,13 +335,13 @@ export class Graph {
* INTERNAL: Get the metadata node. If it does not exist, it is created.
* @returns The metadata node.
*/
_getMetadataNode(): CoreMetadataInvocation {
_getMetadataNode(): S['CoreMetadataInvocation'] {
try {
const node = this.getNode(METADATA) as AnyInvocationIncMetadata;
assert(node.type === 'core_metadata');
return node;
} catch {
const node: CoreMetadataInvocation = { id: METADATA, type: 'core_metadata' };
const node: S['CoreMetadataInvocation'] = { id: METADATA, type: 'core_metadata' };
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
return this.addNode(node);
}
@ -353,7 +353,7 @@ export class Graph {
* @param metadata The metadata to add.
* @returns The metadata node.
*/
upsertMetadata(metadata: Partial<CoreMetadataInvocation>): CoreMetadataInvocation {
upsertMetadata(metadata: Partial<S['CoreMetadataInvocation']>): S['CoreMetadataInvocation'] {
const node = this._getMetadataNode();
Object.assign(node, metadata);
return node;
@ -364,7 +364,7 @@ export class Graph {
* @param keys The keys of the metadata to remove
* @returns The metadata node
*/
removeMetadata(keys: string[]): CoreMetadataInvocation {
removeMetadata(keys: string[]): S['CoreMetadataInvocation'] {
const metadataNode = this._getMetadataNode();
for (const k of keys) {
unset(metadataNode, k);

View File

@ -154,42 +154,6 @@ export type OutputFields<T extends AnyInvocation> = Extract<
AnyInvocationOutputField
>;
// General nodes
export type CollectInvocation = Invocation<'collect'>;
export type InfillPatchMatchInvocation = Invocation<'infill_patchmatch'>;
export type InfillTileInvocation = Invocation<'infill_tile'>;
export type CreateGradientMaskInvocation = Invocation<'create_gradient_mask'>;
export type CanvasPasteBackInvocation = Invocation<'canvas_paste_back'>;
export type NoiseInvocation = Invocation<'noise'>;
export type SDXLLoRALoaderInvocation = Invocation<'sdxl_lora_loader'>;
export type ImageToLatentsInvocation = Invocation<'i2l'>;
export type LatentsToImageInvocation = Invocation<'l2i'>;
export type LoRALoaderInvocation = Invocation<'lora_loader'>;
export type ESRGANInvocation = Invocation<'esrgan'>;
export type ImageNSFWBlurInvocation = Invocation<'img_nsfw'>;
export type ImageWatermarkInvocation = Invocation<'img_watermark'>;
export type SeamlessModeInvocation = Invocation<'seamless'>;
export type CoreMetadataInvocation = Extract<Graph['nodes'][string], { type: 'core_metadata' }>;
// ControlNet Nodes
export type ControlNetInvocation = Invocation<'controlnet'>;
export type T2IAdapterInvocation = Invocation<'t2i_adapter'>;
export type IPAdapterInvocation = Invocation<'ip_adapter'>;
export type CannyImageProcessorInvocation = Invocation<'canny_image_processor'>;
export type ColorMapImageProcessorInvocation = Invocation<'color_map_image_processor'>;
export type ContentShuffleImageProcessorInvocation = Invocation<'content_shuffle_image_processor'>;
export type DepthAnythingImageProcessorInvocation = Invocation<'depth_anything_image_processor'>;
export type HedImageProcessorInvocation = Invocation<'hed_image_processor'>;
export type LineartAnimeImageProcessorInvocation = Invocation<'lineart_anime_image_processor'>;
export type LineartImageProcessorInvocation = Invocation<'lineart_image_processor'>;
export type MediapipeFaceProcessorInvocation = Invocation<'mediapipe_face_processor'>;
export type MidasDepthImageProcessorInvocation = Invocation<'midas_depth_image_processor'>;
export type MlsdImageProcessorInvocation = Invocation<'mlsd_image_processor'>;
export type NormalbaeImageProcessorInvocation = Invocation<'normalbae_image_processor'>;
export type DWOpenposeImageProcessorInvocation = Invocation<'dw_openpose_image_processor'>;
export type PidiImageProcessorInvocation = Invocation<'pidi_image_processor'>;
export type ZoeDepthImageProcessorInvocation = Invocation<'zoe_depth_image_processor'>;
// Node Outputs
export type ImageOutput = S['ImageOutput'];