mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP - A bunch of boilerplate to support Spandrel Image-to-Image models throughout the model manager and the frontend.
This commit is contained in:
parent
95079dc7d4
commit
29c8ddfb88
@ -48,6 +48,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
|||||||
ControlNetModel = "ControlNetModelField"
|
ControlNetModel = "ControlNetModelField"
|
||||||
IPAdapterModel = "IPAdapterModelField"
|
IPAdapterModel = "IPAdapterModelField"
|
||||||
T2IAdapterModel = "T2IAdapterModelField"
|
T2IAdapterModel = "T2IAdapterModelField"
|
||||||
|
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Misc Field Types
|
# region Misc Field Types
|
||||||
@ -134,6 +135,7 @@ class FieldDescriptions:
|
|||||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||||
|
spandrel_image_to_image_model = "Spandrel Image-to-Image model"
|
||||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||||
raw_prompt = "Raw prompt text (no parsing)"
|
raw_prompt = "Raw prompt text (no parsing)"
|
||||||
|
@ -373,6 +373,17 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
|||||||
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
|
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
|
||||||
|
|
||||||
|
|
||||||
|
class SpandrelImageToImageConfig(ModelConfigBase):
|
||||||
|
"""Model config for Spandrel Image to Image models."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
|
||||||
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
|
||||||
|
|
||||||
|
|
||||||
def get_model_discriminator_value(v: Any) -> str:
|
def get_model_discriminator_value(v: Any) -> str:
|
||||||
"""
|
"""
|
||||||
Computes the discriminator value for a model config.
|
Computes the discriminator value for a model config.
|
||||||
@ -409,6 +420,7 @@ AnyModelConfig = Annotated[
|
|||||||
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
||||||
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
|
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
|
||||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||||
|
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
||||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||||
],
|
],
|
||||||
Discriminator(get_model_discriminator_value),
|
Discriminator(get_model_discriminator_value),
|
||||||
|
@ -243,10 +243,14 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
# Check if the model can be loaded as a SpandrelImageToImageModel.
|
# Check if the model can be loaded as a SpandrelImageToImageModel.
|
||||||
try:
|
try:
|
||||||
_ = SpandrelImageToImageModel.load_from_state_dict(ckpt)
|
# TODO(ryand): Figure out why load_from_state_dict() doesn't work as expected.
|
||||||
|
# _ = SpandrelImageToImageModel.load_from_state_dict(ckpt)
|
||||||
|
_ = SpandrelImageToImageModel.load_from_file(model_path)
|
||||||
return ModelType.SpandrelImageToImage
|
return ModelType.SpandrelImageToImage
|
||||||
except Exception:
|
except Exception as e:
|
||||||
# TODO(ryand): Catch a more specific exception type here if we can.
|
# TODO(ryand): Catch a more specific exception type here if we can.
|
||||||
|
# TODO(ryand): Delete this print statement.
|
||||||
|
print(e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
|
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
|
||||||
@ -579,9 +583,9 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class SpandrelImageToImageModelProbe(CheckpointProbeBase):
|
class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
raise NotImplementedError()
|
return BaseModelType.Any
|
||||||
|
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
@ -791,6 +795,11 @@ class CLIPVisionFolderProbe(FolderProbeBase):
|
|||||||
return BaseModelType.Any
|
return BaseModelType.Any
|
||||||
|
|
||||||
|
|
||||||
|
class SpandrelImageToImageFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterFolderProbe(FolderProbeBase):
|
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
config_file = self.model_path / "config.json"
|
config_file = self.model_path / "config.json"
|
||||||
@ -820,6 +829,7 @@ ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderPro
|
|||||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
|
||||||
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||||
@ -829,5 +839,6 @@ ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpoi
|
|||||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
|
||||||
|
|
||||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||||
|
@ -32,6 +32,8 @@ import {
|
|||||||
isSDXLMainModelFieldInputTemplate,
|
isSDXLMainModelFieldInputTemplate,
|
||||||
isSDXLRefinerModelFieldInputInstance,
|
isSDXLRefinerModelFieldInputInstance,
|
||||||
isSDXLRefinerModelFieldInputTemplate,
|
isSDXLRefinerModelFieldInputTemplate,
|
||||||
|
isSpandrelImageToImageModelFieldInputInstance,
|
||||||
|
isSpandrelImageToImageModelFieldInputTemplate,
|
||||||
isStringFieldInputInstance,
|
isStringFieldInputInstance,
|
||||||
isStringFieldInputTemplate,
|
isStringFieldInputTemplate,
|
||||||
isT2IAdapterModelFieldInputInstance,
|
isT2IAdapterModelFieldInputInstance,
|
||||||
@ -54,6 +56,7 @@ import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
|
|||||||
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
|
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
|
||||||
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
|
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
|
||||||
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
||||||
|
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
|
||||||
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
|
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
|
||||||
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
||||||
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
|
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
|
||||||
@ -125,6 +128,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) {
|
if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) {
|
||||||
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isSpandrelImageToImageModelFieldInputInstance(fieldInstance) && isSpandrelImageToImageModelFieldInputTemplate(fieldTemplate)) {
|
||||||
|
return <SpandrelImageToImageModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
|
}
|
||||||
|
|
||||||
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
|
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
|
||||||
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,56 @@
|
|||||||
|
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
|
import { fieldSpandrelImageToImageModelValueChanged, } from 'features/nodes/store/nodesSlice';
|
||||||
|
import type {
|
||||||
|
SpandrelImageToImageModelFieldInputInstance,
|
||||||
|
SpandrelImageToImageModelFieldInputTemplate,
|
||||||
|
} from 'features/nodes/types/field';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType';
|
||||||
|
import type { SpandrelImageToImageModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
import type { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
const SpandrelImageToImageModelFieldInputComponent = (
|
||||||
|
props: FieldComponentProps<SpandrelImageToImageModelFieldInputInstance, SpandrelImageToImageModelFieldInputTemplate>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
|
||||||
|
|
||||||
|
const _onChange = useCallback(
|
||||||
|
(value: SpandrelImageToImageModelConfig | null) => {
|
||||||
|
if (!value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(
|
||||||
|
|
||||||
|
fieldSpandrelImageToImageModelValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { options, value, onChange } = useGroupedModelCombobox({
|
||||||
|
modelConfigs,
|
||||||
|
onChange: _onChange,
|
||||||
|
selectedModel: field.value,
|
||||||
|
isLoading,
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Tooltip label={value?.description}>
|
||||||
|
<FormControl className="nowheel nodrag" isInvalid={!value}>
|
||||||
|
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
|
||||||
|
</FormControl>
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(SpandrelImageToImageModelFieldInputComponent);
|
@ -19,6 +19,7 @@ import type {
|
|||||||
ModelIdentifierFieldValue,
|
ModelIdentifierFieldValue,
|
||||||
SchedulerFieldValue,
|
SchedulerFieldValue,
|
||||||
SDXLRefinerModelFieldValue,
|
SDXLRefinerModelFieldValue,
|
||||||
|
SpandrelImageToImageModelFieldValue,
|
||||||
StatefulFieldValue,
|
StatefulFieldValue,
|
||||||
StringFieldValue,
|
StringFieldValue,
|
||||||
T2IAdapterModelFieldValue,
|
T2IAdapterModelFieldValue,
|
||||||
@ -39,6 +40,7 @@ import {
|
|||||||
zModelIdentifierFieldValue,
|
zModelIdentifierFieldValue,
|
||||||
zSchedulerFieldValue,
|
zSchedulerFieldValue,
|
||||||
zSDXLRefinerModelFieldValue,
|
zSDXLRefinerModelFieldValue,
|
||||||
|
zSpandrelImageToImageModelFieldValue,
|
||||||
zStatefulFieldValue,
|
zStatefulFieldValue,
|
||||||
zStringFieldValue,
|
zStringFieldValue,
|
||||||
zT2IAdapterModelFieldValue,
|
zT2IAdapterModelFieldValue,
|
||||||
@ -333,6 +335,9 @@ export const nodesSlice = createSlice({
|
|||||||
fieldT2IAdapterModelValueChanged: (state, action: FieldValueAction<T2IAdapterModelFieldValue>) => {
|
fieldT2IAdapterModelValueChanged: (state, action: FieldValueAction<T2IAdapterModelFieldValue>) => {
|
||||||
fieldValueReducer(state, action, zT2IAdapterModelFieldValue);
|
fieldValueReducer(state, action, zT2IAdapterModelFieldValue);
|
||||||
},
|
},
|
||||||
|
fieldSpandrelImageToImageModelValueChanged: (state, action: FieldValueAction<SpandrelImageToImageModelFieldValue>) => {
|
||||||
|
fieldValueReducer(state, action, zSpandrelImageToImageModelFieldValue);
|
||||||
|
},
|
||||||
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
|
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
|
||||||
fieldValueReducer(state, action, zEnumFieldValue);
|
fieldValueReducer(state, action, zEnumFieldValue);
|
||||||
},
|
},
|
||||||
@ -384,6 +389,7 @@ export const {
|
|||||||
fieldImageValueChanged,
|
fieldImageValueChanged,
|
||||||
fieldIPAdapterModelValueChanged,
|
fieldIPAdapterModelValueChanged,
|
||||||
fieldT2IAdapterModelValueChanged,
|
fieldT2IAdapterModelValueChanged,
|
||||||
|
fieldSpandrelImageToImageModelValueChanged,
|
||||||
fieldLabelChanged,
|
fieldLabelChanged,
|
||||||
fieldLoRAModelValueChanged,
|
fieldLoRAModelValueChanged,
|
||||||
fieldModelIdentifierValueChanged,
|
fieldModelIdentifierValueChanged,
|
||||||
|
@ -66,6 +66,7 @@ const zModelType = z.enum([
|
|||||||
'embedding',
|
'embedding',
|
||||||
'onnx',
|
'onnx',
|
||||||
'clip_vision',
|
'clip_vision',
|
||||||
|
'spandrel_image_to_image',
|
||||||
]);
|
]);
|
||||||
const zSubModelType = z.enum([
|
const zSubModelType = z.enum([
|
||||||
'unet',
|
'unet',
|
||||||
|
@ -38,6 +38,7 @@ export const MODEL_TYPES = [
|
|||||||
'VAEField',
|
'VAEField',
|
||||||
'CLIPField',
|
'CLIPField',
|
||||||
'T2IAdapterModelField',
|
'T2IAdapterModelField',
|
||||||
|
'SpandrelImageToImageModelField',
|
||||||
];
|
];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -62,6 +63,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
|||||||
MainModelField: 'teal.500',
|
MainModelField: 'teal.500',
|
||||||
SDXLMainModelField: 'teal.500',
|
SDXLMainModelField: 'teal.500',
|
||||||
SDXLRefinerModelField: 'teal.500',
|
SDXLRefinerModelField: 'teal.500',
|
||||||
|
SpandrelImageToImageModelField: 'teal.500',
|
||||||
StringField: 'yellow.500',
|
StringField: 'yellow.500',
|
||||||
T2IAdapterField: 'teal.500',
|
T2IAdapterField: 'teal.500',
|
||||||
T2IAdapterModelField: 'teal.500',
|
T2IAdapterModelField: 'teal.500',
|
||||||
|
@ -139,6 +139,10 @@ const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
|
|||||||
name: z.literal('T2IAdapterModelField'),
|
name: z.literal('T2IAdapterModelField'),
|
||||||
originalType: zStatelessFieldType.optional(),
|
originalType: zStatelessFieldType.optional(),
|
||||||
});
|
});
|
||||||
|
const zSpandrelImageToImageModelFieldType = zFieldTypeBase.extend({
|
||||||
|
name: z.literal('SpandrelImageToImageModelField'),
|
||||||
|
originalType: zStatelessFieldType.optional(),
|
||||||
|
});
|
||||||
const zSchedulerFieldType = zFieldTypeBase.extend({
|
const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||||
name: z.literal('SchedulerField'),
|
name: z.literal('SchedulerField'),
|
||||||
originalType: zStatelessFieldType.optional(),
|
originalType: zStatelessFieldType.optional(),
|
||||||
@ -160,6 +164,7 @@ const zStatefulFieldType = z.union([
|
|||||||
zControlNetModelFieldType,
|
zControlNetModelFieldType,
|
||||||
zIPAdapterModelFieldType,
|
zIPAdapterModelFieldType,
|
||||||
zT2IAdapterModelFieldType,
|
zT2IAdapterModelFieldType,
|
||||||
|
zSpandrelImageToImageModelFieldType,
|
||||||
zColorFieldType,
|
zColorFieldType,
|
||||||
zSchedulerFieldType,
|
zSchedulerFieldType,
|
||||||
]);
|
]);
|
||||||
@ -581,6 +586,30 @@ export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAda
|
|||||||
zT2IAdapterModelFieldInputTemplate.safeParse(val).success;
|
zT2IAdapterModelFieldInputTemplate.safeParse(val).success;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
|
// #region SpandrelModelToModelField
|
||||||
|
|
||||||
|
export const zSpandrelImageToImageModelFieldValue = zModelIdentifierField.optional();
|
||||||
|
const zSpandrelImageToImageModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||||
|
value: zSpandrelImageToImageModelFieldValue,
|
||||||
|
});
|
||||||
|
const zSpandrelImageToImageModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||||
|
type: zSpandrelImageToImageModelFieldType,
|
||||||
|
originalType: zFieldType.optional(),
|
||||||
|
default: zSpandrelImageToImageModelFieldValue,
|
||||||
|
});
|
||||||
|
const zSpandrelImageToImageModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
|
type: zSpandrelImageToImageModelFieldType,
|
||||||
|
});
|
||||||
|
export type SpandrelImageToImageModelFieldValue = z.infer<typeof zSpandrelImageToImageModelFieldValue>;
|
||||||
|
export type SpandrelImageToImageModelFieldInputInstance = z.infer<typeof zSpandrelImageToImageModelFieldInputInstance>;
|
||||||
|
export type SpandrelImageToImageModelFieldInputTemplate = z.infer<typeof zSpandrelImageToImageModelFieldInputTemplate>;
|
||||||
|
export const isSpandrelImageToImageModelFieldInputInstance = (val: unknown): val is SpandrelImageToImageModelFieldInputInstance =>
|
||||||
|
zSpandrelImageToImageModelFieldInputInstance.safeParse(val).success;
|
||||||
|
export const isSpandrelImageToImageModelFieldInputTemplate = (val: unknown): val is SpandrelImageToImageModelFieldInputTemplate =>
|
||||||
|
zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success;
|
||||||
|
// #endregion
|
||||||
|
|
||||||
|
|
||||||
// #region SchedulerField
|
// #region SchedulerField
|
||||||
|
|
||||||
export const zSchedulerFieldValue = zSchedulerField.optional();
|
export const zSchedulerFieldValue = zSchedulerField.optional();
|
||||||
@ -667,6 +696,7 @@ export const zStatefulFieldValue = z.union([
|
|||||||
zControlNetModelFieldValue,
|
zControlNetModelFieldValue,
|
||||||
zIPAdapterModelFieldValue,
|
zIPAdapterModelFieldValue,
|
||||||
zT2IAdapterModelFieldValue,
|
zT2IAdapterModelFieldValue,
|
||||||
|
zSpandrelImageToImageModelFieldValue,
|
||||||
zColorFieldValue,
|
zColorFieldValue,
|
||||||
zSchedulerFieldValue,
|
zSchedulerFieldValue,
|
||||||
]);
|
]);
|
||||||
@ -694,6 +724,7 @@ const zStatefulFieldInputInstance = z.union([
|
|||||||
zControlNetModelFieldInputInstance,
|
zControlNetModelFieldInputInstance,
|
||||||
zIPAdapterModelFieldInputInstance,
|
zIPAdapterModelFieldInputInstance,
|
||||||
zT2IAdapterModelFieldInputInstance,
|
zT2IAdapterModelFieldInputInstance,
|
||||||
|
zSpandrelImageToImageModelFieldInputInstance,
|
||||||
zColorFieldInputInstance,
|
zColorFieldInputInstance,
|
||||||
zSchedulerFieldInputInstance,
|
zSchedulerFieldInputInstance,
|
||||||
]);
|
]);
|
||||||
@ -722,6 +753,7 @@ const zStatefulFieldInputTemplate = z.union([
|
|||||||
zControlNetModelFieldInputTemplate,
|
zControlNetModelFieldInputTemplate,
|
||||||
zIPAdapterModelFieldInputTemplate,
|
zIPAdapterModelFieldInputTemplate,
|
||||||
zT2IAdapterModelFieldInputTemplate,
|
zT2IAdapterModelFieldInputTemplate,
|
||||||
|
zSpandrelImageToImageModelFieldInputTemplate,
|
||||||
zColorFieldInputTemplate,
|
zColorFieldInputTemplate,
|
||||||
zSchedulerFieldInputTemplate,
|
zSchedulerFieldInputTemplate,
|
||||||
zStatelessFieldInputTemplate,
|
zStatelessFieldInputTemplate,
|
||||||
@ -751,6 +783,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
|||||||
zControlNetModelFieldOutputTemplate,
|
zControlNetModelFieldOutputTemplate,
|
||||||
zIPAdapterModelFieldOutputTemplate,
|
zIPAdapterModelFieldOutputTemplate,
|
||||||
zT2IAdapterModelFieldOutputTemplate,
|
zT2IAdapterModelFieldOutputTemplate,
|
||||||
|
zSpandrelImageToImageModelFieldOutputTemplate,
|
||||||
zColorFieldOutputTemplate,
|
zColorFieldOutputTemplate,
|
||||||
zSchedulerFieldOutputTemplate,
|
zSchedulerFieldOutputTemplate,
|
||||||
]);
|
]);
|
||||||
|
@ -18,6 +18,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
|||||||
SDXLRefinerModelField: undefined,
|
SDXLRefinerModelField: undefined,
|
||||||
StringField: '',
|
StringField: '',
|
||||||
T2IAdapterModelField: undefined,
|
T2IAdapterModelField: undefined,
|
||||||
|
SpandrelImageToImageModelField: undefined,
|
||||||
VAEModelField: undefined,
|
VAEModelField: undefined,
|
||||||
ControlNetModelField: undefined,
|
ControlNetModelField: undefined,
|
||||||
};
|
};
|
||||||
|
@ -17,6 +17,7 @@ import type {
|
|||||||
SchedulerFieldInputTemplate,
|
SchedulerFieldInputTemplate,
|
||||||
SDXLMainModelFieldInputTemplate,
|
SDXLMainModelFieldInputTemplate,
|
||||||
SDXLRefinerModelFieldInputTemplate,
|
SDXLRefinerModelFieldInputTemplate,
|
||||||
|
SpandrelImageToImageModelFieldInputTemplate,
|
||||||
StatefulFieldType,
|
StatefulFieldType,
|
||||||
StatelessFieldInputTemplate,
|
StatelessFieldInputTemplate,
|
||||||
StringFieldInputTemplate,
|
StringFieldInputTemplate,
|
||||||
@ -263,6 +264,17 @@ const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<T2IAdapt
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildSpandrelImageToImageModelFieldInputTemplate: FieldInputTemplateBuilder<
|
||||||
|
SpandrelImageToImageModelFieldInputTemplate
|
||||||
|
> = ({ schemaObject, baseField, fieldType }) => {
|
||||||
|
const template: SpandrelImageToImageModelFieldInputTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: fieldType,
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<BoardFieldInputTemplate> = ({
|
const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<BoardFieldInputTemplate> = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -377,6 +389,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
|||||||
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
||||||
StringField: buildStringFieldInputTemplate,
|
StringField: buildStringFieldInputTemplate,
|
||||||
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
|
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
|
||||||
|
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
|
||||||
VAEModelField: buildVAEModelFieldInputTemplate,
|
VAEModelField: buildVAEModelFieldInputTemplate,
|
||||||
} as const;
|
} as const;
|
||||||
|
|
||||||
|
@ -35,6 +35,7 @@ const MODEL_FIELD_TYPES = [
|
|||||||
'ControlNetModelField',
|
'ControlNetModelField',
|
||||||
'IPAdapterModelField',
|
'IPAdapterModelField',
|
||||||
'T2IAdapterModelField',
|
'T2IAdapterModelField',
|
||||||
|
'SpandrelImageToImageModelField',
|
||||||
];
|
];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -11,6 +11,7 @@ import {
|
|||||||
isNonSDXLMainModelConfig,
|
isNonSDXLMainModelConfig,
|
||||||
isRefinerMainModelModelConfig,
|
isRefinerMainModelModelConfig,
|
||||||
isSDXLMainModelModelConfig,
|
isSDXLMainModelModelConfig,
|
||||||
|
isSpandrelImageToImageModelConfig,
|
||||||
isT2IAdapterModelConfig,
|
isT2IAdapterModelConfig,
|
||||||
isTIModelConfig,
|
isTIModelConfig,
|
||||||
isVAEModelConfig,
|
isVAEModelConfig,
|
||||||
@ -39,6 +40,7 @@ export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
|||||||
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
||||||
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
||||||
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
|
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
|
||||||
|
export const useSpandrelImageToImageModels = buildModelsHook(isSpandrelImageToImageModelConfig);
|
||||||
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
|
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
|
||||||
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
|
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
|
||||||
export const useVAEModels = buildModelsHook(isVAEModelConfig);
|
export const useVAEModels = buildModelsHook(isVAEModelConfig);
|
||||||
|
File diff suppressed because one or more lines are too long
@ -49,6 +49,7 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
|||||||
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
||||||
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
|
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
|
||||||
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
||||||
|
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
|
||||||
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
||||||
type DiffusersModelConfig = S['MainDiffusersConfig'];
|
type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||||
type CheckpointModelConfig = S['MainCheckpointConfig'];
|
type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||||
@ -60,6 +61,7 @@ export type AnyModelConfig =
|
|||||||
| ControlNetModelConfig
|
| ControlNetModelConfig
|
||||||
| IPAdapterModelConfig
|
| IPAdapterModelConfig
|
||||||
| T2IAdapterModelConfig
|
| T2IAdapterModelConfig
|
||||||
|
| SpandrelImageToImageModelConfig
|
||||||
| TextualInversionModelConfig
|
| TextualInversionModelConfig
|
||||||
| MainModelConfig
|
| MainModelConfig
|
||||||
| CLIPVisionDiffusersConfig;
|
| CLIPVisionDiffusersConfig;
|
||||||
@ -84,6 +86,10 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd
|
|||||||
return config.type === 't2i_adapter';
|
return config.type === 't2i_adapter';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const isSpandrelImageToImageModelConfig = (config: AnyModelConfig): config is SpandrelImageToImageModelConfig => {
|
||||||
|
return config.type === 'spandrel_image_to_image';
|
||||||
|
}
|
||||||
|
|
||||||
export const isControlAdapterModelConfig = (
|
export const isControlAdapterModelConfig = (
|
||||||
config: AnyModelConfig
|
config: AnyModelConfig
|
||||||
): config is ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig => {
|
): config is ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig => {
|
||||||
|
Loading…
Reference in New Issue
Block a user