From 0c970bc8802059e76f1351142db49bf220b680bb Mon Sep 17 00:00:00 2001
From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com>
Date: Fri, 14 Jun 2024 22:21:09 +0530
Subject: [PATCH] wip: add SD3 Model Loader Invocation
---
invokeai/app/invocations/fields.py | 3 +
invokeai/app/invocations/sd3.py | 54 +++
.../Invocation/fields/InputFieldRenderer.tsx | 7 +
.../SD3MainModelFieldInputComponent.tsx | 55 +++
.../web/src/features/nodes/types/constants.ts | 2 +
.../web/src/features/nodes/types/field.ts | 31 ++
.../features/nodes/types/v1/fieldTypeMap.ts | 5 +
.../src/features/nodes/types/v1/workflowV1.ts | 7 +
.../web/src/features/nodes/types/v2/field.ts | 17 +
.../util/schema/buildFieldInputInstance.ts | 1 +
.../util/schema/buildFieldInputTemplate.ts | 16 +
.../nodes/util/workflow/validateWorkflow.ts | 1 +
.../src/services/api/hooks/modelsByType.ts | 2 +
.../frontend/web/src/services/api/schema.ts | 353 +++++++++++-------
.../frontend/web/src/services/api/types.ts | 8 +
15 files changed, 426 insertions(+), 136 deletions(-)
create mode 100644 invokeai/app/invocations/sd3.py
create mode 100644 invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SD3MainModelFieldInputComponent.tsx
diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py
index 0fa0216f1c..5803696c9f 100644
--- a/invokeai/app/invocations/fields.py
+++ b/invokeai/app/invocations/fields.py
@@ -42,6 +42,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
MainModel = "MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
+ SD3MainModel = "SD3MainModelField"
ONNXModel = "ONNXModelField"
VAEModel = "VAEModelField"
LoRAModel = "LoRAModelField"
@@ -125,6 +126,7 @@ class FieldDescriptions:
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
unet = "UNet (scheduler, LoRAs)"
+ transformer = "Transformer"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
@@ -133,6 +135,7 @@ class FieldDescriptions:
main_model = "Main model (UNet, VAE, CLIP) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
+ sd3_main_model = "SD3 Main Model (Transformer, CLIP1, CLIP2, CLIP3, VAE) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
diff --git a/invokeai/app/invocations/sd3.py b/invokeai/app/invocations/sd3.py
new file mode 100644
index 0000000000..72089f05f0
--- /dev/null
+++ b/invokeai/app/invocations/sd3.py
@@ -0,0 +1,54 @@
+from pydantic import BaseModel, Field
+
+from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
+from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
+from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, VAEField
+from invokeai.app.services.shared.invocation_context import InvocationContext
+from invokeai.backend.model_manager.config import SubModelType
+
+
+class TransformerField(BaseModel):
+ transformer: ModelIdentifierField = Field(description="Info to load unet submodel")
+ scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
+
+
+@invocation_output("sd3_model_loader_output")
+class SD3ModelLoaderOutput(BaseInvocationOutput):
+ """Stable Diffuion 3 base model loader output"""
+
+ transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
+ clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
+ clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
+ clip3: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 3")
+ vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
+
+
+@invocation("sd3_model_loader", title="SD3 Main Model", tags=["model", "sd3"], category="model", version="1.0.0")
+class SD3ModelLoaderInvocation(BaseInvocation):
+ """Loads an SD3 base model, outputting its submodels."""
+
+ model: ModelIdentifierField = InputField(description=FieldDescriptions.sd3_main_model, ui_type=UIType.SD3MainModel)
+
+ def invoke(self, context: InvocationContext) -> SD3ModelLoaderOutput:
+ model_key = self.model.key
+
+ if not context.models.exists(model_key):
+ raise Exception(f"Unknown model: {model_key}")
+
+ transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
+ scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
+ tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
+ text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
+ tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
+ text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
+ tokenizer3 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
+ text_encoder3 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
+ vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
+
+ return SD3ModelLoaderOutput(
+ transformer=TransformerField(transformer=transformer, scheduler=scheduler),
+ clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
+ clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
+ clip3=CLIPField(tokenizer=tokenizer3, text_encoder=text_encoder3, loras=[], skipped_layers=0),
+ vae=VAEField(vae=vae),
+ )
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx
index 99937ceec4..810ec3ffff 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx
@@ -28,6 +28,8 @@ import {
isModelIdentifierFieldInputTemplate,
isSchedulerFieldInputInstance,
isSchedulerFieldInputTemplate,
+ isSD3MainModelFieldInputInstance,
+ isSD3MainModelFieldInputTemplate,
isSDXLMainModelFieldInputInstance,
isSDXLMainModelFieldInputTemplate,
isSDXLRefinerModelFieldInputInstance,
@@ -53,6 +55,7 @@ import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent'
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
+import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
@@ -133,6 +136,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return ;
}
+ if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
+ return ;
+ }
+
if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) {
return ;
}
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SD3MainModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SD3MainModelFieldInputComponent.tsx
new file mode 100644
index 0000000000..95feb08ae9
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SD3MainModelFieldInputComponent.tsx
@@ -0,0 +1,55 @@
+import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
+import { useAppDispatch } from 'app/store/storeHooks';
+import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
+import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
+import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
+import { memo, useCallback } from 'react';
+import { useSD3Models } from 'services/api/hooks/modelsByType';
+import type { MainModelConfig } from 'services/api/types';
+
+import type { FieldComponentProps } from './types';
+
+type Props = FieldComponentProps;
+
+const SD3MainModelFieldInputComponent = (props: Props) => {
+ const { nodeId, field } = props;
+ const dispatch = useAppDispatch();
+ const [modelConfigs, { isLoading }] = useSD3Models();
+ const _onChange = useCallback(
+ (value: MainModelConfig | null) => {
+ if (!value) {
+ return;
+ }
+ dispatch(
+ fieldMainModelValueChanged({
+ nodeId,
+ fieldName: field.name,
+ value,
+ })
+ );
+ },
+ [dispatch, field.name, nodeId]
+ );
+ const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
+ modelConfigs,
+ onChange: _onChange,
+ isLoading,
+ selectedModel: field.value,
+ });
+
+ return (
+
+
+
+
+
+ );
+};
+
+export default memo(SD3MainModelFieldInputComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts
index 4ede5cd479..5ba3733571 100644
--- a/invokeai/frontend/web/src/features/nodes/types/constants.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts
@@ -32,6 +32,7 @@ export const MODEL_TYPES = [
'LoRAModelField',
'MainModelField',
'SDXLMainModelField',
+ 'SD3MainModelField',
'SDXLRefinerModelField',
'VaeModelField',
'UNetField',
@@ -62,6 +63,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
MainModelField: 'teal.500',
SDXLMainModelField: 'teal.500',
SDXLRefinerModelField: 'teal.500',
+ SD3MainModelField: 'teal.500',
StringField: 'yellow.500',
T2IAdapterField: 'teal.500',
T2IAdapterModelField: 'teal.500',
diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts
index e2a84e3390..ae0d9edb01 100644
--- a/invokeai/frontend/web/src/features/nodes/types/field.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/field.ts
@@ -119,6 +119,10 @@ const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'),
originalType: zStatelessFieldType.optional(),
});
+const zSD3MainModelFieldType = zFieldTypeBase.extend({
+ name: z.literal('SD3MainModelField'),
+ originalType: zStatelessFieldType.optional(),
+});
const zVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('VAEModelField'),
originalType: zStatelessFieldType.optional(),
@@ -155,6 +159,7 @@ const zStatefulFieldType = z.union([
zMainModelFieldType,
zSDXLMainModelFieldType,
zSDXLRefinerModelFieldType,
+ zSD3MainModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
zControlNetModelFieldType,
@@ -466,6 +471,28 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR
zSDXLRefinerModelFieldInputTemplate.safeParse(val).success;
// #endregion
+// #region SD3MainModelField
+
+const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SD3 models only.
+const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
+ value: zSD3MainModelFieldValue,
+});
+const zSD3MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
+ type: zSD3MainModelFieldType,
+ originalType: zFieldType.optional(),
+ default: zSD3MainModelFieldValue,
+});
+const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
+ type: zSD3MainModelFieldType,
+});
+export type SD3MainModelFieldInputInstance = z.infer;
+export type SD3MainModelFieldInputTemplate = z.infer;
+export const isSD3MainModelFieldInputInstance = (val: unknown): val is SD3MainModelFieldInputInstance =>
+ zSD3MainModelFieldInputInstance.safeParse(val).success;
+export const isSD3MainModelFieldInputTemplate = (val: unknown): val is SD3MainModelFieldInputTemplate =>
+ zSD3MainModelFieldInputTemplate.safeParse(val).success;
+// #endregion
+
// #region VAEModelField
export const zVAEModelFieldValue = zModelIdentifierField.optional();
@@ -662,6 +689,7 @@ export const zStatefulFieldValue = z.union([
zMainModelFieldValue,
zSDXLMainModelFieldValue,
zSDXLRefinerModelFieldValue,
+ zSD3MainModelFieldValue,
zVAEModelFieldValue,
zLoRAModelFieldValue,
zControlNetModelFieldValue,
@@ -689,6 +717,7 @@ const zStatefulFieldInputInstance = z.union([
zMainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
+ zSD3MainModelFieldInputInstance,
zVAEModelFieldInputInstance,
zLoRAModelFieldInputInstance,
zControlNetModelFieldInputInstance,
@@ -717,6 +746,7 @@ const zStatefulFieldInputTemplate = z.union([
zMainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate,
zSDXLRefinerModelFieldInputTemplate,
+ zSD3MainModelFieldInputTemplate,
zVAEModelFieldInputTemplate,
zLoRAModelFieldInputTemplate,
zControlNetModelFieldInputTemplate,
@@ -746,6 +776,7 @@ const zStatefulFieldOutputTemplate = z.union([
zMainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate,
zSDXLRefinerModelFieldOutputTemplate,
+ zSD3MainModelFieldOutputTemplate,
zVAEModelFieldOutputTemplate,
zLoRAModelFieldOutputTemplate,
zControlNetModelFieldOutputTemplate,
diff --git a/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts b/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts
index f1d4e61300..00f3ccb67d 100644
--- a/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts
@@ -124,6 +124,11 @@ const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: {
isCollection: false,
isCollectionOrScalar: false,
},
+ SD3MainModelField: {
+ name: 'SD3MainModelField',
+ isCollection: false,
+ isCollectionOrScalar: false,
+ },
string: {
name: 'StringField',
isCollection: false,
diff --git a/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts b/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts
index c7a50b20e4..f433ad640c 100644
--- a/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts
@@ -90,6 +90,7 @@ const zFieldTypeV1 = z.enum([
'Scheduler',
'SDXLMainModelField',
'SDXLRefinerModelField',
+ 'SD3MainModelField',
'string',
'StringCollection',
'StringPolymorphic',
@@ -422,6 +423,11 @@ const zSDXLRefinerModelInputFieldValue = zInputFieldValueBase.extend({
value: zMainOrOnnxModel.optional(), // TODO: should narrow this down to a refiner model
});
+const zSD3MainModelInputFieldValue = zInputFieldValueBase.extend({
+ type: z.literal('SD3MainModelField'),
+ value: zMainOrOnnxModel.optional(),
+});
+
const zVaeModelField = zModelIdentifier;
const zVaeModelInputFieldValue = zInputFieldValueBase.extend({
@@ -573,6 +579,7 @@ const zInputFieldValue = z.discriminatedUnion('type', [
zSchedulerInputFieldValue,
zSDXLMainModelInputFieldValue,
zSDXLRefinerModelInputFieldValue,
+ zSD3MainModelInputFieldValue,
zStringCollectionInputFieldValue,
zStringPolymorphicInputFieldValue,
zStringInputFieldValue,
diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/field.ts b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts
index 4b680d1de3..15df9db85b 100644
--- a/invokeai/frontend/web/src/features/nodes/types/v2/field.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts
@@ -217,6 +217,20 @@ const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
});
// #endregion
+// #region SDXLMainModelField
+const zSD3MainModelFieldType = zFieldTypeBase.extend({
+ name: z.literal('SD3MainModelField'),
+});
+const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SD3 models only.
+const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
+ type: zSD3MainModelFieldType,
+ value: zSD3MainModelFieldValue,
+});
+const zSD3MainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
+ type: zSD3MainModelFieldType,
+});
+// #endregion
+
// #region VAEModelField
const zVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('VAEModelField'),
@@ -339,6 +353,7 @@ const zStatefulFieldType = z.union([
zMainModelFieldType,
zSDXLMainModelFieldType,
zSDXLRefinerModelFieldType,
+ zSD3MainModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
zControlNetModelFieldType,
@@ -378,6 +393,7 @@ const zStatefulFieldInputInstance = z.union([
zMainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
+ zSD3MainModelFieldInputInstance,
zVAEModelFieldInputInstance,
zLoRAModelFieldInputInstance,
zControlNetModelFieldInputInstance,
@@ -402,6 +418,7 @@ const zStatefulFieldOutputInstance = z.union([
zMainModelFieldOutputInstance,
zSDXLMainModelFieldOutputInstance,
zSDXLRefinerModelFieldOutputInstance,
+ zSD3MainModelFieldOutputInstance,
zVAEModelFieldOutputInstance,
zLoRAModelFieldOutputInstance,
zControlNetModelFieldOutputInstance,
diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts
index 597779fd61..ecee28f802 100644
--- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts
@@ -15,6 +15,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record =
MainModelField: undefined,
SchedulerField: 'euler',
SDXLMainModelField: undefined,
+ SD3MainModelField: undefined,
SDXLRefinerModelField: undefined,
StringField: '',
T2IAdapterModelField: undefined,
diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts
index 2b77274526..12d150ab12 100644
--- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts
@@ -15,6 +15,7 @@ import type {
MainModelFieldInputTemplate,
ModelIdentifierFieldInputTemplate,
SchedulerFieldInputTemplate,
+ SD3MainModelFieldInputTemplate,
SDXLMainModelFieldInputTemplate,
SDXLRefinerModelFieldInputTemplate,
StatefulFieldType,
@@ -193,6 +194,20 @@ const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder = ({
+ schemaObject,
+ baseField,
+ fieldType,
+}) => {
+ const template: SD3MainModelFieldInputTemplate = {
+ ...baseField,
+ type: fieldType,
+ default: schemaObject.default ?? undefined,
+ };
+
+ return template;
+};
+
const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder = ({
schemaObject,
baseField,
@@ -375,6 +390,7 @@ export const TEMPLATE_BUILDER_MAP: Record {
+ return config.type === 'main' && config.base === 'sd-3';
+};
+
+export const isNonSD3MainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
+ return config.type === 'main' && !(config.base === 'sd-3');
+};
+
export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'embedding';
};