WIP - A bunch of boilerplate to support Spandrel Image-to-Image models throughout the model manager and the frontend.

This commit is contained in:
Ryan Dick 2024-06-28 18:03:09 -04:00
parent 95079dc7d4
commit 29c8ddfb88
15 changed files with 287 additions and 19 deletions

View File

@ -48,6 +48,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
ControlNetModel = "ControlNetModelField" ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField" IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField" T2IAdapterModel = "T2IAdapterModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
# endregion # endregion
# region Misc Field Types # region Misc Field Types
@ -134,6 +135,7 @@ class FieldDescriptions:
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
spandrel_image_to_image_model = "Spandrel Image-to-Image model"
lora_weight = "The weight at which the LoRA is applied to each model" lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor" compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
raw_prompt = "Raw prompt text (no parsing)" raw_prompt = "Raw prompt text (no parsing)"

View File

@ -373,6 +373,17 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}") return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
class SpandrelImageToImageConfig(ModelConfigBase):
"""Model config for Spandrel Image to Image models."""
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
def get_model_discriminator_value(v: Any) -> str: def get_model_discriminator_value(v: Any) -> str:
""" """
Computes the discriminator value for a model config. Computes the discriminator value for a model config.
@ -409,6 +420,7 @@ AnyModelConfig = Annotated[
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()], Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()], Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()], Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()], Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
], ],
Discriminator(get_model_discriminator_value), Discriminator(get_model_discriminator_value),

View File

@ -243,10 +243,14 @@ class ModelProbe(object):
# Check if the model can be loaded as a SpandrelImageToImageModel. # Check if the model can be loaded as a SpandrelImageToImageModel.
try: try:
_ = SpandrelImageToImageModel.load_from_state_dict(ckpt) # TODO(ryand): Figure out why load_from_state_dict() doesn't work as expected.
# _ = SpandrelImageToImageModel.load_from_state_dict(ckpt)
_ = SpandrelImageToImageModel.load_from_file(model_path)
return ModelType.SpandrelImageToImage return ModelType.SpandrelImageToImage
except Exception: except Exception as e:
# TODO(ryand): Catch a more specific exception type here if we can. # TODO(ryand): Catch a more specific exception type here if we can.
# TODO(ryand): Delete this print statement.
print(e)
pass pass
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}") raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
@ -579,9 +583,9 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
raise NotImplementedError() raise NotImplementedError()
class SpandrelImageToImageModelProbe(CheckpointProbeBase): class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
raise NotImplementedError() return BaseModelType.Any
######################################################## ########################################################
@ -791,6 +795,11 @@ class CLIPVisionFolderProbe(FolderProbeBase):
return BaseModelType.Any return BaseModelType.Any
class SpandrelImageToImageFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class T2IAdapterFolderProbe(FolderProbeBase): class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json" config_file = self.model_path / "config.json"
@ -820,6 +829,7 @@ ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderPro
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
@ -829,5 +839,6 @@ ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpoi
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@ -32,6 +32,8 @@ import {
isSDXLMainModelFieldInputTemplate, isSDXLMainModelFieldInputTemplate,
isSDXLRefinerModelFieldInputInstance, isSDXLRefinerModelFieldInputInstance,
isSDXLRefinerModelFieldInputTemplate, isSDXLRefinerModelFieldInputTemplate,
isSpandrelImageToImageModelFieldInputInstance,
isSpandrelImageToImageModelFieldInputTemplate,
isStringFieldInputInstance, isStringFieldInputInstance,
isStringFieldInputTemplate, isStringFieldInputTemplate,
isT2IAdapterModelFieldInputInstance, isT2IAdapterModelFieldInputInstance,
@ -54,6 +56,7 @@ import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent'; import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent'; import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent'; import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent'; import StringFieldInputComponent from './inputs/StringFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent'; import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent'; import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
@ -125,6 +128,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) { if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) {
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
} }
if (isSpandrelImageToImageModelFieldInputInstance(fieldInstance) && isSpandrelImageToImageModelFieldInputTemplate(fieldTemplate)) {
return <SpandrelImageToImageModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) { if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
} }

View File

