From c79d9b9ecf74260a2930ed00b20e5c6ea60263c0 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Fri, 14 Jun 2024 16:04:16 +0530 Subject: [PATCH 1/4] wip: Add Initial support for select SD3 models in UI --- .../ModelManagerPanel/ModelBaseBadge.tsx | 1 + .../ModelPanel/Fields/BaseModelSelect.tsx | 1 + .../features/nodes/store/util/testUtils.ts | 5 +- .../web/src/features/nodes/types/common.ts | 5 +- .../web/src/features/nodes/types/v2/common.ts | 2 +- .../components/Advanced/ParamClipSkip.tsx | 2 +- .../features/parameters/types/constants.ts | 7 + .../frontend/web/src/services/api/schema.ts | 359 +++++++++--------- .../frontend/web/src/services/api/types.ts | 2 +- 9 files changed, 200 insertions(+), 184 deletions(-) diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx index bf07bad58c..f1c39bf162 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx @@ -11,6 +11,7 @@ const BASE_COLOR_MAP: Record = { any: 'base', 'sd-1': 'green', 'sd-2': 'teal', + 'sd-3': 'purple', sdxl: 'invokeBlue', 'sdxl-refiner': 'invokeBlue', }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx index e07714a827..0a912f1c14 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx @@ -10,6 +10,7 @@ import type { UpdateModelArg } from 'services/api/endpoints/models'; const options: ComboboxOption[] = [ { value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] }, { value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] }, + { value: 'sd-3', label: MODEL_TYPE_MAP['sd-3'] }, { value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] }, { value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] }, ]; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts index 83988d55ea..ca2a24efee 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts @@ -839,7 +839,7 @@ export const schema = { }, BaseModelType: { description: 'Base model type.', - enum: ['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'], + enum: ['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner'], title: 'BaseModelType', type: 'string', }, @@ -855,8 +855,11 @@ export const schema = { 'unet', 'text_encoder', 'text_encoder_2', + 'text_encoder_3', 'tokenizer', 'tokenizer_2', + 'tokenizer_3', + 'transformer', 'vae', 'vae_decoder', 'vae_encoder', diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index 54e126af3a..cf15f98528 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -55,7 +55,7 @@ export type SchedulerField = z.infer; // #endregion // #region Model-related schemas -const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); +const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner']); const zModelType = z.enum([ 'main', 'vae', @@ -71,8 +71,11 @@ const zSubModelType = z.enum([ 'unet', 'text_encoder', 'text_encoder_2', + 'text_encoder_3', 'tokenizer', 'tokenizer_2', + 'tokenizer_3', + 'transformer', 'vae', 'vae_decoder', 'vae_encoder', diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/common.ts b/invokeai/frontend/web/src/features/nodes/types/v2/common.ts index 8613076132..3f75d9c1b0 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v2/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v2/common.ts @@ -44,7 +44,7 @@ export const zSchedulerField = z.enum([ // #endregion // #region Model-related schemas -const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); +const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner']); const zModelName = z.string().min(3); export const zModelIdentifier = z.object({ model_name: zModelName, diff --git a/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamClipSkip.tsx b/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamClipSkip.tsx index c23d541613..5fc257d8a8 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamClipSkip.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamClipSkip.tsx @@ -39,7 +39,7 @@ const ParamClipSkip = () => { return CLIP_SKIP_MAP[model.base].markers; }, [model]); - if (model?.base === 'sdxl') { + if (model?.base === 'sdxl' || model?.base === 'sd-3') { return null; } diff --git a/invokeai/frontend/web/src/features/parameters/types/constants.ts b/invokeai/frontend/web/src/features/parameters/types/constants.ts index 6d7b4f9248..554c20ea51 100644 --- a/invokeai/frontend/web/src/features/parameters/types/constants.ts +++ b/invokeai/frontend/web/src/features/parameters/types/constants.ts @@ -7,6 +7,7 @@ export const MODEL_TYPE_MAP = { any: 'Any', 'sd-1': 'Stable Diffusion 1.x', 'sd-2': 'Stable Diffusion 2.x', + 'sd-3': 'Stable Diffusion 3.x', sdxl: 'Stable Diffusion XL', 'sdxl-refiner': 'Stable Diffusion XL Refiner', }; @@ -18,6 +19,7 @@ export const MODEL_TYPE_SHORT_MAP = { any: 'Any', 'sd-1': 'SD1.X', 'sd-2': 'SD2.X', + 'sd-3': 'SD3.X', sdxl: 'SDXL', 'sdxl-refiner': 'SDXLR', }; @@ -38,6 +40,11 @@ export const CLIP_SKIP_MAP = { maxClip: 24, markers: [0, 1, 2, 3, 5, 10, 15, 20, 24], }, + // TODO: Update this when we have more details on how CLIP SKIP works with SD3 + 'sd-3': { + maxClip: 24, + markers: [0, 1, 2, 3, 5, 10, 15, 20, 24], + }, sdxl: { maxClip: 24, markers: [0, 1, 2, 3, 5, 10, 15, 20, 24], diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 5482b57c0b..5f9dc0adbe 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -757,7 +757,7 @@ export type components = { * @description Base model type. * @enum {string} */ - BaseModelType: "any" | "sd-1" | "sd-2" | "sdxl" | "sdxl-refiner"; + BaseModelType: "any" | "sd-1" | "sd-2" | "sdxl" | "sdxl-refiner" | "sd-3"; /** Batch */ Batch: { /** @@ -3788,23 +3788,6 @@ export type components = { * @description Class to monitor and control a model download request. */ DownloadJob: { - /** - * Source - * Format: uri - * @description Where to download from. Specific types specified in child classes. - */ - source: string; - /** - * Dest - * Format: path - * @description Destination of downloaded model on local disk; a directory or file path - */ - dest: string; - /** - * Access Token - * @description authorization token for protected resources - */ - access_token?: string | null; /** * Id * @description Numeric ID of this job @@ -3812,36 +3795,21 @@ export type components = { */ id?: number; /** - * Priority - * @description Queue priority; lower values are higher priority - * @default 10 + * Dest + * Format: path + * @description Initial destination of downloaded model on local disk; a directory or file path */ - priority?: number; + dest: string; + /** + * Download Path + * @description Final location of downloaded file or directory + */ + download_path?: string | null; /** * @description Status of the download * @default waiting */ status?: components["schemas"]["DownloadJobStatus"]; - /** - * Download Path - * @description Final location of downloaded file - */ - download_path?: string | null; - /** - * Job Started - * @description Timestamp for when the download job started - */ - job_started?: string | null; - /** - * Job Ended - * @description Timestamp for when the download job ende1d (completed or errored) - */ - job_ended?: string | null; - /** - * Content Type - * @description Content type of downloaded file - */ - content_type?: string | null; /** * Bytes * @description Bytes downloaded so far @@ -3864,6 +3832,38 @@ export type components = { * @description Traceback of the exception that caused an error */ error?: string | null; + /** + * Source + * Format: uri + * @description Where to download from. Specific types specified in child classes. + */ + source: string; + /** + * Access Token + * @description authorization token for protected resources + */ + access_token?: string | null; + /** + * Priority + * @description Queue priority; lower values are higher priority + * @default 10 + */ + priority?: number; + /** + * Job Started + * @description Timestamp for when the download job started + */ + job_started?: string | null; + /** + * Job Ended + * @description Timestamp for when the download job ende1d (completed or errored) + */ + job_ended?: string | null; + /** + * Content Type + * @description Content type of downloaded file + */ + content_type?: string | null; }; /** * DownloadJobStatus @@ -7276,144 +7276,144 @@ export type components = { project_id: string | null; }; InvocationOutputMap: { - pidi_image_processor: components["schemas"]["ImageOutput"]; - image_mask_to_tensor: components["schemas"]["MaskOutput"]; - vae_loader: components["schemas"]["VAEOutput"]; - collect: components["schemas"]["CollectInvocationOutput"]; - string_join_three: components["schemas"]["StringOutput"]; - content_shuffle_image_processor: components["schemas"]["ImageOutput"]; - random_range: components["schemas"]["IntegerCollectionOutput"]; - ip_adapter: components["schemas"]["IPAdapterOutput"]; - step_param_easing: components["schemas"]["FloatCollectionOutput"]; - core_metadata: components["schemas"]["MetadataOutput"]; - main_model_loader: components["schemas"]["ModelLoaderOutput"]; - leres_image_processor: components["schemas"]["ImageOutput"]; - calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"]; - color_correct: components["schemas"]["ImageOutput"]; - calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"]; - float_range: components["schemas"]["FloatCollectionOutput"]; - infill_cv2: components["schemas"]["ImageOutput"]; - img_channel_multiply: components["schemas"]["ImageOutput"]; - img_pad_crop: components["schemas"]["ImageOutput"]; - sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"]; - face_mask_detection: components["schemas"]["FaceMaskOutput"]; - infill_lama: components["schemas"]["ImageOutput"]; - mask_combine: components["schemas"]["ImageOutput"]; - sdxl_compel_prompt: components["schemas"]["ConditioningOutput"]; - segment_anything_processor: components["schemas"]["ImageOutput"]; - merge_metadata: components["schemas"]["MetadataOutput"]; - img_ilerp: components["schemas"]["ImageOutput"]; - heuristic_resize: components["schemas"]["ImageOutput"]; - cv_inpaint: components["schemas"]["ImageOutput"]; - div: components["schemas"]["IntegerOutput"]; - pair_tile_image: components["schemas"]["PairTileImageOutput"]; - float_math: components["schemas"]["FloatOutput"]; - img_channel_offset: components["schemas"]["ImageOutput"]; - canvas_paste_back: components["schemas"]["ImageOutput"]; - canny_image_processor: components["schemas"]["ImageOutput"]; - integer_collection: components["schemas"]["IntegerCollectionOutput"]; - freeu: components["schemas"]["UNetOutput"]; - lresize: components["schemas"]["LatentsOutput"]; - range_of_size: components["schemas"]["IntegerCollectionOutput"]; - depth_anything_image_processor: components["schemas"]["ImageOutput"]; - float_to_int: components["schemas"]["IntegerOutput"]; - rand_int: components["schemas"]["IntegerOutput"]; - lineart_anime_image_processor: components["schemas"]["ImageOutput"]; - string_split: components["schemas"]["String2Output"]; - img_nsfw: components["schemas"]["ImageOutput"]; - string: components["schemas"]["StringOutput"]; - mask_edge: components["schemas"]["ImageOutput"]; - i2l: components["schemas"]["LatentsOutput"]; - face_identifier: components["schemas"]["ImageOutput"]; - compel: components["schemas"]["ConditioningOutput"]; - esrgan: components["schemas"]["ImageOutput"]; - seamless: components["schemas"]["SeamlessModeOutput"]; - mask_from_id: components["schemas"]["ImageOutput"]; - invert_tensor_mask: components["schemas"]["MaskOutput"]; - rectangle_mask: components["schemas"]["MaskOutput"]; - conditioning: components["schemas"]["ConditioningOutput"]; - t2i_adapter: components["schemas"]["T2IAdapterOutput"]; - string_collection: components["schemas"]["StringCollectionOutput"]; - show_image: components["schemas"]["ImageOutput"]; - dw_openpose_image_processor: components["schemas"]["ImageOutput"]; - string_split_neg: components["schemas"]["StringPosNegOutput"]; - conditioning_collection: components["schemas"]["ConditioningCollectionOutput"]; - infill_patchmatch: components["schemas"]["ImageOutput"]; - img_conv: components["schemas"]["ImageOutput"]; - unsharp_mask: components["schemas"]["ImageOutput"]; - metadata_item: components["schemas"]["MetadataItemOutput"]; - image: components["schemas"]["ImageOutput"]; - image_collection: components["schemas"]["ImageCollectionOutput"]; - tile_to_properties: components["schemas"]["TileToPropertiesOutput"]; - lblend: components["schemas"]["LatentsOutput"]; - float: components["schemas"]["FloatOutput"]; - boolean_collection: components["schemas"]["BooleanCollectionOutput"]; - color: components["schemas"]["ColorOutput"]; - midas_depth_image_processor: components["schemas"]["ImageOutput"]; - zoe_depth_image_processor: components["schemas"]["ImageOutput"]; - infill_rgba: components["schemas"]["ImageOutput"]; - mlsd_image_processor: components["schemas"]["ImageOutput"]; - merge_tiles_to_image: components["schemas"]["ImageOutput"]; - prompt_from_file: components["schemas"]["StringCollectionOutput"]; - boolean: components["schemas"]["BooleanOutput"]; - create_gradient_mask: components["schemas"]["GradientMaskOutput"]; - rand_float: components["schemas"]["FloatOutput"]; - img_mul: components["schemas"]["ImageOutput"]; - controlnet: components["schemas"]["ControlOutput"]; - latents_collection: components["schemas"]["LatentsCollectionOutput"]; - img_lerp: components["schemas"]["ImageOutput"]; - noise: components["schemas"]["NoiseOutput"]; - iterate: components["schemas"]["IterateInvocationOutput"]; - lineart_image_processor: components["schemas"]["ImageOutput"]; - tomask: components["schemas"]["ImageOutput"]; - integer: components["schemas"]["IntegerOutput"]; - create_denoise_mask: components["schemas"]["DenoiseMaskOutput"]; - clip_skip: components["schemas"]["CLIPSkipInvocationOutput"]; - denoise_latents: components["schemas"]["LatentsOutput"]; - string_join: components["schemas"]["StringOutput"]; - scheduler: components["schemas"]["SchedulerOutput"]; - model_identifier: components["schemas"]["ModelIdentifierOutput"]; - normalbae_image_processor: components["schemas"]["ImageOutput"]; - face_off: components["schemas"]["FaceOffOutput"]; - hed_image_processor: components["schemas"]["ImageOutput"]; - img_paste: components["schemas"]["ImageOutput"]; - img_chan: components["schemas"]["ImageOutput"]; - img_watermark: components["schemas"]["ImageOutput"]; - l2i: components["schemas"]["ImageOutput"]; - string_replace: components["schemas"]["StringOutput"]; - color_map_image_processor: components["schemas"]["ImageOutput"]; - tile_image_processor: components["schemas"]["ImageOutput"]; - crop_latents: components["schemas"]["LatentsOutput"]; - sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"]; - add: components["schemas"]["IntegerOutput"]; - sub: components["schemas"]["IntegerOutput"]; - img_scale: components["schemas"]["ImageOutput"]; - range: components["schemas"]["IntegerCollectionOutput"]; - dynamic_prompt: components["schemas"]["StringCollectionOutput"]; - img_crop: components["schemas"]["ImageOutput"]; - infill_tile: components["schemas"]["ImageOutput"]; - img_resize: components["schemas"]["ImageOutput"]; - mediapipe_face_processor: components["schemas"]["ImageOutput"]; - sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"]; - lora_selector: components["schemas"]["LoRASelectorOutput"]; - img_hue_adjust: components["schemas"]["ImageOutput"]; - latents: components["schemas"]["LatentsOutput"]; - lora_collection_loader: components["schemas"]["LoRALoaderOutput"]; - img_blur: components["schemas"]["ImageOutput"]; - ideal_size: components["schemas"]["IdealSizeOutput"]; - float_collection: components["schemas"]["FloatCollectionOutput"]; - blank_image: components["schemas"]["ImageOutput"]; - integer_math: components["schemas"]["IntegerOutput"]; - lora_loader: components["schemas"]["LoRALoaderOutput"]; - metadata: components["schemas"]["MetadataOutput"]; - sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"]; - round_float: components["schemas"]["FloatOutput"]; sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"]; - mul: components["schemas"]["IntegerOutput"]; - alpha_mask_to_tensor: components["schemas"]["MaskOutput"]; - lscale: components["schemas"]["LatentsOutput"]; + hed_image_processor: components["schemas"]["ImageOutput"]; + freeu: components["schemas"]["UNetOutput"]; + pidi_image_processor: components["schemas"]["ImageOutput"]; + sub: components["schemas"]["IntegerOutput"]; + crop_latents: components["schemas"]["LatentsOutput"]; + step_param_easing: components["schemas"]["FloatCollectionOutput"]; + img_hue_adjust: components["schemas"]["ImageOutput"]; + sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"]; + metadata: components["schemas"]["MetadataOutput"]; + dynamic_prompt: components["schemas"]["StringCollectionOutput"]; + boolean_collection: components["schemas"]["BooleanCollectionOutput"]; + img_crop: components["schemas"]["ImageOutput"]; save_image: components["schemas"]["ImageOutput"]; + tile_image_processor: components["schemas"]["ImageOutput"]; + t2i_adapter: components["schemas"]["T2IAdapterOutput"]; + float_to_int: components["schemas"]["IntegerOutput"]; + prompt_from_file: components["schemas"]["StringCollectionOutput"]; + lora_collection_loader: components["schemas"]["LoRALoaderOutput"]; + integer_collection: components["schemas"]["IntegerCollectionOutput"]; + string_split: components["schemas"]["String2Output"]; + i2l: components["schemas"]["LatentsOutput"]; + img_nsfw: components["schemas"]["ImageOutput"]; + calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"]; + zoe_depth_image_processor: components["schemas"]["ImageOutput"]; + img_mul: components["schemas"]["ImageOutput"]; + image_mask_to_tensor: components["schemas"]["MaskOutput"]; + rand_int: components["schemas"]["IntegerOutput"]; + lscale: components["schemas"]["LatentsOutput"]; + img_conv: components["schemas"]["ImageOutput"]; + random_range: components["schemas"]["IntegerCollectionOutput"]; + img_channel_offset: components["schemas"]["ImageOutput"]; + denoise_latents: components["schemas"]["LatentsOutput"]; + lresize: components["schemas"]["LatentsOutput"]; + alpha_mask_to_tensor: components["schemas"]["MaskOutput"]; + show_image: components["schemas"]["ImageOutput"]; + conditioning: components["schemas"]["ConditioningOutput"]; + canny_image_processor: components["schemas"]["ImageOutput"]; + create_denoise_mask: components["schemas"]["DenoiseMaskOutput"]; + boolean: components["schemas"]["BooleanOutput"]; + image: components["schemas"]["ImageOutput"]; calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"]; + infill_patchmatch: components["schemas"]["ImageOutput"]; + mediapipe_face_processor: components["schemas"]["ImageOutput"]; + face_mask_detection: components["schemas"]["FaceMaskOutput"]; + string_split_neg: components["schemas"]["StringPosNegOutput"]; + normalbae_image_processor: components["schemas"]["ImageOutput"]; + create_gradient_mask: components["schemas"]["GradientMaskOutput"]; + blank_image: components["schemas"]["ImageOutput"]; + calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"]; + range: components["schemas"]["IntegerCollectionOutput"]; + float_collection: components["schemas"]["FloatCollectionOutput"]; + scheduler: components["schemas"]["SchedulerOutput"]; + latents_collection: components["schemas"]["LatentsCollectionOutput"]; + color_correct: components["schemas"]["ImageOutput"]; + midas_depth_image_processor: components["schemas"]["ImageOutput"]; + rand_float: components["schemas"]["FloatOutput"]; + lblend: components["schemas"]["LatentsOutput"]; + compel: components["schemas"]["ConditioningOutput"]; + leres_image_processor: components["schemas"]["ImageOutput"]; + float: components["schemas"]["FloatOutput"]; + img_paste: components["schemas"]["ImageOutput"]; + metadata_item: components["schemas"]["MetadataItemOutput"]; + color: components["schemas"]["ColorOutput"]; + string_collection: components["schemas"]["StringCollectionOutput"]; + string_join: components["schemas"]["StringOutput"]; + conditioning_collection: components["schemas"]["ConditioningCollectionOutput"]; + segment_anything_processor: components["schemas"]["ImageOutput"]; + mul: components["schemas"]["IntegerOutput"]; + cv_inpaint: components["schemas"]["ImageOutput"]; + tile_to_properties: components["schemas"]["TileToPropertiesOutput"]; + sdxl_compel_prompt: components["schemas"]["ConditioningOutput"]; + model_identifier: components["schemas"]["ModelIdentifierOutput"]; + canvas_paste_back: components["schemas"]["ImageOutput"]; + string: components["schemas"]["StringOutput"]; + latents: components["schemas"]["LatentsOutput"]; + img_ilerp: components["schemas"]["ImageOutput"]; + collect: components["schemas"]["CollectInvocationOutput"]; + face_identifier: components["schemas"]["ImageOutput"]; + img_lerp: components["schemas"]["ImageOutput"]; + l2i: components["schemas"]["ImageOutput"]; + float_math: components["schemas"]["FloatOutput"]; + unsharp_mask: components["schemas"]["ImageOutput"]; + clip_skip: components["schemas"]["CLIPSkipInvocationOutput"]; + esrgan: components["schemas"]["ImageOutput"]; + image_collection: components["schemas"]["ImageCollectionOutput"]; + vae_loader: components["schemas"]["VAEOutput"]; + mask_combine: components["schemas"]["ImageOutput"]; + infill_lama: components["schemas"]["ImageOutput"]; + integer_math: components["schemas"]["IntegerOutput"]; + core_metadata: components["schemas"]["MetadataOutput"]; + sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"]; + color_map_image_processor: components["schemas"]["ImageOutput"]; + mask_from_id: components["schemas"]["ImageOutput"]; + depth_anything_image_processor: components["schemas"]["ImageOutput"]; + noise: components["schemas"]["NoiseOutput"]; + mask_edge: components["schemas"]["ImageOutput"]; + controlnet: components["schemas"]["ControlOutput"]; + merge_metadata: components["schemas"]["MetadataOutput"]; + string_join_three: components["schemas"]["StringOutput"]; + mlsd_image_processor: components["schemas"]["ImageOutput"]; + rectangle_mask: components["schemas"]["MaskOutput"]; + img_resize: components["schemas"]["ImageOutput"]; + range_of_size: components["schemas"]["IntegerCollectionOutput"]; + infill_rgba: components["schemas"]["ImageOutput"]; + heuristic_resize: components["schemas"]["ImageOutput"]; + img_pad_crop: components["schemas"]["ImageOutput"]; + lineart_image_processor: components["schemas"]["ImageOutput"]; + infill_cv2: components["schemas"]["ImageOutput"]; + ip_adapter: components["schemas"]["IPAdapterOutput"]; + ideal_size: components["schemas"]["IdealSizeOutput"]; + div: components["schemas"]["IntegerOutput"]; + float_range: components["schemas"]["FloatCollectionOutput"]; + seamless: components["schemas"]["SeamlessModeOutput"]; + pair_tile_image: components["schemas"]["PairTileImageOutput"]; + invert_tensor_mask: components["schemas"]["MaskOutput"]; + add: components["schemas"]["IntegerOutput"]; + main_model_loader: components["schemas"]["ModelLoaderOutput"]; + face_off: components["schemas"]["FaceOffOutput"]; + integer: components["schemas"]["IntegerOutput"]; + img_blur: components["schemas"]["ImageOutput"]; + img_watermark: components["schemas"]["ImageOutput"]; + lora_selector: components["schemas"]["LoRASelectorOutput"]; + dw_openpose_image_processor: components["schemas"]["ImageOutput"]; + img_chan: components["schemas"]["ImageOutput"]; + string_replace: components["schemas"]["StringOutput"]; + sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"]; + content_shuffle_image_processor: components["schemas"]["ImageOutput"]; + lora_loader: components["schemas"]["LoRALoaderOutput"]; + infill_tile: components["schemas"]["ImageOutput"]; + tomask: components["schemas"]["ImageOutput"]; + sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"]; + merge_tiles_to_image: components["schemas"]["ImageOutput"]; + lineart_anime_image_processor: components["schemas"]["ImageOutput"]; + round_float: components["schemas"]["FloatOutput"]; + img_channel_multiply: components["schemas"]["ImageOutput"]; + img_scale: components["schemas"]["ImageOutput"]; + iterate: components["schemas"]["IterateInvocationOutput"]; }; /** * InvocationStartedEvent @@ -10671,8 +10671,9 @@ export type components = { /** * Size * @description The size of this file, in bytes + * @default 0 */ - size: number; + size?: number | null; /** * Sha256 * @description SHA256 hash of this model (not always available) @@ -12301,7 +12302,7 @@ export type components = { * @description Submodel type. * @enum {string} */ - SubModelType: "unet" | "text_encoder" | "text_encoder_2" | "tokenizer" | "tokenizer_2" | "vae" | "vae_decoder" | "vae_encoder" | "scheduler" | "safety_checker"; + SubModelType: "unet" | "text_encoder" | "text_encoder_2" | "text_encoder_3" | "tokenizer" | "tokenizer_2" | "tokenizer_3" | "transformer" | "vae" | "vae_decoder" | "vae_encoder" | "scheduler" | "safety_checker"; /** * Subtract Integers * @description Subtracts two numbers diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 90ddf3cca1..7a8a5c0ddd 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -109,7 +109,7 @@ export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is Ma }; export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => { - return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2'); + return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2' || config.base === 'sd-3'); }; export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => { From 0c970bc8802059e76f1351142db49bf220b680bb Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Fri, 14 Jun 2024 22:21:09 +0530 Subject: [PATCH 2/4] wip: add SD3 Model Loader Invocation --- invokeai/app/invocations/fields.py | 3 + invokeai/app/invocations/sd3.py | 54 +++ .../Invocation/fields/InputFieldRenderer.tsx | 7 + .../SD3MainModelFieldInputComponent.tsx | 55 +++ .../web/src/features/nodes/types/constants.ts | 2 + .../web/src/features/nodes/types/field.ts | 31 ++ .../features/nodes/types/v1/fieldTypeMap.ts | 5 + .../src/features/nodes/types/v1/workflowV1.ts | 7 + .../web/src/features/nodes/types/v2/field.ts | 17 + .../util/schema/buildFieldInputInstance.ts | 1 + .../util/schema/buildFieldInputTemplate.ts | 16 + .../nodes/util/workflow/validateWorkflow.ts | 1 + .../src/services/api/hooks/modelsByType.ts | 2 + .../frontend/web/src/services/api/schema.ts | 353 +++++++++++------- .../frontend/web/src/services/api/types.ts | 8 + 15 files changed, 426 insertions(+), 136 deletions(-) create mode 100644 invokeai/app/invocations/sd3.py create mode 100644 invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SD3MainModelFieldInputComponent.tsx diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 0fa0216f1c..5803696c9f 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -42,6 +42,7 @@ class UIType(str, Enum, metaclass=MetaEnum): MainModel = "MainModelField" SDXLMainModel = "SDXLMainModelField" SDXLRefinerModel = "SDXLRefinerModelField" + SD3MainModel = "SD3MainModelField" ONNXModel = "ONNXModelField" VAEModel = "VAEModelField" LoRAModel = "LoRAModelField" @@ -125,6 +126,7 @@ class FieldDescriptions: noise = "Noise tensor" clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" unet = "UNet (scheduler, LoRAs)" + transformer = "Transformer" vae = "VAE" cond = "Conditioning tensor" controlnet_model = "ControlNet model to load" @@ -133,6 +135,7 @@ class FieldDescriptions: main_model = "Main model (UNet, VAE, CLIP) to load" sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" + sd3_main_model = "SD3 Main Model (Transformer, CLIP1, CLIP2, CLIP3, VAE) to load" onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" lora_weight = "The weight at which the LoRA is applied to each model" compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor" diff --git a/invokeai/app/invocations/sd3.py b/invokeai/app/invocations/sd3.py new file mode 100644 index 0000000000..72089f05f0 --- /dev/null +++ b/invokeai/app/invocations/sd3.py @@ -0,0 +1,54 @@ +from pydantic import BaseModel, Field + +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output +from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType +from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, VAEField +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.config import SubModelType + + +class TransformerField(BaseModel): + transformer: ModelIdentifierField = Field(description="Info to load unet submodel") + scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel") + + +@invocation_output("sd3_model_loader_output") +class SD3ModelLoaderOutput(BaseInvocationOutput): + """Stable Diffuion 3 base model loader output""" + + transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer") + clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1") + clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2") + clip3: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 3") + vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") + + +@invocation("sd3_model_loader", title="SD3 Main Model", tags=["model", "sd3"], category="model", version="1.0.0") +class SD3ModelLoaderInvocation(BaseInvocation): + """Loads an SD3 base model, outputting its submodels.""" + + model: ModelIdentifierField = InputField(description=FieldDescriptions.sd3_main_model, ui_type=UIType.SD3MainModel) + + def invoke(self, context: InvocationContext) -> SD3ModelLoaderOutput: + model_key = self.model.key + + if not context.models.exists(model_key): + raise Exception(f"Unknown model: {model_key}") + + transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer}) + scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler}) + tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer}) + text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder}) + tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2}) + text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2}) + tokenizer3 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3}) + text_encoder3 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3}) + vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE}) + + return SD3ModelLoaderOutput( + transformer=TransformerField(transformer=transformer, scheduler=scheduler), + clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0), + clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0), + clip3=CLIPField(tokenizer=tokenizer3, text_encoder=text_encoder3, loras=[], skipped_layers=0), + vae=VAEField(vae=vae), + ) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index 99937ceec4..810ec3ffff 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -28,6 +28,8 @@ import { isModelIdentifierFieldInputTemplate, isSchedulerFieldInputInstance, isSchedulerFieldInputTemplate, + isSD3MainModelFieldInputInstance, + isSD3MainModelFieldInputTemplate, isSDXLMainModelFieldInputInstance, isSDXLMainModelFieldInputTemplate, isSDXLRefinerModelFieldInputInstance, @@ -53,6 +55,7 @@ import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent' import NumberFieldInputComponent from './inputs/NumberFieldInputComponent'; import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent'; import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent'; +import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent'; import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent'; import StringFieldInputComponent from './inputs/StringFieldInputComponent'; import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent'; @@ -133,6 +136,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { return ; } + if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) { + return ; + } + if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) { return ; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SD3MainModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SD3MainModelFieldInputComponent.tsx new file mode 100644 index 0000000000..95feb08ae9 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SD3MainModelFieldInputComponent.tsx @@ -0,0 +1,55 @@ +import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; +import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field'; +import { memo, useCallback } from 'react'; +import { useSD3Models } from 'services/api/hooks/modelsByType'; +import type { MainModelConfig } from 'services/api/types'; + +import type { FieldComponentProps } from './types'; + +type Props = FieldComponentProps; + +const SD3MainModelFieldInputComponent = (props: Props) => { + const { nodeId, field } = props; + const dispatch = useAppDispatch(); + const [modelConfigs, { isLoading }] = useSD3Models(); + const _onChange = useCallback( + (value: MainModelConfig | null) => { + if (!value) { + return; + } + dispatch( + fieldMainModelValueChanged({ + nodeId, + fieldName: field.name, + value, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ + modelConfigs, + onChange: _onChange, + isLoading, + selectedModel: field.value, + }); + + return ( + + + + + + ); +}; + +export default memo(SD3MainModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 4ede5cd479..5ba3733571 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -32,6 +32,7 @@ export const MODEL_TYPES = [ 'LoRAModelField', 'MainModelField', 'SDXLMainModelField', + 'SD3MainModelField', 'SDXLRefinerModelField', 'VaeModelField', 'UNetField', @@ -62,6 +63,7 @@ export const FIELD_COLORS: { [key: string]: string } = { MainModelField: 'teal.500', SDXLMainModelField: 'teal.500', SDXLRefinerModelField: 'teal.500', + SD3MainModelField: 'teal.500', StringField: 'yellow.500', T2IAdapterField: 'teal.500', T2IAdapterModelField: 'teal.500', diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index e2a84e3390..ae0d9edb01 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -119,6 +119,10 @@ const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ name: z.literal('SDXLRefinerModelField'), originalType: zStatelessFieldType.optional(), }); +const zSD3MainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SD3MainModelField'), + originalType: zStatelessFieldType.optional(), +}); const zVAEModelFieldType = zFieldTypeBase.extend({ name: z.literal('VAEModelField'), originalType: zStatelessFieldType.optional(), @@ -155,6 +159,7 @@ const zStatefulFieldType = z.union([ zMainModelFieldType, zSDXLMainModelFieldType, zSDXLRefinerModelFieldType, + zSD3MainModelFieldType, zVAEModelFieldType, zLoRAModelFieldType, zControlNetModelFieldType, @@ -466,6 +471,28 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR zSDXLRefinerModelFieldInputTemplate.safeParse(val).success; // #endregion +// #region SD3MainModelField + +const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SD3 models only. +const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + value: zSD3MainModelFieldValue, +}); +const zSD3MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSD3MainModelFieldType, + originalType: zFieldType.optional(), + default: zSD3MainModelFieldValue, +}); +const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSD3MainModelFieldType, +}); +export type SD3MainModelFieldInputInstance = z.infer; +export type SD3MainModelFieldInputTemplate = z.infer; +export const isSD3MainModelFieldInputInstance = (val: unknown): val is SD3MainModelFieldInputInstance => + zSD3MainModelFieldInputInstance.safeParse(val).success; +export const isSD3MainModelFieldInputTemplate = (val: unknown): val is SD3MainModelFieldInputTemplate => + zSD3MainModelFieldInputTemplate.safeParse(val).success; +// #endregion + // #region VAEModelField export const zVAEModelFieldValue = zModelIdentifierField.optional(); @@ -662,6 +689,7 @@ export const zStatefulFieldValue = z.union([ zMainModelFieldValue, zSDXLMainModelFieldValue, zSDXLRefinerModelFieldValue, + zSD3MainModelFieldValue, zVAEModelFieldValue, zLoRAModelFieldValue, zControlNetModelFieldValue, @@ -689,6 +717,7 @@ const zStatefulFieldInputInstance = z.union([ zMainModelFieldInputInstance, zSDXLMainModelFieldInputInstance, zSDXLRefinerModelFieldInputInstance, + zSD3MainModelFieldInputInstance, zVAEModelFieldInputInstance, zLoRAModelFieldInputInstance, zControlNetModelFieldInputInstance, @@ -717,6 +746,7 @@ const zStatefulFieldInputTemplate = z.union([ zMainModelFieldInputTemplate, zSDXLMainModelFieldInputTemplate, zSDXLRefinerModelFieldInputTemplate, + zSD3MainModelFieldInputTemplate, zVAEModelFieldInputTemplate, zLoRAModelFieldInputTemplate, zControlNetModelFieldInputTemplate, @@ -746,6 +776,7 @@ const zStatefulFieldOutputTemplate = z.union([ zMainModelFieldOutputTemplate, zSDXLMainModelFieldOutputTemplate, zSDXLRefinerModelFieldOutputTemplate, + zSD3MainModelFieldOutputTemplate, zVAEModelFieldOutputTemplate, zLoRAModelFieldOutputTemplate, zControlNetModelFieldOutputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts b/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts index f1d4e61300..00f3ccb67d 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts @@ -124,6 +124,11 @@ const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: { isCollection: false, isCollectionOrScalar: false, }, + SD3MainModelField: { + name: 'SD3MainModelField', + isCollection: false, + isCollectionOrScalar: false, + }, string: { name: 'StringField', isCollection: false, diff --git a/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts b/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts index c7a50b20e4..f433ad640c 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts @@ -90,6 +90,7 @@ const zFieldTypeV1 = z.enum([ 'Scheduler', 'SDXLMainModelField', 'SDXLRefinerModelField', + 'SD3MainModelField', 'string', 'StringCollection', 'StringPolymorphic', @@ -422,6 +423,11 @@ const zSDXLRefinerModelInputFieldValue = zInputFieldValueBase.extend({ value: zMainOrOnnxModel.optional(), // TODO: should narrow this down to a refiner model }); +const zSD3MainModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('SD3MainModelField'), + value: zMainOrOnnxModel.optional(), +}); + const zVaeModelField = zModelIdentifier; const zVaeModelInputFieldValue = zInputFieldValueBase.extend({ @@ -573,6 +579,7 @@ const zInputFieldValue = z.discriminatedUnion('type', [ zSchedulerInputFieldValue, zSDXLMainModelInputFieldValue, zSDXLRefinerModelInputFieldValue, + zSD3MainModelInputFieldValue, zStringCollectionInputFieldValue, zStringPolymorphicInputFieldValue, zStringInputFieldValue, diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/field.ts b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts index 4b680d1de3..15df9db85b 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v2/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts @@ -217,6 +217,20 @@ const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ }); // #endregion +// #region SDXLMainModelField +const zSD3MainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SD3MainModelField'), +}); +const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SD3 models only. +const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSD3MainModelFieldType, + value: zSD3MainModelFieldValue, +}); +const zSD3MainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSD3MainModelFieldType, +}); +// #endregion + // #region VAEModelField const zVAEModelFieldType = zFieldTypeBase.extend({ name: z.literal('VAEModelField'), @@ -339,6 +353,7 @@ const zStatefulFieldType = z.union([ zMainModelFieldType, zSDXLMainModelFieldType, zSDXLRefinerModelFieldType, + zSD3MainModelFieldType, zVAEModelFieldType, zLoRAModelFieldType, zControlNetModelFieldType, @@ -378,6 +393,7 @@ const zStatefulFieldInputInstance = z.union([ zMainModelFieldInputInstance, zSDXLMainModelFieldInputInstance, zSDXLRefinerModelFieldInputInstance, + zSD3MainModelFieldInputInstance, zVAEModelFieldInputInstance, zLoRAModelFieldInputInstance, zControlNetModelFieldInputInstance, @@ -402,6 +418,7 @@ const zStatefulFieldOutputInstance = z.union([ zMainModelFieldOutputInstance, zSDXLMainModelFieldOutputInstance, zSDXLRefinerModelFieldOutputInstance, + zSD3MainModelFieldOutputInstance, zVAEModelFieldOutputInstance, zLoRAModelFieldOutputInstance, zControlNetModelFieldOutputInstance, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index 597779fd61..ecee28f802 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -15,6 +15,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record = MainModelField: undefined, SchedulerField: 'euler', SDXLMainModelField: undefined, + SD3MainModelField: undefined, SDXLRefinerModelField: undefined, StringField: '', T2IAdapterModelField: undefined, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts index 2b77274526..12d150ab12 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts @@ -15,6 +15,7 @@ import type { MainModelFieldInputTemplate, ModelIdentifierFieldInputTemplate, SchedulerFieldInputTemplate, + SD3MainModelFieldInputTemplate, SDXLMainModelFieldInputTemplate, SDXLRefinerModelFieldInputTemplate, StatefulFieldType, @@ -193,6 +194,20 @@ const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder = ({ + schemaObject, + baseField, + fieldType, +}) => { + const template: SD3MainModelFieldInputTemplate = { + ...baseField, + type: fieldType, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, @@ -375,6 +390,7 @@ export const TEMPLATE_BUILDER_MAP: Record { + return config.type === 'main' && config.base === 'sd-3'; +}; + +export const isNonSD3MainModelConfig = (config: AnyModelConfig): config is MainModelConfig => { + return config.type === 'main' && !(config.base === 'sd-3'); +}; + export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => { return config.type === 'embedding'; }; From ddbd2ebd9d1f2fdceb2e7efe86f5b832e3807dd1 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Fri, 14 Jun 2024 22:25:26 +0530 Subject: [PATCH 3/4] wip: add Transformer Field to Node UI --- invokeai/app/invocations/model.py | 13 +- invokeai/app/invocations/sd3.py | 9 +- .../web/src/features/nodes/types/constants.ts | 2 + .../features/nodes/types/v1/fieldTypeMap.ts | 5 + .../src/features/nodes/types/v1/workflowV1.ts | 13 + .../frontend/web/src/services/api/schema.ts | 262 +++++++++--------- invokeai/invocation_api/__init__.py | 2 + 7 files changed, 160 insertions(+), 146 deletions(-) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 94a6136fcb..f8450c90a5 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -8,13 +8,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.shared.models import FreeUConfig from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType -from .baseinvocation import ( - BaseInvocation, - BaseInvocationOutput, - Classification, - invocation, - invocation_output, -) +from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output class ModelIdentifierField(BaseModel): @@ -54,6 +48,11 @@ class UNetField(BaseModel): freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration") +class TransformerField(BaseModel): + transformer: ModelIdentifierField = Field(description="Info to load unet submodel") + scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel") + + class CLIPField(BaseModel): tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel") text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel") diff --git a/invokeai/app/invocations/sd3.py b/invokeai/app/invocations/sd3.py index 72089f05f0..dbc10d2e8f 100644 --- a/invokeai/app/invocations/sd3.py +++ b/invokeai/app/invocations/sd3.py @@ -1,17 +1,10 @@ -from pydantic import BaseModel, Field - from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType -from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, VAEField +from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, TransformerField, VAEField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager.config import SubModelType -class TransformerField(BaseModel): - transformer: ModelIdentifierField = Field(description="Info to load unet submodel") - scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel") - - @invocation_output("sd3_model_loader_output") class SD3ModelLoaderOutput(BaseInvocationOutput): """Stable Diffuion 3 base model loader output""" diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 5ba3733571..ccb7cae736 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -36,6 +36,7 @@ export const MODEL_TYPES = [ 'SDXLRefinerModelField', 'VaeModelField', 'UNetField', + 'TransformerField', 'VAEField', 'CLIPField', 'T2IAdapterModelField', @@ -68,6 +69,7 @@ export const FIELD_COLORS: { [key: string]: string } = { T2IAdapterField: 'teal.500', T2IAdapterModelField: 'teal.500', UNetField: 'red.500', + TransformerField: 'red.500', VAEField: 'blue.500', VAEModelField: 'teal.500', }; diff --git a/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts b/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts index 00f3ccb67d..0c69d987c6 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts @@ -298,6 +298,11 @@ const FIELD_TYPE_V1_TO_STATELESS_FIELD_TYPE_V2: { isCollection: false, isCollectionOrScalar: false, }, + TransformerField: { + name: 'TransformerField', + isCollection: false, + isCollectionOrScalar: false, + }, VaeField: { name: 'VaeField', isCollection: false, diff --git a/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts b/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts index f433ad640c..bfacf86d65 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts @@ -99,6 +99,7 @@ const zFieldTypeV1 = z.enum([ 'T2IAdapterModelField', 'T2IAdapterPolymorphic', 'UNetField', + 'TransformerField', 'VaeField', 'VaeModelField', ]); @@ -367,6 +368,17 @@ const zUNetInputFieldValue = zInputFieldValueBase.extend({ value: zUNetField.optional(), }); +const zTransformerField = z.object({ + unet: zModelInfo, + scheduler: zModelInfo, + loras: z.array(zLoraInfo), +}); + +const zTransformerInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('TransformerField'), + value: zTransformerField.optional(), +}); + const zClipField = z.object({ tokenizer: zModelInfo, text_encoder: zModelInfo, @@ -588,6 +600,7 @@ const zInputFieldValue = z.discriminatedUnion('type', [ zT2IAdapterCollectionInputFieldValue, zT2IAdapterPolymorphicInputFieldValue, zUNetInputFieldValue, + zTransformerInputFieldValue, zVaeInputFieldValue, zVaeModelInputFieldValue, zMetadataItemInputFieldValue, diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 4e27679c0f..dc39e63920 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -7276,145 +7276,145 @@ export type components = { project_id: string | null; }; InvocationOutputMap: { - float_to_int: components["schemas"]["IntegerOutput"]; - range_of_size: components["schemas"]["IntegerCollectionOutput"]; - img_hue_adjust: components["schemas"]["ImageOutput"]; - hed_image_processor: components["schemas"]["ImageOutput"]; - img_blur: components["schemas"]["ImageOutput"]; infill_tile: components["schemas"]["ImageOutput"]; - conditioning_collection: components["schemas"]["ConditioningCollectionOutput"]; - div: components["schemas"]["IntegerOutput"]; - color_correct: components["schemas"]["ImageOutput"]; - calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"]; - boolean_collection: components["schemas"]["BooleanCollectionOutput"]; - image_collection: components["schemas"]["ImageCollectionOutput"]; - img_mul: components["schemas"]["ImageOutput"]; - sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"]; - img_lerp: components["schemas"]["ImageOutput"]; - l2i: components["schemas"]["ImageOutput"]; - string_collection: components["schemas"]["StringCollectionOutput"]; - face_off: components["schemas"]["FaceOffOutput"]; - sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"]; - normalbae_image_processor: components["schemas"]["ImageOutput"]; - dynamic_prompt: components["schemas"]["StringCollectionOutput"]; - float_math: components["schemas"]["FloatOutput"]; - ideal_size: components["schemas"]["IdealSizeOutput"]; - sub: components["schemas"]["IntegerOutput"]; - string: components["schemas"]["StringOutput"]; - core_metadata: components["schemas"]["MetadataOutput"]; - latents: components["schemas"]["LatentsOutput"]; - crop_latents: components["schemas"]["LatentsOutput"]; - denoise_latents: components["schemas"]["LatentsOutput"]; - range: components["schemas"]["IntegerCollectionOutput"]; - unsharp_mask: components["schemas"]["ImageOutput"]; - pidi_image_processor: components["schemas"]["ImageOutput"]; - float_collection: components["schemas"]["FloatCollectionOutput"]; - i2l: components["schemas"]["LatentsOutput"]; - face_identifier: components["schemas"]["ImageOutput"]; - step_param_easing: components["schemas"]["FloatCollectionOutput"]; - img_pad_crop: components["schemas"]["ImageOutput"]; - lineart_image_processor: components["schemas"]["ImageOutput"]; - infill_rgba: components["schemas"]["ImageOutput"]; - lblend: components["schemas"]["LatentsOutput"]; - mlsd_image_processor: components["schemas"]["ImageOutput"]; - lresize: components["schemas"]["LatentsOutput"]; + model_identifier: components["schemas"]["ModelIdentifierOutput"]; + tile_image_processor: components["schemas"]["ImageOutput"]; mask_combine: components["schemas"]["ImageOutput"]; - string_replace: components["schemas"]["StringOutput"]; - conditioning: components["schemas"]["ConditioningOutput"]; - scheduler: components["schemas"]["SchedulerOutput"]; - add: components["schemas"]["IntegerOutput"]; - metadata: components["schemas"]["MetadataOutput"]; - random_range: components["schemas"]["IntegerCollectionOutput"]; - img_ilerp: components["schemas"]["ImageOutput"]; - canvas_paste_back: components["schemas"]["ImageOutput"]; - mask_from_id: components["schemas"]["ImageOutput"]; - tile_to_properties: components["schemas"]["TileToPropertiesOutput"]; - sdxl_compel_prompt: components["schemas"]["ConditioningOutput"]; - img_resize: components["schemas"]["ImageOutput"]; - mul: components["schemas"]["IntegerOutput"]; - integer_collection: components["schemas"]["IntegerCollectionOutput"]; - infill_patchmatch: components["schemas"]["ImageOutput"]; - t2i_adapter: components["schemas"]["T2IAdapterOutput"]; - lora_loader: components["schemas"]["LoRALoaderOutput"]; - iterate: components["schemas"]["IterateInvocationOutput"]; - depth_anything_image_processor: components["schemas"]["ImageOutput"]; - content_shuffle_image_processor: components["schemas"]["ImageOutput"]; - string_join: components["schemas"]["StringOutput"]; - esrgan: components["schemas"]["ImageOutput"]; - dw_openpose_image_processor: components["schemas"]["ImageOutput"]; - round_float: components["schemas"]["FloatOutput"]; - noise: components["schemas"]["NoiseOutput"]; - img_channel_offset: components["schemas"]["ImageOutput"]; - calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"]; - cv_inpaint: components["schemas"]["ImageOutput"]; - lineart_anime_image_processor: components["schemas"]["ImageOutput"]; - lora_selector: components["schemas"]["LoRASelectorOutput"]; + invert_tensor_mask: components["schemas"]["MaskOutput"]; + img_chan: components["schemas"]["ImageOutput"]; + sub: components["schemas"]["IntegerOutput"]; + mediapipe_face_processor: components["schemas"]["ImageOutput"]; + compel: components["schemas"]["ConditioningOutput"]; + sd3_model_loader: components["schemas"]["SD3ModelLoaderOutput"]; + rand_float: components["schemas"]["FloatOutput"]; + zoe_depth_image_processor: components["schemas"]["ImageOutput"]; + infill_rgba: components["schemas"]["ImageOutput"]; + color_map_image_processor: components["schemas"]["ImageOutput"]; + img_hue_adjust: components["schemas"]["ImageOutput"]; + lineart_image_processor: components["schemas"]["ImageOutput"]; + metadata_item: components["schemas"]["MetadataItemOutput"]; float: components["schemas"]["FloatOutput"]; - merge_metadata: components["schemas"]["MetadataOutput"]; + create_gradient_mask: components["schemas"]["GradientMaskOutput"]; + crop_latents: components["schemas"]["LatentsOutput"]; + segment_anything_processor: components["schemas"]["ImageOutput"]; + sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"]; + string_join: components["schemas"]["StringOutput"]; + heuristic_resize: components["schemas"]["ImageOutput"]; + lblend: components["schemas"]["LatentsOutput"]; + lineart_anime_image_processor: components["schemas"]["ImageOutput"]; + string_split_neg: components["schemas"]["StringPosNegOutput"]; + alpha_mask_to_tensor: components["schemas"]["MaskOutput"]; + infill_lama: components["schemas"]["ImageOutput"]; + float_collection: components["schemas"]["FloatCollectionOutput"]; + conditioning_collection: components["schemas"]["ConditioningCollectionOutput"]; + lscale: components["schemas"]["LatentsOutput"]; + clip_skip: components["schemas"]["CLIPSkipInvocationOutput"]; + float_to_int: components["schemas"]["IntegerOutput"]; + float_math: components["schemas"]["FloatOutput"]; + collect: components["schemas"]["CollectInvocationOutput"]; + boolean: components["schemas"]["BooleanOutput"]; + latents: components["schemas"]["LatentsOutput"]; + blank_image: components["schemas"]["ImageOutput"]; + vae_loader: components["schemas"]["VAEOutput"]; + denoise_latents: components["schemas"]["LatentsOutput"]; + dw_openpose_image_processor: components["schemas"]["ImageOutput"]; + range_of_size: components["schemas"]["IntegerCollectionOutput"]; + face_mask_detection: components["schemas"]["FaceMaskOutput"]; + tomask: components["schemas"]["ImageOutput"]; + rectangle_mask: components["schemas"]["MaskOutput"]; + controlnet: components["schemas"]["ControlOutput"]; + seamless: components["schemas"]["SeamlessModeOutput"]; + pair_tile_image: components["schemas"]["PairTileImageOutput"]; + unsharp_mask: components["schemas"]["ImageOutput"]; + hed_image_processor: components["schemas"]["ImageOutput"]; + metadata: components["schemas"]["MetadataOutput"]; + freeu: components["schemas"]["UNetOutput"]; + image_collection: components["schemas"]["ImageCollectionOutput"]; + dynamic_prompt: components["schemas"]["StringCollectionOutput"]; + face_off: components["schemas"]["FaceOffOutput"]; + sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"]; + show_image: components["schemas"]["ImageOutput"]; + img_nsfw: components["schemas"]["ImageOutput"]; + round_float: components["schemas"]["FloatOutput"]; + string: components["schemas"]["StringOutput"]; + calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"]; + img_crop: components["schemas"]["ImageOutput"]; + mask_edge: components["schemas"]["ImageOutput"]; + normalbae_image_processor: components["schemas"]["ImageOutput"]; + save_image: components["schemas"]["ImageOutput"]; + add: components["schemas"]["IntegerOutput"]; + main_model_loader: components["schemas"]["ModelLoaderOutput"]; + color: components["schemas"]["ColorOutput"]; + string_replace: components["schemas"]["StringOutput"]; + img_lerp: components["schemas"]["ImageOutput"]; + midas_depth_image_processor: components["schemas"]["ImageOutput"]; + infill_patchmatch: components["schemas"]["ImageOutput"]; + noise: components["schemas"]["NoiseOutput"]; + img_watermark: components["schemas"]["ImageOutput"]; + depth_anything_image_processor: components["schemas"]["ImageOutput"]; + i2l: components["schemas"]["LatentsOutput"]; + tile_to_properties: components["schemas"]["TileToPropertiesOutput"]; + canvas_paste_back: components["schemas"]["ImageOutput"]; + mul: components["schemas"]["IntegerOutput"]; + pidi_image_processor: components["schemas"]["ImageOutput"]; + sdxl_compel_prompt: components["schemas"]["ConditioningOutput"]; + img_conv: components["schemas"]["ImageOutput"]; + sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"]; + mask_from_id: components["schemas"]["ImageOutput"]; + lora_loader: components["schemas"]["LoRALoaderOutput"]; + step_param_easing: components["schemas"]["FloatCollectionOutput"]; + face_identifier: components["schemas"]["ImageOutput"]; + calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"]; + esrgan: components["schemas"]["ImageOutput"]; + color_correct: components["schemas"]["ImageOutput"]; + lora_selector: components["schemas"]["LoRASelectorOutput"]; + cv_inpaint: components["schemas"]["ImageOutput"]; + img_pad_crop: components["schemas"]["ImageOutput"]; + merge_tiles_to_image: components["schemas"]["ImageOutput"]; + img_channel_offset: components["schemas"]["ImageOutput"]; + string_collection: components["schemas"]["StringCollectionOutput"]; + scheduler: components["schemas"]["SchedulerOutput"]; + conditioning: components["schemas"]["ConditioningOutput"]; + string_split: components["schemas"]["String2Output"]; + string_join_three: components["schemas"]["StringOutput"]; + img_ilerp: components["schemas"]["ImageOutput"]; + lora_collection_loader: components["schemas"]["LoRALoaderOutput"]; + core_metadata: components["schemas"]["MetadataOutput"]; float_range: components["schemas"]["FloatCollectionOutput"]; + random_range: components["schemas"]["IntegerCollectionOutput"]; + rand_int: components["schemas"]["IntegerOutput"]; + canny_image_processor: components["schemas"]["ImageOutput"]; + merge_metadata: components["schemas"]["MetadataOutput"]; + latents_collection: components["schemas"]["LatentsCollectionOutput"]; + range: components["schemas"]["IntegerCollectionOutput"]; + iterate: components["schemas"]["IterateInvocationOutput"]; + img_scale: components["schemas"]["ImageOutput"]; + img_blur: components["schemas"]["ImageOutput"]; + img_channel_multiply: components["schemas"]["ImageOutput"]; + integer_math: components["schemas"]["IntegerOutput"]; + calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"]; + img_mul: components["schemas"]["ImageOutput"]; + mlsd_image_processor: components["schemas"]["ImageOutput"]; ip_adapter: components["schemas"]["IPAdapterOutput"]; + sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"]; + content_shuffle_image_processor: components["schemas"]["ImageOutput"]; + infill_cv2: components["schemas"]["ImageOutput"]; + prompt_from_file: components["schemas"]["StringCollectionOutput"]; + image: components["schemas"]["ImageOutput"]; + img_resize: components["schemas"]["ImageOutput"]; + boolean_collection: components["schemas"]["BooleanCollectionOutput"]; + lresize: components["schemas"]["LatentsOutput"]; + l2i: components["schemas"]["ImageOutput"]; + integer_collection: components["schemas"]["IntegerCollectionOutput"]; + t2i_adapter: components["schemas"]["T2IAdapterOutput"]; + div: components["schemas"]["IntegerOutput"]; + leres_image_processor: components["schemas"]["ImageOutput"]; + sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"]; + ideal_size: components["schemas"]["IdealSizeOutput"]; + integer: components["schemas"]["IntegerOutput"]; create_denoise_mask: components["schemas"]["DenoiseMaskOutput"]; img_paste: components["schemas"]["ImageOutput"]; - save_image: components["schemas"]["ImageOutput"]; - color_map_image_processor: components["schemas"]["ImageOutput"]; - rand_float: components["schemas"]["FloatOutput"]; - midas_depth_image_processor: components["schemas"]["ImageOutput"]; - blank_image: components["schemas"]["ImageOutput"]; - sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"]; - rectangle_mask: components["schemas"]["MaskOutput"]; - collect: components["schemas"]["CollectInvocationOutput"]; - tomask: components["schemas"]["ImageOutput"]; - model_identifier: components["schemas"]["ModelIdentifierOutput"]; - lora_collection_loader: components["schemas"]["LoRALoaderOutput"]; - rand_int: components["schemas"]["IntegerOutput"]; - sd3_model_loader: components["schemas"]["SD3ModelLoaderOutput"]; - infill_lama: components["schemas"]["ImageOutput"]; - heuristic_resize: components["schemas"]["ImageOutput"]; - latents_collection: components["schemas"]["LatentsCollectionOutput"]; - face_mask_detection: components["schemas"]["FaceMaskOutput"]; - vae_loader: components["schemas"]["VAEOutput"]; - invert_tensor_mask: components["schemas"]["MaskOutput"]; - integer: components["schemas"]["IntegerOutput"]; - img_channel_multiply: components["schemas"]["ImageOutput"]; - clip_skip: components["schemas"]["CLIPSkipInvocationOutput"]; - tile_image_processor: components["schemas"]["ImageOutput"]; - freeu: components["schemas"]["UNetOutput"]; - boolean: components["schemas"]["BooleanOutput"]; - sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"]; - mediapipe_face_processor: components["schemas"]["ImageOutput"]; - prompt_from_file: components["schemas"]["StringCollectionOutput"]; - img_nsfw: components["schemas"]["ImageOutput"]; - string_split_neg: components["schemas"]["StringPosNegOutput"]; - img_chan: components["schemas"]["ImageOutput"]; - seamless: components["schemas"]["SeamlessModeOutput"]; - img_scale: components["schemas"]["ImageOutput"]; - sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"]; - mask_edge: components["schemas"]["ImageOutput"]; - alpha_mask_to_tensor: components["schemas"]["MaskOutput"]; - create_gradient_mask: components["schemas"]["GradientMaskOutput"]; - controlnet: components["schemas"]["ControlOutput"]; - leres_image_processor: components["schemas"]["ImageOutput"]; - main_model_loader: components["schemas"]["ModelLoaderOutput"]; - calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"]; - string_split: components["schemas"]["String2Output"]; - img_watermark: components["schemas"]["ImageOutput"]; - merge_tiles_to_image: components["schemas"]["ImageOutput"]; - img_conv: components["schemas"]["ImageOutput"]; - segment_anything_processor: components["schemas"]["ImageOutput"]; image_mask_to_tensor: components["schemas"]["MaskOutput"]; - zoe_depth_image_processor: components["schemas"]["ImageOutput"]; - show_image: components["schemas"]["ImageOutput"]; - string_join_three: components["schemas"]["StringOutput"]; - pair_tile_image: components["schemas"]["PairTileImageOutput"]; - infill_cv2: components["schemas"]["ImageOutput"]; - integer_math: components["schemas"]["IntegerOutput"]; - color: components["schemas"]["ColorOutput"]; - canny_image_processor: components["schemas"]["ImageOutput"]; - img_crop: components["schemas"]["ImageOutput"]; - lscale: components["schemas"]["LatentsOutput"]; - metadata_item: components["schemas"]["MetadataItemOutput"]; - image: components["schemas"]["ImageOutput"]; - compel: components["schemas"]["ConditioningOutput"]; }; /** * InvocationStartedEvent diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index 97260c4dfe..f81016125e 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -39,6 +39,7 @@ from invokeai.app.invocations.model import ( ModelIdentifierField, ModelLoaderOutput, SDXLLoRALoaderOutput, + TransformerField, UNetField, UNetOutput, VAEField, @@ -117,6 +118,7 @@ __all__ = [ # invokeai.app.invocations.model "ModelIdentifierField", "UNetField", + "TransformerField", "CLIPField", "VAEField", "UNetOutput", From 41236031b22c293a9d37ebc402877f1c67541aa9 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sat, 15 Jun 2024 00:00:44 +0530 Subject: [PATCH 4/4] chore: remove unrequired changes to v1 workflow field types --- .../features/nodes/types/v1/fieldTypeMap.ts | 10 ---------- .../src/features/nodes/types/v1/workflowV1.ts | 20 ------------------- 2 files changed, 30 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts b/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts index 0c69d987c6..f1d4e61300 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts @@ -124,11 +124,6 @@ const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: { isCollection: false, isCollectionOrScalar: false, }, - SD3MainModelField: { - name: 'SD3MainModelField', - isCollection: false, - isCollectionOrScalar: false, - }, string: { name: 'StringField', isCollection: false, @@ -298,11 +293,6 @@ const FIELD_TYPE_V1_TO_STATELESS_FIELD_TYPE_V2: { isCollection: false, isCollectionOrScalar: false, }, - TransformerField: { - name: 'TransformerField', - isCollection: false, - isCollectionOrScalar: false, - }, VaeField: { name: 'VaeField', isCollection: false, diff --git a/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts b/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts index bfacf86d65..c7a50b20e4 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts @@ -90,7 +90,6 @@ const zFieldTypeV1 = z.enum([ 'Scheduler', 'SDXLMainModelField', 'SDXLRefinerModelField', - 'SD3MainModelField', 'string', 'StringCollection', 'StringPolymorphic', @@ -99,7 +98,6 @@ const zFieldTypeV1 = z.enum([ 'T2IAdapterModelField', 'T2IAdapterPolymorphic', 'UNetField', - 'TransformerField', 'VaeField', 'VaeModelField', ]); @@ -368,17 +366,6 @@ const zUNetInputFieldValue = zInputFieldValueBase.extend({ value: zUNetField.optional(), }); -const zTransformerField = z.object({ - unet: zModelInfo, - scheduler: zModelInfo, - loras: z.array(zLoraInfo), -}); - -const zTransformerInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('TransformerField'), - value: zTransformerField.optional(), -}); - const zClipField = z.object({ tokenizer: zModelInfo, text_encoder: zModelInfo, @@ -435,11 +422,6 @@ const zSDXLRefinerModelInputFieldValue = zInputFieldValueBase.extend({ value: zMainOrOnnxModel.optional(), // TODO: should narrow this down to a refiner model }); -const zSD3MainModelInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('SD3MainModelField'), - value: zMainOrOnnxModel.optional(), -}); - const zVaeModelField = zModelIdentifier; const zVaeModelInputFieldValue = zInputFieldValueBase.extend({ @@ -591,7 +573,6 @@ const zInputFieldValue = z.discriminatedUnion('type', [ zSchedulerInputFieldValue, zSDXLMainModelInputFieldValue, zSDXLRefinerModelInputFieldValue, - zSD3MainModelInputFieldValue, zStringCollectionInputFieldValue, zStringPolymorphicInputFieldValue, zStringInputFieldValue, @@ -600,7 +581,6 @@ const zInputFieldValue = z.discriminatedUnion('type', [ zT2IAdapterCollectionInputFieldValue, zT2IAdapterPolymorphicInputFieldValue, zUNetInputFieldValue, - zTransformerInputFieldValue, zVaeInputFieldValue, zVaeModelInputFieldValue, zMetadataItemInputFieldValue,