feat(ui): add and use type helpers for invocations and invocation outputs

This commit is contained in:
psychedelicious 2024-05-05 08:20:31 +10:00
parent 47b8153728
commit e3289856c0
4 changed files with 76 additions and 51 deletions

View File

@ -1,4 +1,4 @@
import type { S } from 'services/api/types';
import type { Invocation } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { describe, test } from 'vitest';
@ -45,16 +45,16 @@ describe('Control Adapter Types', () => {
assert<Equals<ProcessorConfig['type'], ProcessorTypeV2>>();
});
test('IP Adapter Method', () => {
assert<Equals<NonNullable<S['IPAdapterInvocation']['method']>, IPMethodV2>>();
assert<Equals<NonNullable<Invocation<'ip_adapter'>['method']>, IPMethodV2>>();
});
test('CLIP Vision Model', () => {
assert<Equals<NonNullable<S['IPAdapterInvocation']['clip_vision_model']>, CLIPVisionModelV2>>();
assert<Equals<NonNullable<Invocation<'ip_adapter'>['clip_vision_model']>, CLIPVisionModelV2>>();
});
test('Control Mode', () => {
assert<Equals<NonNullable<S['ControlNetInvocation']['control_mode']>, ControlModeV2>>();
assert<Equals<NonNullable<Invocation<'controlnet'>['control_mode']>, ControlModeV2>>();
});
test('DepthAnything Model Size', () => {
assert<Equals<NonNullable<S['DepthAnythingImageProcessorInvocation']['model_size']>, DepthAnythingModelSize>>();
assert<Equals<NonNullable<Invocation<'depth_anything_image_processor'>['model_size']>, DepthAnythingModelSize>>();
});
test('Processor Configs', () => {
// The processor configs are manually modeled zod schemas. This test ensures that the inferred types are correct.

View File

@ -11,7 +11,7 @@ import type {
SchedulerField,
T2IAdapterField,
} from 'features/nodes/types/common';
import type { S } from 'services/api/types';
import type { Invocation, S } from 'services/api/types';
import type { Equals, Extends } from 'tsafe';
import { assert } from 'tsafe';
import { describe, test } from 'vitest';
@ -26,7 +26,7 @@ describe('Common types', () => {
test('ImageField', () => assert<Equals<ImageField, S['ImageField']>>());
test('BoardField', () => assert<Equals<BoardField, S['BoardField']>>());
test('ColorField', () => assert<Equals<ColorField, S['ColorField']>>());
test('SchedulerField', () => assert<Equals<SchedulerField, NonNullable<S['SchedulerInvocation']['scheduler']>>>());
test('SchedulerField', () => assert<Equals<SchedulerField, NonNullable<Invocation<'scheduler'>['scheduler']>>>());
test('ControlField', () => assert<Equals<ControlField, S['ControlField']>>());
// @ts-expect-error TODO(psyche): fix types
test('IPAdapterField', () => assert<Extends<IPAdapterField, S['IPAdapterField']>>());

View File

@ -43,11 +43,9 @@ import type {
ControlNetInvocation,
Edge,
ImageDTO,
ImageResizeInvocation,
ImageToLatentsInvocation,
Invocation,
IPAdapterInvocation,
NonNullableGraph,
S,
T2IAdapterInvocation,
} from 'services/api/types';
import { assert } from 'tsafe';
@ -153,7 +151,7 @@ export const addControlLayersToGraph = async (
const { image_name } = await getMaskImage(layer, blob);
// The main mask-to-tensor node
const maskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
const maskToTensorNode: Invocation<'alpha_mask_to_tensor'> = {
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${layer.id}`,
type: 'alpha_mask_to_tensor',
image: {
@ -164,7 +162,7 @@ export const addControlLayersToGraph = async (
if (layer.positivePrompt) {
// The main positive conditioning node
const regionalPositiveCondNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
const regionalPositiveCondNode: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'> = isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
@ -203,7 +201,7 @@ export const addControlLayersToGraph = async (
if (layer.negativePrompt) {
// The main negative conditioning node
const regionalNegativeCondNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
const regionalNegativeCondNode: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'> = isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
@ -243,7 +241,7 @@ export const addControlLayersToGraph = async (
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
if (layer.autoNegative === 'invert' && layer.positivePrompt) {
// We re-use the mask image, but invert it when converting to tensor
const invertTensorMaskNode: S['InvertTensorMaskInvocation'] = {
const invertTensorMaskNode: Invocation<'invert_tensor_mask'> = {
id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${layer.id}`,
type: 'invert_tensor_mask',
};
@ -263,7 +261,7 @@ export const addControlLayersToGraph = async (
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the
// positive prompt
const regionalPositiveCondInvertedNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
const regionalPositiveCondInvertedNode: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'> = isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
@ -574,7 +572,7 @@ const addInitialImageLayerToGraph = (
: 1 - denoisingStrength;
denoiseNode.denoising_end = useRefinerStartEnd ? refinerStart : 1;
const i2lNode: ImageToLatentsInvocation = {
const i2lNode: Invocation<'i2l'> = {
type: 'i2l',
id: IMAGE_TO_LATENTS,
is_intermediate: true,
@ -598,7 +596,7 @@ const addInitialImageLayerToGraph = (
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
// Create a resize node, explicitly setting its image
const resizeNode: ImageResizeInvocation = {
const resizeNode: Invocation<'img_resize'> = {
id: RESIZE,
type: 'img_resize',
image: {

View File

@ -129,43 +129,70 @@ export type WorkflowRecordOrderBy = S['WorkflowRecordOrderBy'];
export type SQLiteDirection = S['SQLiteDirection'];
export type WorkflowRecordListItemDTO = S['WorkflowRecordListItemDTO'];
export type KeysOfUnion<T> = T extends T ? keyof T : never;
export type NonInputFields = 'id' | 'type' | 'is_intermediate' | 'use_cache';
export type NonOutputFields = 'type';
export type AnyInvocation = Graph['nodes'][string];
export type AnyInvocationExcludeCoreMetata = Exclude<AnyInvocation, { type: 'core_metadata' }>;
export type InvocationType = AnyInvocation['type'];
export type InvocationTypeExcludeCoreMetadata = Exclude<InvocationType, 'core_metadata'>;
export type InvocationOutputMap = S['InvocationOutputMap'];
export type AnyInvocationOutput = InvocationOutputMap[InvocationType];
export type Invocation<T extends InvocationType> = Extract<AnyInvocation, { type: T }>;
export type InvocationExcludeCoreMetadata<T extends InvocationTypeExcludeCoreMetadata> = Extract<
AnyInvocation,
{ type: T }
>;
export type InvocationInputFields<T extends InvocationTypeExcludeCoreMetadata> = Exclude<
keyof Invocation<T>,
NonInputFields
>;
export type AnyInvocationInputField = Exclude<KeysOfUnion<AnyInvocationExcludeCoreMetata>, NonInputFields>;
export type InvocationOutput<T extends InvocationType> = InvocationOutputMap[T];
export type InvocationOutputFields<T extends InvocationType> = Exclude<keyof InvocationOutput<T>, NonOutputFields>;
export type AnyInvocationOutputField = Exclude<KeysOfUnion<AnyInvocationOutput>, NonOutputFields>;
// General nodes
export type CollectInvocation = S['CollectInvocation'];
export type ImageResizeInvocation = S['ImageResizeInvocation'];
export type InfillPatchMatchInvocation = S['InfillPatchMatchInvocation'];
export type InfillTileInvocation = S['InfillTileInvocation'];
export type CreateGradientMaskInvocation = S['CreateGradientMaskInvocation'];
export type CanvasPasteBackInvocation = S['CanvasPasteBackInvocation'];
export type NoiseInvocation = S['NoiseInvocation'];
export type DenoiseLatentsInvocation = S['DenoiseLatentsInvocation'];
export type SDXLLoRALoaderInvocation = S['SDXLLoRALoaderInvocation'];
export type ImageToLatentsInvocation = S['ImageToLatentsInvocation'];
export type LatentsToImageInvocation = S['LatentsToImageInvocation'];
export type LoRALoaderInvocation = S['LoRALoaderInvocation'];
export type ESRGANInvocation = S['ESRGANInvocation'];
export type ImageNSFWBlurInvocation = S['ImageNSFWBlurInvocation'];
export type ImageWatermarkInvocation = S['ImageWatermarkInvocation'];
export type SeamlessModeInvocation = S['SeamlessModeInvocation'];
export type CoreMetadataInvocation = S['CoreMetadataInvocation'];
export type CollectInvocation = Invocation<'collect'>;
export type ImageResizeInvocation = Invocation<'img_resize'>;
export type InfillPatchMatchInvocation = Invocation<'infill_patchmatch'>;
export type InfillTileInvocation = Invocation<'infill_tile'>;
export type CreateGradientMaskInvocation = Invocation<'create_gradient_mask'>;
export type CanvasPasteBackInvocation = Invocation<'canvas_paste_back'>;
export type NoiseInvocation = Invocation<'noise'>;
export type DenoiseLatentsInvocation = Invocation<'denoise_latents'>;
export type SDXLLoRALoaderInvocation = Invocation<'sdxl_lora_loader'>;
export type ImageToLatentsInvocation = Invocation<'i2l'>;
export type LatentsToImageInvocation = Invocation<'l2i'>;
export type LoRALoaderInvocation = Invocation<'lora_loader'>;
export type ESRGANInvocation = Invocation<'esrgan'>;
export type ImageNSFWBlurInvocation = Invocation<'img_nsfw'>;
export type ImageWatermarkInvocation = Invocation<'img_watermark'>;
export type SeamlessModeInvocation = Invocation<'seamless'>;
export type CoreMetadataInvocation = Extract<Graph['nodes'][string], { type: 'core_metadata' }>;
// ControlNet Nodes
export type ControlNetInvocation = S['ControlNetInvocation'];
export type T2IAdapterInvocation = S['T2IAdapterInvocation'];
export type IPAdapterInvocation = S['IPAdapterInvocation'];
export type CannyImageProcessorInvocation = S['CannyImageProcessorInvocation'];
export type ColorMapImageProcessorInvocation = S['ColorMapImageProcessorInvocation'];
export type ContentShuffleImageProcessorInvocation = S['ContentShuffleImageProcessorInvocation'];
export type DepthAnythingImageProcessorInvocation = S['DepthAnythingImageProcessorInvocation'];
export type HedImageProcessorInvocation = S['HedImageProcessorInvocation'];
export type LineartAnimeImageProcessorInvocation = S['LineartAnimeImageProcessorInvocation'];
export type LineartImageProcessorInvocation = S['LineartImageProcessorInvocation'];
export type MediapipeFaceProcessorInvocation = S['MediapipeFaceProcessorInvocation'];
export type MidasDepthImageProcessorInvocation = S['MidasDepthImageProcessorInvocation'];
export type MlsdImageProcessorInvocation = S['MlsdImageProcessorInvocation'];
export type NormalbaeImageProcessorInvocation = S['NormalbaeImageProcessorInvocation'];
export type DWOpenposeImageProcessorInvocation = S['DWOpenposeImageProcessorInvocation'];
export type PidiImageProcessorInvocation = S['PidiImageProcessorInvocation'];
export type ZoeDepthImageProcessorInvocation = S['ZoeDepthImageProcessorInvocation'];
export type ControlNetInvocation = Invocation<'controlnet'>;
export type T2IAdapterInvocation = Invocation<'t2i_adapter'>;
export type IPAdapterInvocation = Invocation<'ip_adapter'>;
export type CannyImageProcessorInvocation = Invocation<'canny_image_processor'>;
export type ColorMapImageProcessorInvocation = Invocation<'color_map_image_processor'>;
export type ContentShuffleImageProcessorInvocation = Invocation<'content_shuffle_image_processor'>;
export type DepthAnythingImageProcessorInvocation = Invocation<'depth_anything_image_processor'>;
export type HedImageProcessorInvocation = Invocation<'hed_image_processor'>;
export type LineartAnimeImageProcessorInvocation = Invocation<'lineart_anime_image_processor'>;
export type LineartImageProcessorInvocation = Invocation<'lineart_image_processor'>;
export type MediapipeFaceProcessorInvocation = Invocation<'mediapipe_face_processor'>;
export type MidasDepthImageProcessorInvocation = Invocation<'midas_depth_image_processor'>;
export type MlsdImageProcessorInvocation = Invocation<'mlsd_image_processor'>;
export type NormalbaeImageProcessorInvocation = Invocation<'normalbae_image_processor'>;
export type DWOpenposeImageProcessorInvocation = Invocation<'dw_openpose_image_processor'>;
export type PidiImageProcessorInvocation = Invocation<'pidi_image_processor'>;
export type ZoeDepthImageProcessorInvocation = Invocation<'zoe_depth_image_processor'>;
// Node Outputs
export type ImageOutput = S['ImageOutput'];