wip: add SD3 Model Loader Invocation

This commit is contained in:
blessedcoolant 2024-06-14 22:21:09 +05:30
parent c79d9b9ecf
commit 0c970bc880
15 changed files with 426 additions and 136 deletions

View File

@ -42,6 +42,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
MainModel = "MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
SD3MainModel = "SD3MainModelField"
ONNXModel = "ONNXModelField"
VAEModel = "VAEModelField"
LoRAModel = "LoRAModelField"
@ -125,6 +126,7 @@ class FieldDescriptions:
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
@ -133,6 +135,7 @@ class FieldDescriptions:
main_model = "Main model (UNet, VAE, CLIP) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
sd3_main_model = "SD3 Main Model (Transformer, CLIP1, CLIP2, CLIP3, VAE) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"

View File

@ -0,0 +1,54 @@
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import SubModelType
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load unet submodel")
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
@invocation_output("sd3_model_loader_output")
class SD3ModelLoaderOutput(BaseInvocationOutput):
"""Stable Diffuion 3 base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
clip3: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 3")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("sd3_model_loader", title="SD3 Main Model", tags=["model", "sd3"], category="model", version="1.0.0")
class SD3ModelLoaderInvocation(BaseInvocation):
"""Loads an SD3 base model, outputting its submodels."""
model: ModelIdentifierField = InputField(description=FieldDescriptions.sd3_main_model, ui_type=UIType.SD3MainModel)
def invoke(self, context: InvocationContext) -> SD3ModelLoaderOutput:
model_key = self.model.key
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
tokenizer3 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
text_encoder3 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return SD3ModelLoaderOutput(
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
clip3=CLIPField(tokenizer=tokenizer3, text_encoder=text_encoder3, loras=[], skipped_layers=0),
vae=VAEField(vae=vae),
)

View File

@ -28,6 +28,8 @@ import {
isModelIdentifierFieldInputTemplate,
isSchedulerFieldInputInstance,
isSchedulerFieldInputTemplate,
isSD3MainModelFieldInputInstance,
isSD3MainModelFieldInputTemplate,
isSDXLMainModelFieldInputInstance,
isSDXLMainModelFieldInputTemplate,
isSDXLRefinerModelFieldInputInstance,
@ -53,6 +55,7 @@ import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent'
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
@ -133,6 +136,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) {
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

View File

@ -0,0 +1,55 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSD3Models } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate>;
const SD3MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSD3Models();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
);
};
export default memo(SD3MainModelFieldInputComponent);

View File

@ -32,6 +32,7 @@ export const MODEL_TYPES = [
'LoRAModelField',
'MainModelField',
'SDXLMainModelField',
'SD3MainModelField',
'SDXLRefinerModelField',
'VaeModelField',
'UNetField',
@ -62,6 +63,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
MainModelField: 'teal.500',
SDXLMainModelField: 'teal.500',
SDXLRefinerModelField: 'teal.500',
SD3MainModelField: 'teal.500',
StringField: 'yellow.500',
T2IAdapterField: 'teal.500',
T2IAdapterModelField: 'teal.500',

View File

@ -119,6 +119,10 @@ const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSD3MainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SD3MainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('VAEModelField'),
originalType: zStatelessFieldType.optional(),
@ -155,6 +159,7 @@ const zStatefulFieldType = z.union([
zMainModelFieldType,
zSDXLMainModelFieldType,
zSDXLRefinerModelFieldType,
zSD3MainModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
zControlNetModelFieldType,
@ -466,6 +471,28 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR
zSDXLRefinerModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SD3MainModelField
const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SD3 models only.
const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zSD3MainModelFieldValue,
});
const zSD3MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSD3MainModelFieldType,
originalType: zFieldType.optional(),
default: zSD3MainModelFieldValue,
});
const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zSD3MainModelFieldType,
});
export type SD3MainModelFieldInputInstance = z.infer<typeof zSD3MainModelFieldInputInstance>;
export type SD3MainModelFieldInputTemplate = z.infer<typeof zSD3MainModelFieldInputTemplate>;
export const isSD3MainModelFieldInputInstance = (val: unknown): val is SD3MainModelFieldInputInstance =>
zSD3MainModelFieldInputInstance.safeParse(val).success;
export const isSD3MainModelFieldInputTemplate = (val: unknown): val is SD3MainModelFieldInputTemplate =>
zSD3MainModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region VAEModelField
export const zVAEModelFieldValue = zModelIdentifierField.optional();
@ -662,6 +689,7 @@ export const zStatefulFieldValue = z.union([
zMainModelFieldValue,
zSDXLMainModelFieldValue,
zSDXLRefinerModelFieldValue,
zSD3MainModelFieldValue,
zVAEModelFieldValue,
zLoRAModelFieldValue,
zControlNetModelFieldValue,
@ -689,6 +717,7 @@ const zStatefulFieldInputInstance = z.union([
zMainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
zSD3MainModelFieldInputInstance,
zVAEModelFieldInputInstance,
zLoRAModelFieldInputInstance,
zControlNetModelFieldInputInstance,
@ -717,6 +746,7 @@ const zStatefulFieldInputTemplate = z.union([
zMainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate,
zSDXLRefinerModelFieldInputTemplate,
zSD3MainModelFieldInputTemplate,
zVAEModelFieldInputTemplate,
zLoRAModelFieldInputTemplate,
zControlNetModelFieldInputTemplate,
@ -746,6 +776,7 @@ const zStatefulFieldOutputTemplate = z.union([
zMainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate,
zSDXLRefinerModelFieldOutputTemplate,
zSD3MainModelFieldOutputTemplate,
zVAEModelFieldOutputTemplate,
zLoRAModelFieldOutputTemplate,
zControlNetModelFieldOutputTemplate,

View File

@ -124,6 +124,11 @@ const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: {
isCollection: false,
isCollectionOrScalar: false,
},
SD3MainModelField: {
name: 'SD3MainModelField',
isCollection: false,
isCollectionOrScalar: false,
},
string: {
name: 'StringField',
isCollection: false,

View File

@ -90,6 +90,7 @@ const zFieldTypeV1 = z.enum([
'Scheduler',
'SDXLMainModelField',
'SDXLRefinerModelField',
'SD3MainModelField',
'string',
'StringCollection',
'StringPolymorphic',
@ -422,6 +423,11 @@ const zSDXLRefinerModelInputFieldValue = zInputFieldValueBase.extend({
value: zMainOrOnnxModel.optional(), // TODO: should narrow this down to a refiner model
});
const zSD3MainModelInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('SD3MainModelField'),
value: zMainOrOnnxModel.optional(),
});
const zVaeModelField = zModelIdentifier;
const zVaeModelInputFieldValue = zInputFieldValueBase.extend({
@ -573,6 +579,7 @@ const zInputFieldValue = z.discriminatedUnion('type', [
zSchedulerInputFieldValue,
zSDXLMainModelInputFieldValue,
zSDXLRefinerModelInputFieldValue,
zSD3MainModelInputFieldValue,
zStringCollectionInputFieldValue,
zStringPolymorphicInputFieldValue,
zStringInputFieldValue,

View File

@ -217,6 +217,20 @@ const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
});
// #endregion
// #region SDXLMainModelField
const zSD3MainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SD3MainModelField'),
});
const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SD3 models only.
const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
type: zSD3MainModelFieldType,
value: zSD3MainModelFieldValue,
});
const zSD3MainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
type: zSD3MainModelFieldType,
});
// #endregion
// #region VAEModelField
const zVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('VAEModelField'),
@ -339,6 +353,7 @@ const zStatefulFieldType = z.union([
zMainModelFieldType,
zSDXLMainModelFieldType,
zSDXLRefinerModelFieldType,
zSD3MainModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
zControlNetModelFieldType,
@ -378,6 +393,7 @@ const zStatefulFieldInputInstance = z.union([
zMainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
zSD3MainModelFieldInputInstance,
zVAEModelFieldInputInstance,
zLoRAModelFieldInputInstance,
zControlNetModelFieldInputInstance,
@ -402,6 +418,7 @@ const zStatefulFieldOutputInstance = z.union([
zMainModelFieldOutputInstance,
zSDXLMainModelFieldOutputInstance,
zSDXLRefinerModelFieldOutputInstance,
zSD3MainModelFieldOutputInstance,
zVAEModelFieldOutputInstance,
zLoRAModelFieldOutputInstance,
zControlNetModelFieldOutputInstance,

View File

@ -15,6 +15,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
MainModelField: undefined,
SchedulerField: 'euler',
SDXLMainModelField: undefined,
SD3MainModelField: undefined,
SDXLRefinerModelField: undefined,
StringField: '',
T2IAdapterModelField: undefined,

View File

@ -15,6 +15,7 @@ import type {
MainModelFieldInputTemplate,
ModelIdentifierFieldInputTemplate,
SchedulerFieldInputTemplate,
SD3MainModelFieldInputTemplate,
SDXLMainModelFieldInputTemplate,
SDXLRefinerModelFieldInputTemplate,
StatefulFieldType,
@ -193,6 +194,20 @@ const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefiner
return template;
};
const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: SD3MainModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<VAEModelFieldInputTemplate> = ({
schemaObject,
baseField,
@ -375,6 +390,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
SchedulerField: buildSchedulerFieldInputTemplate,
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
SD3MainModelField: buildSD3MainModelFieldInputTemplate,
StringField: buildStringFieldInputTemplate,
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
VAEModelField: buildVAEModelFieldInputTemplate,

View File

@ -30,6 +30,7 @@ const MODEL_FIELD_TYPES = [
'MainModelField',
'SDXLMainModelField',
'SDXLRefinerModelField',
'SD3MainModelField',
'VAEModelField',
'LoRAModelField',
'ControlNetModelField',

View File

@ -10,6 +10,7 @@ import {
isNonRefinerMainModelConfig,
isNonSDXLMainModelConfig,
isRefinerMainModelModelConfig,
isSD3MainModelModelConfig,
isSDXLMainModelModelConfig,
isT2IAdapterModelConfig,
isTIModelConfig,
@ -35,6 +36,7 @@ export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);

File diff suppressed because one or more lines are too long

View File

@ -112,6 +112,14 @@ export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is Main
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2' || config.base === 'sd-3');
};
export const isSD3MainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'sd-3';
};
export const isNonSD3MainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && !(config.base === 'sd-3');
};
export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'embedding';
};