@ -0,0 +1,56 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldSpandrelImageToImageModelValueChanged, } from 'features/nodes/store/nodesSlice';
import type {
SpandrelImageToImageModelFieldInputInstance,
SpandrelImageToImageModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType';
import type { SpandrelImageToImageModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const SpandrelImageToImageModelFieldInputComponent = (
props: FieldComponentProps<SpandrelImageToImageModelFieldInputInstance, SpandrelImageToImageModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
const _onChange = useCallback(
(value: SpandrelImageToImageModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldSpandrelImageToImageModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<Tooltip label={value?.description}>
<FormControl className="nowheel nodrag" isInvalid={!value}>
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
</FormControl>
</Tooltip>
);
};
export default memo(SpandrelImageToImageModelFieldInputComponent);

View File

@ -19,6 +19,7 @@ import type {
ModelIdentifierFieldValue, ModelIdentifierFieldValue,
SchedulerFieldValue, SchedulerFieldValue,
SDXLRefinerModelFieldValue, SDXLRefinerModelFieldValue,
SpandrelImageToImageModelFieldValue,
StatefulFieldValue, StatefulFieldValue,
StringFieldValue, StringFieldValue,
T2IAdapterModelFieldValue, T2IAdapterModelFieldValue,
@ -39,6 +40,7 @@ import {
zModelIdentifierFieldValue, zModelIdentifierFieldValue,
zSchedulerFieldValue, zSchedulerFieldValue,
zSDXLRefinerModelFieldValue, zSDXLRefinerModelFieldValue,
zSpandrelImageToImageModelFieldValue,
zStatefulFieldValue, zStatefulFieldValue,
zStringFieldValue, zStringFieldValue,
zT2IAdapterModelFieldValue, zT2IAdapterModelFieldValue,
@ -333,6 +335,9 @@ export const nodesSlice = createSlice({
fieldT2IAdapterModelValueChanged: (state, action: FieldValueAction<T2IAdapterModelFieldValue>) => { fieldT2IAdapterModelValueChanged: (state, action: FieldValueAction<T2IAdapterModelFieldValue>) => {
fieldValueReducer(state, action, zT2IAdapterModelFieldValue); fieldValueReducer(state, action, zT2IAdapterModelFieldValue);
}, },
fieldSpandrelImageToImageModelValueChanged: (state, action: FieldValueAction<SpandrelImageToImageModelFieldValue>) => {
fieldValueReducer(state, action, zSpandrelImageToImageModelFieldValue);
},
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => { fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
fieldValueReducer(state, action, zEnumFieldValue); fieldValueReducer(state, action, zEnumFieldValue);
}, },
@ -384,6 +389,7 @@ export const {
fieldImageValueChanged, fieldImageValueChanged,
fieldIPAdapterModelValueChanged, fieldIPAdapterModelValueChanged,
fieldT2IAdapterModelValueChanged, fieldT2IAdapterModelValueChanged,
fieldSpandrelImageToImageModelValueChanged,
fieldLabelChanged, fieldLabelChanged,
fieldLoRAModelValueChanged, fieldLoRAModelValueChanged,
fieldModelIdentifierValueChanged, fieldModelIdentifierValueChanged,

View File

@ -66,6 +66,7 @@ const zModelType = z.enum([
'embedding', 'embedding',
'onnx', 'onnx',
'clip_vision', 'clip_vision',
'spandrel_image_to_image',
]); ]);
const zSubModelType = z.enum([ const zSubModelType = z.enum([
'unet', 'unet',

View File

@ -38,6 +38,7 @@ export const MODEL_TYPES = [
'VAEField', 'VAEField',
'CLIPField', 'CLIPField',
'T2IAdapterModelField', 'T2IAdapterModelField',
'SpandrelImageToImageModelField',
]; ];
/** /**
@ -62,6 +63,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
MainModelField: 'teal.500', MainModelField: 'teal.500',
SDXLMainModelField: 'teal.500', SDXLMainModelField: 'teal.500',
SDXLRefinerModelField: 'teal.500', SDXLRefinerModelField: 'teal.500',
SpandrelImageToImageModelField: 'teal.500',
StringField: 'yellow.500', StringField: 'yellow.500',
T2IAdapterField: 'teal.500', T2IAdapterField: 'teal.500',
T2IAdapterModelField: 'teal.500', T2IAdapterModelField: 'teal.500',

View File

@ -139,6 +139,10 @@ const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
name: z.literal('T2IAdapterModelField'), name: z.literal('T2IAdapterModelField'),
originalType: zStatelessFieldType.optional(), originalType: zStatelessFieldType.optional(),
}); });
const zSpandrelImageToImageModelFieldType = zFieldTypeBase.extend({
name: z.literal('SpandrelImageToImageModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSchedulerFieldType = zFieldTypeBase.extend({ const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'), name: z.literal('SchedulerField'),
originalType: zStatelessFieldType.optional(), originalType: zStatelessFieldType.optional(),
@ -160,6 +164,7 @@ const zStatefulFieldType = z.union([
zControlNetModelFieldType, zControlNetModelFieldType,
zIPAdapterModelFieldType, zIPAdapterModelFieldType,
zT2IAdapterModelFieldType, zT2IAdapterModelFieldType,
zSpandrelImageToImageModelFieldType,
zColorFieldType, zColorFieldType,
zSchedulerFieldType, zSchedulerFieldType,
]); ]);
@ -581,6 +586,30 @@ export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAda
zT2IAdapterModelFieldInputTemplate.safeParse(val).success; zT2IAdapterModelFieldInputTemplate.safeParse(val).success;
// #endregion // #endregion
// #region SpandrelModelToModelField
export const zSpandrelImageToImageModelFieldValue = zModelIdentifierField.optional();
const zSpandrelImageToImageModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zSpandrelImageToImageModelFieldValue,
});
const zSpandrelImageToImageModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSpandrelImageToImageModelFieldType,
originalType: zFieldType.optional(),
default: zSpandrelImageToImageModelFieldValue,
});
const zSpandrelImageToImageModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zSpandrelImageToImageModelFieldType,
});
export type SpandrelImageToImageModelFieldValue = z.infer<typeof zSpandrelImageToImageModelFieldValue>;
export type SpandrelImageToImageModelFieldInputInstance = z.infer<typeof zSpandrelImageToImageModelFieldInputInstance>;
export type SpandrelImageToImageModelFieldInputTemplate = z.infer<typeof zSpandrelImageToImageModelFieldInputTemplate>;
export const isSpandrelImageToImageModelFieldInputInstance = (val: unknown): val is SpandrelImageToImageModelFieldInputInstance =>
zSpandrelImageToImageModelFieldInputInstance.safeParse(val).success;
export const isSpandrelImageToImageModelFieldInputTemplate = (val: unknown): val is SpandrelImageToImageModelFieldInputTemplate =>
zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SchedulerField // #region SchedulerField
export const zSchedulerFieldValue = zSchedulerField.optional(); export const zSchedulerFieldValue = zSchedulerField.optional();
@ -667,6 +696,7 @@ export const zStatefulFieldValue = z.union([
zControlNetModelFieldValue, zControlNetModelFieldValue,
zIPAdapterModelFieldValue, zIPAdapterModelFieldValue,
zT2IAdapterModelFieldValue, zT2IAdapterModelFieldValue,
zSpandrelImageToImageModelFieldValue,
zColorFieldValue, zColorFieldValue,
zSchedulerFieldValue, zSchedulerFieldValue,
]); ]);
@ -694,6 +724,7 @@ const zStatefulFieldInputInstance = z.union([
zControlNetModelFieldInputInstance, zControlNetModelFieldInputInstance,
zIPAdapterModelFieldInputInstance, zIPAdapterModelFieldInputInstance,
zT2IAdapterModelFieldInputInstance, zT2IAdapterModelFieldInputInstance,
zSpandrelImageToImageModelFieldInputInstance,
zColorFieldInputInstance, zColorFieldInputInstance,
zSchedulerFieldInputInstance, zSchedulerFieldInputInstance,
]); ]);
@ -722,6 +753,7 @@ const zStatefulFieldInputTemplate = z.union([
zControlNetModelFieldInputTemplate, zControlNetModelFieldInputTemplate,
zIPAdapterModelFieldInputTemplate, zIPAdapterModelFieldInputTemplate,
zT2IAdapterModelFieldInputTemplate, zT2IAdapterModelFieldInputTemplate,
zSpandrelImageToImageModelFieldInputTemplate,
zColorFieldInputTemplate, zColorFieldInputTemplate,
zSchedulerFieldInputTemplate, zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate, zStatelessFieldInputTemplate,
@ -751,6 +783,7 @@ const zStatefulFieldOutputTemplate = z.union([
zControlNetModelFieldOutputTemplate, zControlNetModelFieldOutputTemplate,
zIPAdapterModelFieldOutputTemplate, zIPAdapterModelFieldOutputTemplate,
zT2IAdapterModelFieldOutputTemplate, zT2IAdapterModelFieldOutputTemplate,
zSpandrelImageToImageModelFieldOutputTemplate,
zColorFieldOutputTemplate, zColorFieldOutputTemplate,
zSchedulerFieldOutputTemplate, zSchedulerFieldOutputTemplate,
]); ]);

View File

@ -18,6 +18,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
SDXLRefinerModelField: undefined, SDXLRefinerModelField: undefined,
StringField: '', StringField: '',
T2IAdapterModelField: undefined, T2IAdapterModelField: undefined,
SpandrelImageToImageModelField: undefined,
VAEModelField: undefined, VAEModelField: undefined,
ControlNetModelField: undefined, ControlNetModelField: undefined,
}; };

View File

@ -17,6 +17,7 @@ import type {
SchedulerFieldInputTemplate, SchedulerFieldInputTemplate,
SDXLMainModelFieldInputTemplate, SDXLMainModelFieldInputTemplate,
SDXLRefinerModelFieldInputTemplate, SDXLRefinerModelFieldInputTemplate,
SpandrelImageToImageModelFieldInputTemplate,
StatefulFieldType, StatefulFieldType,
StatelessFieldInputTemplate, StatelessFieldInputTemplate,
StringFieldInputTemplate, StringFieldInputTemplate,
@ -263,6 +264,17 @@ const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<T2IAdapt
return template; return template;
}; };
const buildSpandrelImageToImageModelFieldInputTemplate: FieldInputTemplateBuilder<
SpandrelImageToImageModelFieldInputTemplate
> = ({ schemaObject, baseField, fieldType }) => {
const template: SpandrelImageToImageModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<BoardFieldInputTemplate> = ({ const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<BoardFieldInputTemplate> = ({
schemaObject, schemaObject,
baseField, baseField,
@ -377,6 +389,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate, SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
StringField: buildStringFieldInputTemplate, StringField: buildStringFieldInputTemplate,
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate, T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
VAEModelField: buildVAEModelFieldInputTemplate, VAEModelField: buildVAEModelFieldInputTemplate,
} as const; } as const;

View File

@ -35,6 +35,7 @@ const MODEL_FIELD_TYPES = [
'ControlNetModelField', 'ControlNetModelField',
'IPAdapterModelField', 'IPAdapterModelField',
'T2IAdapterModelField', 'T2IAdapterModelField',
'SpandrelImageToImageModelField',
]; ];
/** /**

View File

@ -11,6 +11,7 @@ import {
isNonSDXLMainModelConfig, isNonSDXLMainModelConfig,
isRefinerMainModelModelConfig, isRefinerMainModelModelConfig,
isSDXLMainModelModelConfig, isSDXLMainModelModelConfig,
isSpandrelImageToImageModelConfig,
isT2IAdapterModelConfig, isT2IAdapterModelConfig,
isTIModelConfig, isTIModelConfig,
isVAEModelConfig, isVAEModelConfig,
@ -39,6 +40,7 @@ export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig); export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
export const useControlNetModels = buildModelsHook(isControlNetModelConfig); export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig); export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
export const useSpandrelImageToImageModels = buildModelsHook(isSpandrelImageToImageModelConfig);
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig); export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
export const useEmbeddingModels = buildModelsHook(isTIModelConfig); export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
export const useVAEModels = buildModelsHook(isVAEModelConfig); export const useVAEModels = buildModelsHook(isVAEModelConfig);

File diff suppressed because one or more lines are too long

View File

@ -49,6 +49,7 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig']; export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig']; export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
export type T2IAdapterModelConfig = S['T2IAdapterConfig']; export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig']; type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
type DiffusersModelConfig = S['MainDiffusersConfig']; type DiffusersModelConfig = S['MainDiffusersConfig'];
type CheckpointModelConfig = S['MainCheckpointConfig']; type CheckpointModelConfig = S['MainCheckpointConfig'];
@ -60,6 +61,7 @@ export type AnyModelConfig =
| ControlNetModelConfig | ControlNetModelConfig
| IPAdapterModelConfig | IPAdapterModelConfig
| T2IAdapterModelConfig | T2IAdapterModelConfig
| SpandrelImageToImageModelConfig
| TextualInversionModelConfig | TextualInversionModelConfig
| MainModelConfig | MainModelConfig
| CLIPVisionDiffusersConfig; | CLIPVisionDiffusersConfig;
@ -84,6 +86,10 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd
return config.type === 't2i_adapter'; return config.type === 't2i_adapter';
}; };
export const isSpandrelImageToImageModelConfig = (config: AnyModelConfig): config is SpandrelImageToImageModelConfig => {
return config.type === 'spandrel_image_to_image';
}
export const isControlAdapterModelConfig = ( export const isControlAdapterModelConfig = (
config: AnyModelConfig config: AnyModelConfig
): config is ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig => { ): config is ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig => {