From 8342f32f2e2a293ff4b839bb7ae66391fb08117b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 7 May 2024 15:14:47 +1000 Subject: [PATCH] refactor(ui): rewrite all types as zod schemas This change prepares for safe metadata recall. --- .../src/features/controlLayers/store/types.ts | 177 +++++++----- .../util/controlAdapters.test.ts | 68 ++++- .../controlLayers/util/controlAdapters.ts | 268 +++++++++++++----- .../parameters/types/parameterSchemas.ts | 8 +- 4 files changed, 361 insertions(+), 160 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index d469506c60..11266c8049 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -1,88 +1,121 @@ -import type { - ControlNetConfigV2, - ImageWithDims, - IPAdapterConfigV2, - T2IAdapterConfigV2, +import { + zControlNetConfigV2, + zImageWithDims, + zIPAdapterConfigV2, + zT2IAdapterConfigV2, } from 'features/controlLayers/util/controlAdapters'; import type { AspectRatioState } from 'features/parameters/components/ImageSize/types'; -import type { - ParameterAutoNegative, - ParameterHeight, - ParameterNegativePrompt, - ParameterNegativeStylePromptSDXL, - ParameterPositivePrompt, - ParameterPositiveStylePromptSDXL, - ParameterWidth, +import { + type ParameterHeight, + type ParameterNegativePrompt, + type ParameterNegativeStylePromptSDXL, + type ParameterPositivePrompt, + type ParameterPositiveStylePromptSDXL, + type ParameterWidth, + zAutoNegative, + zParameterNegativePrompt, + zParameterPositivePrompt, + zParameterStrength, } from 'features/parameters/types/parameterSchemas'; -import type { IRect } from 'konva/lib/types'; -import type { RgbColor } from 'react-colorful'; +import { z } from 'zod'; -export type DrawingTool = 'brush' | 'eraser'; +export const zTool = z.enum(['brush', 'eraser', 'move', 'rect']); +export type Tool = z.infer; +export const zDrawingTool = zTool.extract(['brush', 'eraser']); +export type DrawingTool = z.infer; -export type Tool = DrawingTool | 'move' | 'rect'; +const zPoints = z.array(z.number()).refine((points) => points.length % 2 === 0, { + message: 'Must have an even number of points', +}); +export const zVectorMaskLine = z.object({ + id: z.string(), + type: z.literal('vector_mask_line'), + tool: zDrawingTool, + strokeWidth: z.number().min(1), + points: zPoints, +}); +export type VectorMaskLine = z.infer; -export type VectorMaskLine = { - id: string; - type: 'vector_mask_line'; - tool: DrawingTool; - strokeWidth: number; - points: number[]; -}; +export const zVectorMaskRect = z.object({ + id: z.string(), + type: z.literal('vector_mask_rect'), + x: z.number(), + y: z.number(), + width: z.number().min(1), + height: z.number().min(1), +}); +export type VectorMaskRect = z.infer; -export type VectorMaskRect = { - id: string; - type: 'vector_mask_rect'; - x: number; - y: number; - width: number; - height: number; -}; +const zLayerBase = z.object({ + id: z.string(), + isEnabled: z.boolean(), +}); -type LayerBase = { - id: string; - isEnabled: boolean; -}; +const zRect = z.object({ + x: z.number(), + y: z.number(), + width: z.number().min(1), + height: z.number().min(1), +}); +const zRenderableLayerBase = zLayerBase.extend({ + x: z.number(), + y: z.number(), + bbox: zRect.nullable(), + bboxNeedsUpdate: z.boolean(), + isSelected: z.boolean(), +}); -type RenderableLayerBase = LayerBase & { - x: number; - y: number; - bbox: IRect | null; - bboxNeedsUpdate: boolean; - isSelected: boolean; -}; +const zControlAdapterLayer = zRenderableLayerBase.extend({ + type: z.literal('control_adapter_layer'), + opacity: z.number().gte(0).lte(1), + isFilterEnabled: z.boolean(), + controlAdapter: z.discriminatedUnion('type', [zControlNetConfigV2, zT2IAdapterConfigV2]), +}); +export type ControlAdapterLayer = z.infer; -export type ControlAdapterLayer = RenderableLayerBase & { - type: 'control_adapter_layer'; // technically, also t2i adapter layer - opacity: number; - isFilterEnabled: boolean; - controlAdapter: ControlNetConfigV2 | T2IAdapterConfigV2; -}; +const zIPAdapterLayer = zLayerBase.extend({ + type: z.literal('ip_adapter_layer'), + ipAdapter: zIPAdapterConfigV2, +}); +export type IPAdapterLayer = z.infer; -export type IPAdapterLayer = LayerBase & { - type: 'ip_adapter_layer'; - ipAdapter: IPAdapterConfigV2; -}; +const zRgbColor = z.object({ + r: z.number().int().min(0).max(255), + g: z.number().int().min(0).max(255), + b: z.number().int().min(0).max(255), +}); +const zRegionalGuidanceLayer = zRenderableLayerBase.extend({ + type: z.literal('regional_guidance_layer'), + maskObjects: z.array(z.discriminatedUnion('type', [zVectorMaskLine, zVectorMaskRect])), + positivePrompt: zParameterPositivePrompt.nullable(), + negativePrompt: zParameterNegativePrompt.nullable(), + ipAdapters: z.array(zIPAdapterConfigV2), + previewColor: zRgbColor, + autoNegative: zAutoNegative, + needsPixelBbox: z + .boolean() + .describe( + 'Whether the layer needs the slower pixel-based bbox calculation. Set to true when an there is an eraser object.' + ), + uploadedMaskImage: zImageWithDims.nullable(), +}); +export type RegionalGuidanceLayer = z.infer; -export type RegionalGuidanceLayer = RenderableLayerBase & { - type: 'regional_guidance_layer'; - maskObjects: (VectorMaskLine | VectorMaskRect)[]; - positivePrompt: ParameterPositivePrompt | null; - negativePrompt: ParameterNegativePrompt | null; // Up to one text prompt per mask - ipAdapters: IPAdapterConfigV2[]; // Any number of image prompts - previewColor: RgbColor; - autoNegative: ParameterAutoNegative; - needsPixelBbox: boolean; // Needs the slower pixel-based bbox calculation - set to true when an there is an eraser object - uploadedMaskImage: ImageWithDims | null; -}; +const zInitialImageLayer = zRenderableLayerBase.extend({ + type: z.literal('initial_image_layer'), + opacity: z.number().gte(0).lte(1), + image: zImageWithDims.nullable(), + denoisingStrength: zParameterStrength, +}); +export type InitialImageLayer = z.infer; -export type InitialImageLayer = RenderableLayerBase & { - type: 'initial_image_layer'; - opacity: number; - image: ImageWithDims | null; - denoisingStrength: number; -}; - -export type Layer = RegionalGuidanceLayer | ControlAdapterLayer | IPAdapterLayer | InitialImageLayer; +export const zLayer = z.discriminatedUnion('type', [ + zRegionalGuidanceLayer, + zControlAdapterLayer, + zIPAdapterLayer, + zInitialImageLayer, +]); +export type Layer = z.infer; export type ControlLayersState = { _version: 2; diff --git a/invokeai/frontend/web/src/features/controlLayers/util/controlAdapters.test.ts b/invokeai/frontend/web/src/features/controlLayers/util/controlAdapters.test.ts index 880514bf7c..31eb54e730 100644 --- a/invokeai/frontend/web/src/features/controlLayers/util/controlAdapters.test.ts +++ b/invokeai/frontend/web/src/features/controlLayers/util/controlAdapters.test.ts @@ -4,20 +4,74 @@ import { assert } from 'tsafe'; import { describe, test } from 'vitest'; import type { + _CannyProcessorConfig, + _ColorMapProcessorConfig, + _ContentShuffleProcessorConfig, + _DepthAnythingProcessorConfig, + _DWOpenposeProcessorConfig, + _HedProcessorConfig, + _LineartAnimeProcessorConfig, + _LineartProcessorConfig, + _MediapipeFaceProcessorConfig, + _MidasDepthProcessorConfig, + _MlsdProcessorConfig, + _NormalbaeProcessorConfig, + _PidiProcessorConfig, + _ZoeDepthProcessorConfig, + CannyProcessorConfig, CLIPVisionModelV2, + ColorMapProcessorConfig, + ContentShuffleProcessorConfig, ControlModeV2, DepthAnythingModelSize, + DepthAnythingProcessorConfig, + DWOpenposeProcessorConfig, + HedProcessorConfig, IPMethodV2, + LineartAnimeProcessorConfig, + LineartProcessorConfig, + MediapipeFaceProcessorConfig, + MidasDepthProcessorConfig, + MlsdProcessorConfig, + NormalbaeProcessorConfig, + PidiProcessorConfig, ProcessorConfig, ProcessorTypeV2, + ZoeDepthProcessorConfig, } from './controlAdapters'; describe('Control Adapter Types', () => { - test('ProcessorType', () => assert>()); - test('IP Adapter Method', () => assert, IPMethodV2>>()); - test('CLIP Vision Model', () => - assert, CLIPVisionModelV2>>()); - test('Control Mode', () => assert, ControlModeV2>>()); - test('DepthAnything Model Size', () => - assert, DepthAnythingModelSize>>()); + test('ProcessorType', () => { + assert>(); + }); + test('IP Adapter Method', () => { + assert, IPMethodV2>>(); + }); + test('CLIP Vision Model', () => { + assert, CLIPVisionModelV2>>(); + }); + test('Control Mode', () => { + assert, ControlModeV2>>(); + }); + test('DepthAnything Model Size', () => { + assert, DepthAnythingModelSize>>(); + }); + test('Processor Configs', () => { + // The processor configs are manually modeled zod schemas. This test ensures that the inferred types are correct. + // The types prefixed with `_` are types generated from OpenAPI, while the types without the prefix are manually modeled. + assert>(); + assert>(); + assert>(); + assert>(); + assert>(); + assert>(); + assert>(); + assert>(); + assert>(); + assert>(); + assert>(); + assert>(); + assert>(); + assert>(); + }); }); diff --git a/invokeai/frontend/web/src/features/controlLayers/util/controlAdapters.ts b/invokeai/frontend/web/src/features/controlLayers/util/controlAdapters.ts index 617b527475..9e885c56e2 100644 --- a/invokeai/frontend/web/src/features/controlLayers/util/controlAdapters.ts +++ b/invokeai/frontend/web/src/features/controlLayers/util/controlAdapters.ts @@ -1,9 +1,5 @@ import { deepClone } from 'common/util/deepClone'; -import type { - ParameterControlNetModel, - ParameterIPAdapterModel, - ParameterT2IAdapterModel, -} from 'features/parameters/types/parameterSchemas'; +import { zModelIdentifierField } from 'features/nodes/types/common'; import { merge, omit } from 'lodash-es'; import type { BaseModelType, @@ -28,90 +24,207 @@ import type { } from 'services/api/types'; import { z } from 'zod'; +const zId = z.string().min(1); + +const zCannyProcessorConfig = z.object({ + id: zId, + type: z.literal('canny_image_processor'), + low_threshold: z.number().int().gte(0).lte(255), + high_threshold: z.number().int().gte(0).lte(255), +}); +export type _CannyProcessorConfig = Required< + Pick +>; +export type CannyProcessorConfig = z.infer; + +const zColorMapProcessorConfig = z.object({ + id: zId, + type: z.literal('color_map_image_processor'), + color_map_tile_size: z.number().int().gte(1), +}); +export type _ColorMapProcessorConfig = Required< + Pick +>; +export type ColorMapProcessorConfig = z.infer; + +const zContentShuffleProcessorConfig = z.object({ + id: zId, + type: z.literal('content_shuffle_image_processor'), + w: z.number().int().gte(0), + h: z.number().int().gte(0), + f: z.number().int().gte(0), +}); +export type _ContentShuffleProcessorConfig = Required< + Pick +>; +export type ContentShuffleProcessorConfig = z.infer; + const zDepthAnythingModelSize = z.enum(['large', 'base', 'small']); export type DepthAnythingModelSize = z.infer; export const isDepthAnythingModelSize = (v: unknown): v is DepthAnythingModelSize => zDepthAnythingModelSize.safeParse(v).success; - -export type CannyProcessorConfig = Required< - Pick ->; -export type ColorMapProcessorConfig = Required< - Pick ->; -export type ContentShuffleProcessorConfig = Required< - Pick ->; -export type DepthAnythingProcessorConfig = Required< +const zDepthAnythingProcessorConfig = z.object({ + id: zId, + type: z.literal('depth_anything_image_processor'), + model_size: zDepthAnythingModelSize, +}); +export type _DepthAnythingProcessorConfig = Required< Pick >; -export type HedProcessorConfig = Required>; -type LineartAnimeProcessorConfig = Required>; -export type LineartProcessorConfig = Required>; -export type MediapipeFaceProcessorConfig = Required< +export type DepthAnythingProcessorConfig = z.infer; + +const zHedProcessorConfig = z.object({ + id: zId, + type: z.literal('hed_image_processor'), + scribble: z.boolean(), +}); +export type _HedProcessorConfig = Required>; +export type HedProcessorConfig = z.infer; + +const zLineartAnimeProcessorConfig = z.object({ + id: zId, + type: z.literal('lineart_anime_image_processor'), +}); +export type _LineartAnimeProcessorConfig = Required>; +export type LineartAnimeProcessorConfig = z.infer; + +const zLineartProcessorConfig = z.object({ + id: zId, + type: z.literal('lineart_image_processor'), + coarse: z.boolean(), +}); +export type _LineartProcessorConfig = Required>; +export type LineartProcessorConfig = z.infer; + +const zMediapipeFaceProcessorConfig = z.object({ + id: zId, + type: z.literal('mediapipe_face_processor'), + max_faces: z.number().int().gte(1), + min_confidence: z.number().gte(0).lte(1), +}); +export type _MediapipeFaceProcessorConfig = Required< Pick >; -export type MidasDepthProcessorConfig = Required< +export type MediapipeFaceProcessorConfig = z.infer; + +const zMidasDepthProcessorConfig = z.object({ + id: zId, + type: z.literal('midas_depth_image_processor'), + a_mult: z.number().gte(0), + bg_th: z.number().gte(0), +}); +export type _MidasDepthProcessorConfig = Required< Pick >; -export type MlsdProcessorConfig = Required>; -type NormalbaeProcessorConfig = Required>; -export type DWOpenposeProcessorConfig = Required< +export type MidasDepthProcessorConfig = z.infer; + +const zMlsdProcessorConfig = z.object({ + id: zId, + type: z.literal('mlsd_image_processor'), + thr_v: z.number().gte(0), + thr_d: z.number().gte(0), +}); +export type _MlsdProcessorConfig = Required>; +export type MlsdProcessorConfig = z.infer; + +const zNormalbaeProcessorConfig = z.object({ + id: zId, + type: z.literal('normalbae_image_processor'), +}); +export type _NormalbaeProcessorConfig = Required>; +export type NormalbaeProcessorConfig = z.infer; + +const zDWOpenposeProcessorConfig = z.object({ + id: zId, + type: z.literal('dw_openpose_image_processor'), + draw_body: z.boolean(), + draw_face: z.boolean(), + draw_hands: z.boolean(), +}); +export type _DWOpenposeProcessorConfig = Required< Pick >; -export type PidiProcessorConfig = Required>; -type ZoeDepthProcessorConfig = Required>; +export type DWOpenposeProcessorConfig = z.infer; -export type ProcessorConfig = - | CannyProcessorConfig - | ColorMapProcessorConfig - | ContentShuffleProcessorConfig - | DepthAnythingProcessorConfig - | HedProcessorConfig - | LineartAnimeProcessorConfig - | LineartProcessorConfig - | MediapipeFaceProcessorConfig - | MidasDepthProcessorConfig - | MlsdProcessorConfig - | NormalbaeProcessorConfig - | DWOpenposeProcessorConfig - | PidiProcessorConfig - | ZoeDepthProcessorConfig; +const zPidiProcessorConfig = z.object({ + id: zId, + type: z.literal('pidi_image_processor'), + safe: z.boolean(), + scribble: z.boolean(), +}); +export type _PidiProcessorConfig = Required>; +export type PidiProcessorConfig = z.infer; -export type ImageWithDims = { - name: string; - width: number; - height: number; -}; +const zZoeDepthProcessorConfig = z.object({ + id: zId, + type: z.literal('zoe_depth_image_processor'), +}); +export type _ZoeDepthProcessorConfig = Required>; +export type ZoeDepthProcessorConfig = z.infer; -type ControlAdapterBase = { - id: string; - weight: number; - image: ImageWithDims | null; - processedImage: ImageWithDims | null; - isProcessingImage: boolean; - processorConfig: ProcessorConfig | null; - beginEndStepPct: [number, number]; -}; +export const zProcessorConfig = z.discriminatedUnion('type', [ + zCannyProcessorConfig, + zColorMapProcessorConfig, + zContentShuffleProcessorConfig, + zDepthAnythingProcessorConfig, + zHedProcessorConfig, + zLineartAnimeProcessorConfig, + zLineartProcessorConfig, + zMediapipeFaceProcessorConfig, + zMidasDepthProcessorConfig, + zMlsdProcessorConfig, + zNormalbaeProcessorConfig, + zDWOpenposeProcessorConfig, + zPidiProcessorConfig, + zZoeDepthProcessorConfig, +]); +export type ProcessorConfig = z.infer; + +export const zImageWithDims = z.object({ + name: z.string(), + width: z.number().int().positive(), + height: z.number().int().positive(), +}); +export type ImageWithDims = z.infer; + +const zBeginEndStepPct = z + .tuple([z.number().gte(0).lte(1), z.number().gte(0).lte(1)]) + .refine(([begin, end]) => begin < end, { + message: 'Begin must be less than end', + }); + +const zControlAdapterBase = z.object({ + id: zId, + weight: z.number().gte(0).lte(0), + image: zImageWithDims.nullable(), + processedImage: zImageWithDims.nullable(), + isProcessingImage: z.boolean(), + processorConfig: zProcessorConfig.nullable(), + beginEndStepPct: zBeginEndStepPct, +}); const zControlModeV2 = z.enum(['balanced', 'more_prompt', 'more_control', 'unbalanced']); export type ControlModeV2 = z.infer; export const isControlModeV2 = (v: unknown): v is ControlModeV2 => zControlModeV2.safeParse(v).success; -export type ControlNetConfigV2 = ControlAdapterBase & { - type: 'controlnet'; - model: ParameterControlNetModel | null; - controlMode: ControlModeV2; -}; -export const isControlNetConfigV2 = (ca: ControlNetConfigV2 | T2IAdapterConfigV2): ca is ControlNetConfigV2 => - ca.type === 'controlnet'; +export const zControlNetConfigV2 = zControlAdapterBase.extend({ + type: z.literal('controlnet'), + model: zModelIdentifierField.nullable(), + controlMode: zControlModeV2, +}); +export type ControlNetConfigV2 = z.infer; + +export const isControlNetConfigV2 = (ca: ControlNetConfigV2 | T2IAdapterConfigV2): ca is ControlNetConfigV2 => + zControlNetConfigV2.safeParse(ca).success; + +export const zT2IAdapterConfigV2 = zControlAdapterBase.extend({ + type: z.literal('t2i_adapter'), + model: zModelIdentifierField.nullable(), +}); +export type T2IAdapterConfigV2 = z.infer; -export type T2IAdapterConfigV2 = ControlAdapterBase & { - type: 't2i_adapter'; - model: ParameterT2IAdapterModel | null; -}; export const isT2IAdapterConfigV2 = (ca: ControlNetConfigV2 | T2IAdapterConfigV2): ca is T2IAdapterConfigV2 => - ca.type === 't2i_adapter'; + zT2IAdapterConfigV2.safeParse(ca).success; const zCLIPVisionModelV2 = z.enum(['ViT-H', 'ViT-G']); export type CLIPVisionModelV2 = z.infer; @@ -121,16 +234,17 @@ const zIPMethodV2 = z.enum(['full', 'style', 'composition']); export type IPMethodV2 = z.infer; export const isIPMethodV2 = (v: unknown): v is IPMethodV2 => zIPMethodV2.safeParse(v).success; -export type IPAdapterConfigV2 = { - id: string; - type: 'ip_adapter'; - weight: number; - method: IPMethodV2; - image: ImageWithDims | null; - model: ParameterIPAdapterModel | null; - clipVisionModel: CLIPVisionModelV2; - beginEndStepPct: [number, number]; -}; +export const zIPAdapterConfigV2 = z.object({ + id: zId, + type: z.literal('ip_adapter'), + weight: z.number().gte(0).lte(0), + method: zIPMethodV2, + image: zImageWithDims.nullable(), + model: zModelIdentifierField.nullable(), + clipVisionModel: zCLIPVisionModelV2, + beginEndStepPct: zBeginEndStepPct, +}); +export type IPAdapterConfigV2 = z.infer; const zProcessorTypeV2 = z.enum([ 'canny_image_processor', diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index a18cc7f86d..8a808ed0c5 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -16,14 +16,14 @@ import { z } from 'zod'; */ // #region Positive prompt -const zParameterPositivePrompt = z.string(); +export const zParameterPositivePrompt = z.string(); export type ParameterPositivePrompt = z.infer; export const isParameterPositivePrompt = (val: unknown): val is ParameterPositivePrompt => zParameterPositivePrompt.safeParse(val).success; // #endregion // #region Negative prompt -const zParameterNegativePrompt = z.string(); +export const zParameterNegativePrompt = z.string(); export type ParameterNegativePrompt = z.infer; export const isParameterNegativePrompt = (val: unknown): val is ParameterNegativePrompt => zParameterNegativePrompt.safeParse(val).success; @@ -127,7 +127,7 @@ export type ParameterT2IAdapterModel = z.infer // #endregion // #region Strength (l2l strength) -const zParameterStrength = z.number().min(0).max(1); +export const zParameterStrength = z.number().min(0).max(1); export type ParameterStrength = z.infer; export const isParameterStrength = (val: unknown): val is ParameterStrength => zParameterStrength.safeParse(val).success; @@ -198,6 +198,6 @@ export const isParameterLoRAWeight = (val: unknown): val is ParameterLoRAWeight // #endregion // #region Regional Prompts AutoNegative -const zAutoNegative = z.enum(['off', 'invert']); +export const zAutoNegative = z.enum(['off', 'invert']); export type ParameterAutoNegative = z.infer; // #endregion