From 0c970bc8802059e76f1351142db49bf220b680bb Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Fri, 14 Jun 2024 22:21:09 +0530 Subject: [PATCH] wip: add SD3 Model Loader Invocation --- invokeai/app/invocations/fields.py | 3 + invokeai/app/invocations/sd3.py | 54 +++ .../Invocation/fields/InputFieldRenderer.tsx | 7 + .../SD3MainModelFieldInputComponent.tsx | 55 +++ .../web/src/features/nodes/types/constants.ts | 2 + .../web/src/features/nodes/types/field.ts | 31 ++ .../features/nodes/types/v1/fieldTypeMap.ts | 5 + .../src/features/nodes/types/v1/workflowV1.ts | 7 + .../web/src/features/nodes/types/v2/field.ts | 17 + .../util/schema/buildFieldInputInstance.ts | 1 + .../util/schema/buildFieldInputTemplate.ts | 16 + .../nodes/util/workflow/validateWorkflow.ts | 1 + .../src/services/api/hooks/modelsByType.ts | 2 + .../frontend/web/src/services/api/schema.ts | 353 +++++++++++------- .../frontend/web/src/services/api/types.ts | 8 + 15 files changed, 426 insertions(+), 136 deletions(-) create mode 100644 invokeai/app/invocations/sd3.py create mode 100644 invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SD3MainModelFieldInputComponent.tsx diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 0fa0216f1c..5803696c9f 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -42,6 +42,7 @@ class UIType(str, Enum, metaclass=MetaEnum): MainModel = "MainModelField" SDXLMainModel = "SDXLMainModelField" SDXLRefinerModel = "SDXLRefinerModelField" + SD3MainModel = "SD3MainModelField" ONNXModel = "ONNXModelField" VAEModel = "VAEModelField" LoRAModel = "LoRAModelField" @@ -125,6 +126,7 @@ class FieldDescriptions: noise = "Noise tensor" clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" unet = "UNet (scheduler, LoRAs)" + transformer = "Transformer" vae = "VAE" cond = "Conditioning tensor" controlnet_model = "ControlNet model to load" @@ -133,6 +135,7 @@ class FieldDescriptions: main_model = "Main model (UNet, VAE, CLIP) to load" sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" + sd3_main_model = "SD3 Main Model (Transformer, CLIP1, CLIP2, CLIP3, VAE) to load" onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" lora_weight = "The weight at which the LoRA is applied to each model" compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor" diff --git a/invokeai/app/invocations/sd3.py b/invokeai/app/invocations/sd3.py new file mode 100644 index 0000000000..72089f05f0 --- /dev/null +++ b/invokeai/app/invocations/sd3.py @@ -0,0 +1,54 @@ +from pydantic import BaseModel, Field + +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output +from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType +from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, VAEField +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.config import SubModelType + + +class TransformerField(BaseModel): + transformer: ModelIdentifierField = Field(description="Info to load unet submodel") + scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel") + + +@invocation_output("sd3_model_loader_output") +class SD3ModelLoaderOutput(BaseInvocationOutput): + """Stable Diffuion 3 base model loader output""" + + transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer") + clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1") + clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2") + clip3: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 3") + vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") + + +@invocation("sd3_model_loader", title="SD3 Main Model", tags=["model", "sd3"], category="model", version="1.0.0") +class SD3ModelLoaderInvocation(BaseInvocation): + """Loads an SD3 base model, outputting its submodels.""" + + model: ModelIdentifierField = InputField(description=FieldDescriptions.sd3_main_model, ui_type=UIType.SD3MainModel) + + def invoke(self, context: InvocationContext) -> SD3ModelLoaderOutput: + model_key = self.model.key + + if not context.models.exists(model_key): + raise Exception(f"Unknown model: {model_key}") + + transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer}) + scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler}) + tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer}) + text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder}) + tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2}) + text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2}) + tokenizer3 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3}) + text_encoder3 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3}) + vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE}) + + return SD3ModelLoaderOutput( + transformer=TransformerField(transformer=transformer, scheduler=scheduler), + clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0), + clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0), + clip3=CLIPField(tokenizer=tokenizer3, text_encoder=text_encoder3, loras=[], skipped_layers=0), + vae=VAEField(vae=vae), + ) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index 99937ceec4..810ec3ffff 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -28,6 +28,8 @@ import { isModelIdentifierFieldInputTemplate, isSchedulerFieldInputInstance, isSchedulerFieldInputTemplate, + isSD3MainModelFieldInputInstance, + isSD3MainModelFieldInputTemplate, isSDXLMainModelFieldInputInstance, isSDXLMainModelFieldInputTemplate, isSDXLRefinerModelFieldInputInstance, @@ -53,6 +55,7 @@ import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent' import NumberFieldInputComponent from './inputs/NumberFieldInputComponent'; import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent'; import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent'; +import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent'; import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent'; import StringFieldInputComponent from './inputs/StringFieldInputComponent'; import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent'; @@ -133,6 +136,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { return ; } + if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) { + return ; + } + if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) { return ; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SD3MainModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SD3MainModelFieldInputComponent.tsx new file mode 100644 index 0000000000..95feb08ae9 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SD3MainModelFieldInputComponent.tsx @@ -0,0 +1,55 @@ +import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; +import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field'; +import { memo, useCallback } from 'react'; +import { useSD3Models } from 'services/api/hooks/modelsByType'; +import type { MainModelConfig } from 'services/api/types'; + +import type { FieldComponentProps } from './types'; + +type Props = FieldComponentProps; + +const SD3MainModelFieldInputComponent = (props: Props) => { + const { nodeId, field } = props; + const dispatch = useAppDispatch(); + const [modelConfigs, { isLoading }] = useSD3Models(); + const _onChange = useCallback( + (value: MainModelConfig | null) => { + if (!value) { + return; + } + dispatch( + fieldMainModelValueChanged({ + nodeId, + fieldName: field.name, + value, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ + modelConfigs, + onChange: _onChange, + isLoading, + selectedModel: field.value, + }); + + return ( + + + + + + ); +}; + +export default memo(SD3MainModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 4ede5cd479..5ba3733571 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -32,6 +32,7 @@ export const MODEL_TYPES = [ 'LoRAModelField', 'MainModelField', 'SDXLMainModelField', + 'SD3MainModelField', 'SDXLRefinerModelField', 'VaeModelField', 'UNetField', @@ -62,6 +63,7 @@ export const FIELD_COLORS: { [key: string]: string } = { MainModelField: 'teal.500', SDXLMainModelField: 'teal.500', SDXLRefinerModelField: 'teal.500', + SD3MainModelField: 'teal.500', StringField: 'yellow.500', T2IAdapterField: 'teal.500', T2IAdapterModelField: 'teal.500', diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index e2a84e3390..ae0d9edb01 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -119,6 +119,10 @@ const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ name: z.literal('SDXLRefinerModelField'), originalType: zStatelessFieldType.optional(), }); +const zSD3MainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SD3MainModelField'), + originalType: zStatelessFieldType.optional(), +}); const zVAEModelFieldType = zFieldTypeBase.extend({ name: z.literal('VAEModelField'), originalType: zStatelessFieldType.optional(), @@ -155,6 +159,7 @@ const zStatefulFieldType = z.union([ zMainModelFieldType, zSDXLMainModelFieldType, zSDXLRefinerModelFieldType, + zSD3MainModelFieldType, zVAEModelFieldType, zLoRAModelFieldType, zControlNetModelFieldType, @@ -466,6 +471,28 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR zSDXLRefinerModelFieldInputTemplate.safeParse(val).success; // #endregion +// #region SD3MainModelField + +const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SD3 models only. +const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + value: zSD3MainModelFieldValue, +}); +const zSD3MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSD3MainModelFieldType, + originalType: zFieldType.optional(), + default: zSD3MainModelFieldValue, +}); +const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSD3MainModelFieldType, +}); +export type SD3MainModelFieldInputInstance = z.infer; +export type SD3MainModelFieldInputTemplate = z.infer; +export const isSD3MainModelFieldInputInstance = (val: unknown): val is SD3MainModelFieldInputInstance => + zSD3MainModelFieldInputInstance.safeParse(val).success; +export const isSD3MainModelFieldInputTemplate = (val: unknown): val is SD3MainModelFieldInputTemplate => + zSD3MainModelFieldInputTemplate.safeParse(val).success; +// #endregion + // #region VAEModelField export const zVAEModelFieldValue = zModelIdentifierField.optional(); @@ -662,6 +689,7 @@ export const zStatefulFieldValue = z.union([ zMainModelFieldValue, zSDXLMainModelFieldValue, zSDXLRefinerModelFieldValue, + zSD3MainModelFieldValue, zVAEModelFieldValue, zLoRAModelFieldValue, zControlNetModelFieldValue, @@ -689,6 +717,7 @@ const zStatefulFieldInputInstance = z.union([ zMainModelFieldInputInstance, zSDXLMainModelFieldInputInstance, zSDXLRefinerModelFieldInputInstance, + zSD3MainModelFieldInputInstance, zVAEModelFieldInputInstance, zLoRAModelFieldInputInstance, zControlNetModelFieldInputInstance, @@ -717,6 +746,7 @@ const zStatefulFieldInputTemplate = z.union([ zMainModelFieldInputTemplate, zSDXLMainModelFieldInputTemplate, zSDXLRefinerModelFieldInputTemplate, + zSD3MainModelFieldInputTemplate, zVAEModelFieldInputTemplate, zLoRAModelFieldInputTemplate, zControlNetModelFieldInputTemplate, @@ -746,6 +776,7 @@ const zStatefulFieldOutputTemplate = z.union([ zMainModelFieldOutputTemplate, zSDXLMainModelFieldOutputTemplate, zSDXLRefinerModelFieldOutputTemplate, + zSD3MainModelFieldOutputTemplate, zVAEModelFieldOutputTemplate, zLoRAModelFieldOutputTemplate, zControlNetModelFieldOutputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts b/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts index f1d4e61300..00f3ccb67d 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts @@ -124,6 +124,11 @@ const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: { isCollection: false, isCollectionOrScalar: false, }, + SD3MainModelField: { + name: 'SD3MainModelField', + isCollection: false, + isCollectionOrScalar: false, + }, string: { name: 'StringField', isCollection: false, diff --git a/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts b/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts index c7a50b20e4..f433ad640c 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts @@ -90,6 +90,7 @@ const zFieldTypeV1 = z.enum([ 'Scheduler', 'SDXLMainModelField', 'SDXLRefinerModelField', + 'SD3MainModelField', 'string', 'StringCollection', 'StringPolymorphic', @@ -422,6 +423,11 @@ const zSDXLRefinerModelInputFieldValue = zInputFieldValueBase.extend({ value: zMainOrOnnxModel.optional(), // TODO: should narrow this down to a refiner model }); +const zSD3MainModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('SD3MainModelField'), + value: zMainOrOnnxModel.optional(), +}); + const zVaeModelField = zModelIdentifier; const zVaeModelInputFieldValue = zInputFieldValueBase.extend({ @@ -573,6 +579,7 @@ const zInputFieldValue = z.discriminatedUnion('type', [ zSchedulerInputFieldValue, zSDXLMainModelInputFieldValue, zSDXLRefinerModelInputFieldValue, + zSD3MainModelInputFieldValue, zStringCollectionInputFieldValue, zStringPolymorphicInputFieldValue, zStringInputFieldValue, diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/field.ts b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts index 4b680d1de3..15df9db85b 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v2/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts @@ -217,6 +217,20 @@ const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ }); // #endregion +// #region SDXLMainModelField +const zSD3MainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SD3MainModelField'), +}); +const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SD3 models only. +const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSD3MainModelFieldType, + value: zSD3MainModelFieldValue, +}); +const zSD3MainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSD3MainModelFieldType, +}); +// #endregion + // #region VAEModelField const zVAEModelFieldType = zFieldTypeBase.extend({ name: z.literal('VAEModelField'), @@ -339,6 +353,7 @@ const zStatefulFieldType = z.union([ zMainModelFieldType, zSDXLMainModelFieldType, zSDXLRefinerModelFieldType, + zSD3MainModelFieldType, zVAEModelFieldType, zLoRAModelFieldType, zControlNetModelFieldType, @@ -378,6 +393,7 @@ const zStatefulFieldInputInstance = z.union([ zMainModelFieldInputInstance, zSDXLMainModelFieldInputInstance, zSDXLRefinerModelFieldInputInstance, + zSD3MainModelFieldInputInstance, zVAEModelFieldInputInstance, zLoRAModelFieldInputInstance, zControlNetModelFieldInputInstance, @@ -402,6 +418,7 @@ const zStatefulFieldOutputInstance = z.union([ zMainModelFieldOutputInstance, zSDXLMainModelFieldOutputInstance, zSDXLRefinerModelFieldOutputInstance, + zSD3MainModelFieldOutputInstance, zVAEModelFieldOutputInstance, zLoRAModelFieldOutputInstance, zControlNetModelFieldOutputInstance, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index 597779fd61..ecee28f802 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -15,6 +15,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record = MainModelField: undefined, SchedulerField: 'euler', SDXLMainModelField: undefined, + SD3MainModelField: undefined, SDXLRefinerModelField: undefined, StringField: '', T2IAdapterModelField: undefined, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts index 2b77274526..12d150ab12 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts @@ -15,6 +15,7 @@ import type { MainModelFieldInputTemplate, ModelIdentifierFieldInputTemplate, SchedulerFieldInputTemplate, + SD3MainModelFieldInputTemplate, SDXLMainModelFieldInputTemplate, SDXLRefinerModelFieldInputTemplate, StatefulFieldType, @@ -193,6 +194,20 @@ const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder = ({ + schemaObject, + baseField, + fieldType, +}) => { + const template: SD3MainModelFieldInputTemplate = { + ...baseField, + type: fieldType, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, @@ -375,6 +390,7 @@ export const TEMPLATE_BUILDER_MAP: Record { + return config.type === 'main' && config.base === 'sd-3'; +}; + +export const isNonSD3MainModelConfig = (config: AnyModelConfig): config is MainModelConfig => { + return config.type === 'main' && !(config.base === 'sd-3'); +}; + export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => { return config.type === 'embedding'; };