mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'lstein/feat/sd3-model-loading' of github.com:invoke-ai/InvokeAI into lstein/feat/sd3-model-loading
This commit is contained in:
commit
ac0396e6f7
@ -42,6 +42,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
|||||||
MainModel = "MainModelField"
|
MainModel = "MainModelField"
|
||||||
SDXLMainModel = "SDXLMainModelField"
|
SDXLMainModel = "SDXLMainModelField"
|
||||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||||
|
SD3MainModel = "SD3MainModelField"
|
||||||
ONNXModel = "ONNXModelField"
|
ONNXModel = "ONNXModelField"
|
||||||
VAEModel = "VAEModelField"
|
VAEModel = "VAEModelField"
|
||||||
LoRAModel = "LoRAModelField"
|
LoRAModel = "LoRAModelField"
|
||||||
@ -125,6 +126,7 @@ class FieldDescriptions:
|
|||||||
noise = "Noise tensor"
|
noise = "Noise tensor"
|
||||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||||
unet = "UNet (scheduler, LoRAs)"
|
unet = "UNet (scheduler, LoRAs)"
|
||||||
|
transformer = "Transformer"
|
||||||
vae = "VAE"
|
vae = "VAE"
|
||||||
cond = "Conditioning tensor"
|
cond = "Conditioning tensor"
|
||||||
controlnet_model = "ControlNet model to load"
|
controlnet_model = "ControlNet model to load"
|
||||||
@ -133,6 +135,7 @@ class FieldDescriptions:
|
|||||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||||
|
sd3_main_model = "SD3 Main Model (Transformer, CLIP1, CLIP2, CLIP3, VAE) to load"
|
||||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||||
|
@ -8,13 +8,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
|||||||
from invokeai.app.shared.models import FreeUConfig
|
from invokeai.app.shared.models import FreeUConfig
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||||
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
|
||||||
BaseInvocation,
|
|
||||||
BaseInvocationOutput,
|
|
||||||
Classification,
|
|
||||||
invocation,
|
|
||||||
invocation_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelIdentifierField(BaseModel):
|
class ModelIdentifierField(BaseModel):
|
||||||
@ -54,6 +48,11 @@ class UNetField(BaseModel):
|
|||||||
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerField(BaseModel):
|
||||||
|
transformer: ModelIdentifierField = Field(description="Info to load unet submodel")
|
||||||
|
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
|
||||||
|
|
||||||
|
|
||||||
class CLIPField(BaseModel):
|
class CLIPField(BaseModel):
|
||||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||||
|
47
invokeai/app/invocations/sd3.py
Normal file
47
invokeai/app/invocations/sd3.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
|
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
|
||||||
|
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, TransformerField, VAEField
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.model_manager.config import SubModelType
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("sd3_model_loader_output")
|
||||||
|
class SD3ModelLoaderOutput(BaseInvocationOutput):
|
||||||
|
"""Stable Diffuion 3 base model loader output"""
|
||||||
|
|
||||||
|
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||||
|
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||||
|
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||||
|
clip3: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 3")
|
||||||
|
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("sd3_model_loader", title="SD3 Main Model", tags=["model", "sd3"], category="model", version="1.0.0")
|
||||||
|
class SD3ModelLoaderInvocation(BaseInvocation):
|
||||||
|
"""Loads an SD3 base model, outputting its submodels."""
|
||||||
|
|
||||||
|
model: ModelIdentifierField = InputField(description=FieldDescriptions.sd3_main_model, ui_type=UIType.SD3MainModel)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> SD3ModelLoaderOutput:
|
||||||
|
model_key = self.model.key
|
||||||
|
|
||||||
|
if not context.models.exists(model_key):
|
||||||
|
raise Exception(f"Unknown model: {model_key}")
|
||||||
|
|
||||||
|
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||||
|
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||||
|
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||||
|
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||||
|
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||||
|
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||||
|
tokenizer3 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||||
|
text_encoder3 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||||
|
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||||
|
|
||||||
|
return SD3ModelLoaderOutput(
|
||||||
|
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
|
||||||
|
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
||||||
|
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
||||||
|
clip3=CLIPField(tokenizer=tokenizer3, text_encoder=text_encoder3, loras=[], skipped_layers=0),
|
||||||
|
vae=VAEField(vae=vae),
|
||||||
|
)
|
@ -11,6 +11,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
|
|||||||
any: 'base',
|
any: 'base',
|
||||||
'sd-1': 'green',
|
'sd-1': 'green',
|
||||||
'sd-2': 'teal',
|
'sd-2': 'teal',
|
||||||
|
'sd-3': 'purple',
|
||||||
sdxl: 'invokeBlue',
|
sdxl: 'invokeBlue',
|
||||||
'sdxl-refiner': 'invokeBlue',
|
'sdxl-refiner': 'invokeBlue',
|
||||||
};
|
};
|
||||||
|
@ -10,6 +10,7 @@ import type { UpdateModelArg } from 'services/api/endpoints/models';
|
|||||||
const options: ComboboxOption[] = [
|
const options: ComboboxOption[] = [
|
||||||
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
||||||
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
|
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
|
||||||
|
{ value: 'sd-3', label: MODEL_TYPE_MAP['sd-3'] },
|
||||||
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
|
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
|
||||||
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
||||||
];
|
];
|
||||||
|
@ -28,6 +28,8 @@ import {
|
|||||||
isModelIdentifierFieldInputTemplate,
|
isModelIdentifierFieldInputTemplate,
|
||||||
isSchedulerFieldInputInstance,
|
isSchedulerFieldInputInstance,
|
||||||
isSchedulerFieldInputTemplate,
|
isSchedulerFieldInputTemplate,
|
||||||
|
isSD3MainModelFieldInputInstance,
|
||||||
|
isSD3MainModelFieldInputTemplate,
|
||||||
isSDXLMainModelFieldInputInstance,
|
isSDXLMainModelFieldInputInstance,
|
||||||
isSDXLMainModelFieldInputTemplate,
|
isSDXLMainModelFieldInputTemplate,
|
||||||
isSDXLRefinerModelFieldInputInstance,
|
isSDXLRefinerModelFieldInputInstance,
|
||||||
@ -53,6 +55,7 @@ import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent'
|
|||||||
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
|
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
|
||||||
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
|
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
|
||||||
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
|
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
|
||||||
|
import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
|
||||||
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
||||||
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
|
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
|
||||||
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
||||||
@ -133,6 +136,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
|
||||||
|
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
|
}
|
||||||
|
|
||||||
if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) {
|
if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) {
|
||||||
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,55 @@
|
|||||||
|
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
|
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useSD3Models } from 'services/api/hooks/modelsByType';
|
||||||
|
import type { MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
import type { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
type Props = FieldComponentProps<SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate>;
|
||||||
|
|
||||||
|
const SD3MainModelFieldInputComponent = (props: Props) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const [modelConfigs, { isLoading }] = useSD3Models();
|
||||||
|
const _onChange = useCallback(
|
||||||
|
(value: MainModelConfig | null) => {
|
||||||
|
if (!value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(
|
||||||
|
fieldMainModelValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
|
modelConfigs,
|
||||||
|
onChange: _onChange,
|
||||||
|
isLoading,
|
||||||
|
selectedModel: field.value,
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex w="full" alignItems="center" gap={2}>
|
||||||
|
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||||
|
<Combobox
|
||||||
|
value={value}
|
||||||
|
placeholder={placeholder}
|
||||||
|
options={options}
|
||||||
|
onChange={onChange}
|
||||||
|
noOptionsMessage={noOptionsMessage}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(SD3MainModelFieldInputComponent);
|
@ -839,7 +839,7 @@ export const schema = {
|
|||||||
},
|
},
|
||||||
BaseModelType: {
|
BaseModelType: {
|
||||||
description: 'Base model type.',
|
description: 'Base model type.',
|
||||||
enum: ['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
enum: ['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner'],
|
||||||
title: 'BaseModelType',
|
title: 'BaseModelType',
|
||||||
type: 'string',
|
type: 'string',
|
||||||
},
|
},
|
||||||
@ -855,8 +855,11 @@ export const schema = {
|
|||||||
'unet',
|
'unet',
|
||||||
'text_encoder',
|
'text_encoder',
|
||||||
'text_encoder_2',
|
'text_encoder_2',
|
||||||
|
'text_encoder_3',
|
||||||
'tokenizer',
|
'tokenizer',
|
||||||
'tokenizer_2',
|
'tokenizer_2',
|
||||||
|
'tokenizer_3',
|
||||||
|
'transformer',
|
||||||
'vae',
|
'vae',
|
||||||
'vae_decoder',
|
'vae_decoder',
|
||||||
'vae_encoder',
|
'vae_encoder',
|
||||||
|
@ -55,7 +55,7 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
|
|||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
// #region Model-related schemas
|
// #region Model-related schemas
|
||||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner']);
|
||||||
const zModelType = z.enum([
|
const zModelType = z.enum([
|
||||||
'main',
|
'main',
|
||||||
'vae',
|
'vae',
|
||||||
@ -71,8 +71,11 @@ const zSubModelType = z.enum([
|
|||||||
'unet',
|
'unet',
|
||||||
'text_encoder',
|
'text_encoder',
|
||||||
'text_encoder_2',
|
'text_encoder_2',
|
||||||
|
'text_encoder_3',
|
||||||
'tokenizer',
|
'tokenizer',
|
||||||
'tokenizer_2',
|
'tokenizer_2',
|
||||||
|
'tokenizer_3',
|
||||||
|
'transformer',
|
||||||
'vae',
|
'vae',
|
||||||
'vae_decoder',
|
'vae_decoder',
|
||||||
'vae_encoder',
|
'vae_encoder',
|
||||||
|
@ -32,9 +32,11 @@ export const MODEL_TYPES = [
|
|||||||
'LoRAModelField',
|
'LoRAModelField',
|
||||||
'MainModelField',
|
'MainModelField',
|
||||||
'SDXLMainModelField',
|
'SDXLMainModelField',
|
||||||
|
'SD3MainModelField',
|
||||||
'SDXLRefinerModelField',
|
'SDXLRefinerModelField',
|
||||||
'VaeModelField',
|
'VaeModelField',
|
||||||
'UNetField',
|
'UNetField',
|
||||||
|
'TransformerField',
|
||||||
'VAEField',
|
'VAEField',
|
||||||
'CLIPField',
|
'CLIPField',
|
||||||
'T2IAdapterModelField',
|
'T2IAdapterModelField',
|
||||||
@ -62,10 +64,12 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
|||||||
MainModelField: 'teal.500',
|
MainModelField: 'teal.500',
|
||||||
SDXLMainModelField: 'teal.500',
|
SDXLMainModelField: 'teal.500',
|
||||||
SDXLRefinerModelField: 'teal.500',
|
SDXLRefinerModelField: 'teal.500',
|
||||||
|
SD3MainModelField: 'teal.500',
|
||||||
StringField: 'yellow.500',
|
StringField: 'yellow.500',
|
||||||
T2IAdapterField: 'teal.500',
|
T2IAdapterField: 'teal.500',
|
||||||
T2IAdapterModelField: 'teal.500',
|
T2IAdapterModelField: 'teal.500',
|
||||||
UNetField: 'red.500',
|
UNetField: 'red.500',
|
||||||
|
TransformerField: 'red.500',
|
||||||
VAEField: 'blue.500',
|
VAEField: 'blue.500',
|
||||||
VAEModelField: 'teal.500',
|
VAEModelField: 'teal.500',
|
||||||
};
|
};
|
||||||
|
@ -119,6 +119,10 @@ const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
|||||||
name: z.literal('SDXLRefinerModelField'),
|
name: z.literal('SDXLRefinerModelField'),
|
||||||
originalType: zStatelessFieldType.optional(),
|
originalType: zStatelessFieldType.optional(),
|
||||||
});
|
});
|
||||||
|
const zSD3MainModelFieldType = zFieldTypeBase.extend({
|
||||||
|
name: z.literal('SD3MainModelField'),
|
||||||
|
originalType: zStatelessFieldType.optional(),
|
||||||
|
});
|
||||||
const zVAEModelFieldType = zFieldTypeBase.extend({
|
const zVAEModelFieldType = zFieldTypeBase.extend({
|
||||||
name: z.literal('VAEModelField'),
|
name: z.literal('VAEModelField'),
|
||||||
originalType: zStatelessFieldType.optional(),
|
originalType: zStatelessFieldType.optional(),
|
||||||
@ -155,6 +159,7 @@ const zStatefulFieldType = z.union([
|
|||||||
zMainModelFieldType,
|
zMainModelFieldType,
|
||||||
zSDXLMainModelFieldType,
|
zSDXLMainModelFieldType,
|
||||||
zSDXLRefinerModelFieldType,
|
zSDXLRefinerModelFieldType,
|
||||||
|
zSD3MainModelFieldType,
|
||||||
zVAEModelFieldType,
|
zVAEModelFieldType,
|
||||||
zLoRAModelFieldType,
|
zLoRAModelFieldType,
|
||||||
zControlNetModelFieldType,
|
zControlNetModelFieldType,
|
||||||
@ -466,6 +471,28 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR
|
|||||||
zSDXLRefinerModelFieldInputTemplate.safeParse(val).success;
|
zSDXLRefinerModelFieldInputTemplate.safeParse(val).success;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
|
// #region SD3MainModelField
|
||||||
|
|
||||||
|
const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SD3 models only.
|
||||||
|
const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||||
|
value: zSD3MainModelFieldValue,
|
||||||
|
});
|
||||||
|
const zSD3MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||||
|
type: zSD3MainModelFieldType,
|
||||||
|
originalType: zFieldType.optional(),
|
||||||
|
default: zSD3MainModelFieldValue,
|
||||||
|
});
|
||||||
|
const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
|
type: zSD3MainModelFieldType,
|
||||||
|
});
|
||||||
|
export type SD3MainModelFieldInputInstance = z.infer<typeof zSD3MainModelFieldInputInstance>;
|
||||||
|
export type SD3MainModelFieldInputTemplate = z.infer<typeof zSD3MainModelFieldInputTemplate>;
|
||||||
|
export const isSD3MainModelFieldInputInstance = (val: unknown): val is SD3MainModelFieldInputInstance =>
|
||||||
|
zSD3MainModelFieldInputInstance.safeParse(val).success;
|
||||||
|
export const isSD3MainModelFieldInputTemplate = (val: unknown): val is SD3MainModelFieldInputTemplate =>
|
||||||
|
zSD3MainModelFieldInputTemplate.safeParse(val).success;
|
||||||
|
// #endregion
|
||||||
|
|
||||||
// #region VAEModelField
|
// #region VAEModelField
|
||||||
|
|
||||||
export const zVAEModelFieldValue = zModelIdentifierField.optional();
|
export const zVAEModelFieldValue = zModelIdentifierField.optional();
|
||||||
@ -662,6 +689,7 @@ export const zStatefulFieldValue = z.union([
|
|||||||
zMainModelFieldValue,
|
zMainModelFieldValue,
|
||||||
zSDXLMainModelFieldValue,
|
zSDXLMainModelFieldValue,
|
||||||
zSDXLRefinerModelFieldValue,
|
zSDXLRefinerModelFieldValue,
|
||||||
|
zSD3MainModelFieldValue,
|
||||||
zVAEModelFieldValue,
|
zVAEModelFieldValue,
|
||||||
zLoRAModelFieldValue,
|
zLoRAModelFieldValue,
|
||||||
zControlNetModelFieldValue,
|
zControlNetModelFieldValue,
|
||||||
@ -689,6 +717,7 @@ const zStatefulFieldInputInstance = z.union([
|
|||||||
zMainModelFieldInputInstance,
|
zMainModelFieldInputInstance,
|
||||||
zSDXLMainModelFieldInputInstance,
|
zSDXLMainModelFieldInputInstance,
|
||||||
zSDXLRefinerModelFieldInputInstance,
|
zSDXLRefinerModelFieldInputInstance,
|
||||||
|
zSD3MainModelFieldInputInstance,
|
||||||
zVAEModelFieldInputInstance,
|
zVAEModelFieldInputInstance,
|
||||||
zLoRAModelFieldInputInstance,
|
zLoRAModelFieldInputInstance,
|
||||||
zControlNetModelFieldInputInstance,
|
zControlNetModelFieldInputInstance,
|
||||||
@ -717,6 +746,7 @@ const zStatefulFieldInputTemplate = z.union([
|
|||||||
zMainModelFieldInputTemplate,
|
zMainModelFieldInputTemplate,
|
||||||
zSDXLMainModelFieldInputTemplate,
|
zSDXLMainModelFieldInputTemplate,
|
||||||
zSDXLRefinerModelFieldInputTemplate,
|
zSDXLRefinerModelFieldInputTemplate,
|
||||||
|
zSD3MainModelFieldInputTemplate,
|
||||||
zVAEModelFieldInputTemplate,
|
zVAEModelFieldInputTemplate,
|
||||||
zLoRAModelFieldInputTemplate,
|
zLoRAModelFieldInputTemplate,
|
||||||
zControlNetModelFieldInputTemplate,
|
zControlNetModelFieldInputTemplate,
|
||||||
@ -746,6 +776,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
|||||||
zMainModelFieldOutputTemplate,
|
zMainModelFieldOutputTemplate,
|
||||||
zSDXLMainModelFieldOutputTemplate,
|
zSDXLMainModelFieldOutputTemplate,
|
||||||
zSDXLRefinerModelFieldOutputTemplate,
|
zSDXLRefinerModelFieldOutputTemplate,
|
||||||
|
zSD3MainModelFieldOutputTemplate,
|
||||||
zVAEModelFieldOutputTemplate,
|
zVAEModelFieldOutputTemplate,
|
||||||
zLoRAModelFieldOutputTemplate,
|
zLoRAModelFieldOutputTemplate,
|
||||||
zControlNetModelFieldOutputTemplate,
|
zControlNetModelFieldOutputTemplate,
|
||||||
|
@ -44,7 +44,7 @@ export const zSchedulerField = z.enum([
|
|||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
// #region Model-related schemas
|
// #region Model-related schemas
|
||||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner']);
|
||||||
const zModelName = z.string().min(3);
|
const zModelName = z.string().min(3);
|
||||||
export const zModelIdentifier = z.object({
|
export const zModelIdentifier = z.object({
|
||||||
model_name: zModelName,
|
model_name: zModelName,
|
||||||
|
@ -217,6 +217,20 @@ const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
|||||||
});
|
});
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
|
// #region SDXLMainModelField
|
||||||
|
const zSD3MainModelFieldType = zFieldTypeBase.extend({
|
||||||
|
name: z.literal('SD3MainModelField'),
|
||||||
|
});
|
||||||
|
const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SD3 models only.
|
||||||
|
const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||||
|
type: zSD3MainModelFieldType,
|
||||||
|
value: zSD3MainModelFieldValue,
|
||||||
|
});
|
||||||
|
const zSD3MainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||||
|
type: zSD3MainModelFieldType,
|
||||||
|
});
|
||||||
|
// #endregion
|
||||||
|
|
||||||
// #region VAEModelField
|
// #region VAEModelField
|
||||||
const zVAEModelFieldType = zFieldTypeBase.extend({
|
const zVAEModelFieldType = zFieldTypeBase.extend({
|
||||||
name: z.literal('VAEModelField'),
|
name: z.literal('VAEModelField'),
|
||||||
@ -339,6 +353,7 @@ const zStatefulFieldType = z.union([
|
|||||||
zMainModelFieldType,
|
zMainModelFieldType,
|
||||||
zSDXLMainModelFieldType,
|
zSDXLMainModelFieldType,
|
||||||
zSDXLRefinerModelFieldType,
|
zSDXLRefinerModelFieldType,
|
||||||
|
zSD3MainModelFieldType,
|
||||||
zVAEModelFieldType,
|
zVAEModelFieldType,
|
||||||
zLoRAModelFieldType,
|
zLoRAModelFieldType,
|
||||||
zControlNetModelFieldType,
|
zControlNetModelFieldType,
|
||||||
@ -378,6 +393,7 @@ const zStatefulFieldInputInstance = z.union([
|
|||||||
zMainModelFieldInputInstance,
|
zMainModelFieldInputInstance,
|
||||||
zSDXLMainModelFieldInputInstance,
|
zSDXLMainModelFieldInputInstance,
|
||||||
zSDXLRefinerModelFieldInputInstance,
|
zSDXLRefinerModelFieldInputInstance,
|
||||||
|
zSD3MainModelFieldInputInstance,
|
||||||
zVAEModelFieldInputInstance,
|
zVAEModelFieldInputInstance,
|
||||||
zLoRAModelFieldInputInstance,
|
zLoRAModelFieldInputInstance,
|
||||||
zControlNetModelFieldInputInstance,
|
zControlNetModelFieldInputInstance,
|
||||||
@ -402,6 +418,7 @@ const zStatefulFieldOutputInstance = z.union([
|
|||||||
zMainModelFieldOutputInstance,
|
zMainModelFieldOutputInstance,
|
||||||
zSDXLMainModelFieldOutputInstance,
|
zSDXLMainModelFieldOutputInstance,
|
||||||
zSDXLRefinerModelFieldOutputInstance,
|
zSDXLRefinerModelFieldOutputInstance,
|
||||||
|
zSD3MainModelFieldOutputInstance,
|
||||||
zVAEModelFieldOutputInstance,
|
zVAEModelFieldOutputInstance,
|
||||||
zLoRAModelFieldOutputInstance,
|
zLoRAModelFieldOutputInstance,
|
||||||
zControlNetModelFieldOutputInstance,
|
zControlNetModelFieldOutputInstance,
|
||||||
|
@ -15,6 +15,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
|||||||
MainModelField: undefined,
|
MainModelField: undefined,
|
||||||
SchedulerField: 'euler',
|
SchedulerField: 'euler',
|
||||||
SDXLMainModelField: undefined,
|
SDXLMainModelField: undefined,
|
||||||
|
SD3MainModelField: undefined,
|
||||||
SDXLRefinerModelField: undefined,
|
SDXLRefinerModelField: undefined,
|
||||||
StringField: '',
|
StringField: '',
|
||||||
T2IAdapterModelField: undefined,
|
T2IAdapterModelField: undefined,
|
||||||
|
@ -15,6 +15,7 @@ import type {
|
|||||||
MainModelFieldInputTemplate,
|
MainModelFieldInputTemplate,
|
||||||
ModelIdentifierFieldInputTemplate,
|
ModelIdentifierFieldInputTemplate,
|
||||||
SchedulerFieldInputTemplate,
|
SchedulerFieldInputTemplate,
|
||||||
|
SD3MainModelFieldInputTemplate,
|
||||||
SDXLMainModelFieldInputTemplate,
|
SDXLMainModelFieldInputTemplate,
|
||||||
SDXLRefinerModelFieldInputTemplate,
|
SDXLRefinerModelFieldInputTemplate,
|
||||||
StatefulFieldType,
|
StatefulFieldType,
|
||||||
@ -193,6 +194,20 @@ const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefiner
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainModelFieldInputTemplate> = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
fieldType,
|
||||||
|
}) => {
|
||||||
|
const template: SD3MainModelFieldInputTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: fieldType,
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<VAEModelFieldInputTemplate> = ({
|
const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<VAEModelFieldInputTemplate> = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -375,6 +390,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
|||||||
SchedulerField: buildSchedulerFieldInputTemplate,
|
SchedulerField: buildSchedulerFieldInputTemplate,
|
||||||
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
|
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
|
||||||
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
||||||
|
SD3MainModelField: buildSD3MainModelFieldInputTemplate,
|
||||||
StringField: buildStringFieldInputTemplate,
|
StringField: buildStringFieldInputTemplate,
|
||||||
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
|
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
|
||||||
VAEModelField: buildVAEModelFieldInputTemplate,
|
VAEModelField: buildVAEModelFieldInputTemplate,
|
||||||
|
@ -30,6 +30,7 @@ const MODEL_FIELD_TYPES = [
|
|||||||
'MainModelField',
|
'MainModelField',
|
||||||
'SDXLMainModelField',
|
'SDXLMainModelField',
|
||||||
'SDXLRefinerModelField',
|
'SDXLRefinerModelField',
|
||||||
|
'SD3MainModelField',
|
||||||
'VAEModelField',
|
'VAEModelField',
|
||||||
'LoRAModelField',
|
'LoRAModelField',
|
||||||
'ControlNetModelField',
|
'ControlNetModelField',
|
||||||
|
@ -39,7 +39,7 @@ const ParamClipSkip = () => {
|
|||||||
return CLIP_SKIP_MAP[model.base].markers;
|
return CLIP_SKIP_MAP[model.base].markers;
|
||||||
}, [model]);
|
}, [model]);
|
||||||
|
|
||||||
if (model?.base === 'sdxl') {
|
if (model?.base === 'sdxl' || model?.base === 'sd-3') {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ export const MODEL_TYPE_MAP = {
|
|||||||
any: 'Any',
|
any: 'Any',
|
||||||
'sd-1': 'Stable Diffusion 1.x',
|
'sd-1': 'Stable Diffusion 1.x',
|
||||||
'sd-2': 'Stable Diffusion 2.x',
|
'sd-2': 'Stable Diffusion 2.x',
|
||||||
|
'sd-3': 'Stable Diffusion 3.x',
|
||||||
sdxl: 'Stable Diffusion XL',
|
sdxl: 'Stable Diffusion XL',
|
||||||
'sdxl-refiner': 'Stable Diffusion XL Refiner',
|
'sdxl-refiner': 'Stable Diffusion XL Refiner',
|
||||||
};
|
};
|
||||||
@ -18,6 +19,7 @@ export const MODEL_TYPE_SHORT_MAP = {
|
|||||||
any: 'Any',
|
any: 'Any',
|
||||||
'sd-1': 'SD1.X',
|
'sd-1': 'SD1.X',
|
||||||
'sd-2': 'SD2.X',
|
'sd-2': 'SD2.X',
|
||||||
|
'sd-3': 'SD3.X',
|
||||||
sdxl: 'SDXL',
|
sdxl: 'SDXL',
|
||||||
'sdxl-refiner': 'SDXLR',
|
'sdxl-refiner': 'SDXLR',
|
||||||
};
|
};
|
||||||
@ -38,6 +40,11 @@ export const CLIP_SKIP_MAP = {
|
|||||||
maxClip: 24,
|
maxClip: 24,
|
||||||
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||||
},
|
},
|
||||||
|
// TODO: Update this when we have more details on how CLIP SKIP works with SD3
|
||||||
|
'sd-3': {
|
||||||
|
maxClip: 24,
|
||||||
|
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||||
|
},
|
||||||
sdxl: {
|
sdxl: {
|
||||||
maxClip: 24,
|
maxClip: 24,
|
||||||
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||||
|
@ -10,6 +10,7 @@ import {
|
|||||||
isNonRefinerMainModelConfig,
|
isNonRefinerMainModelConfig,
|
||||||
isNonSDXLMainModelConfig,
|
isNonSDXLMainModelConfig,
|
||||||
isRefinerMainModelModelConfig,
|
isRefinerMainModelModelConfig,
|
||||||
|
isSD3MainModelModelConfig,
|
||||||
isSDXLMainModelModelConfig,
|
isSDXLMainModelModelConfig,
|
||||||
isT2IAdapterModelConfig,
|
isT2IAdapterModelConfig,
|
||||||
isTIModelConfig,
|
isTIModelConfig,
|
||||||
@ -35,6 +36,7 @@ export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
|
|||||||
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
||||||
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
||||||
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
||||||
|
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
|
||||||
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
||||||
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
||||||
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
||||||
|
File diff suppressed because one or more lines are too long
@ -109,7 +109,15 @@ export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is Ma
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||||
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2');
|
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2' || config.base === 'sd-3');
|
||||||
|
};
|
||||||
|
|
||||||
|
export const isSD3MainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||||
|
return config.type === 'main' && config.base === 'sd-3';
|
||||||
|
};
|
||||||
|
|
||||||
|
export const isNonSD3MainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||||
|
return config.type === 'main' && !(config.base === 'sd-3');
|
||||||
};
|
};
|
||||||
|
|
||||||
export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||||
|
@ -39,6 +39,7 @@ from invokeai.app.invocations.model import (
|
|||||||
ModelIdentifierField,
|
ModelIdentifierField,
|
||||||
ModelLoaderOutput,
|
ModelLoaderOutput,
|
||||||
SDXLLoRALoaderOutput,
|
SDXLLoRALoaderOutput,
|
||||||
|
TransformerField,
|
||||||
UNetField,
|
UNetField,
|
||||||
UNetOutput,
|
UNetOutput,
|
||||||
VAEField,
|
VAEField,
|
||||||
@ -117,6 +118,7 @@ __all__ = [
|
|||||||
# invokeai.app.invocations.model
|
# invokeai.app.invocations.model
|
||||||
"ModelIdentifierField",
|
"ModelIdentifierField",
|
||||||
"UNetField",
|
"UNetField",
|
||||||
|
"TransformerField",
|
||||||
"CLIPField",
|
"CLIPField",
|
||||||
"VAEField",
|
"VAEField",
|
||||||
"UNetOutput",
|
"UNetOutput",
|
||||||
|
Loading…
Reference in New Issue
Block a user