mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP - A bunch of boilerplate to support Spandrel Image-to-Image models throughout the model manager and the frontend.
This commit is contained in:
@ -32,6 +32,8 @@ import {
|
||||
isSDXLMainModelFieldInputTemplate,
|
||||
isSDXLRefinerModelFieldInputInstance,
|
||||
isSDXLRefinerModelFieldInputTemplate,
|
||||
isSpandrelImageToImageModelFieldInputInstance,
|
||||
isSpandrelImageToImageModelFieldInputTemplate,
|
||||
isStringFieldInputInstance,
|
||||
isStringFieldInputTemplate,
|
||||
isT2IAdapterModelFieldInputInstance,
|
||||
@ -54,6 +56,7 @@ import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
|
||||
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
|
||||
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
|
||||
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
||||
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
|
||||
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
|
||||
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
||||
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
|
||||
@ -125,6 +128,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isSpandrelImageToImageModelFieldInputInstance(fieldInstance) && isSpandrelImageToImageModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <SpandrelImageToImageModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
|
||||
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
@ -0,0 +1,56 @@
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldSpandrelImageToImageModelValueChanged, } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
SpandrelImageToImageModelFieldInputInstance,
|
||||
SpandrelImageToImageModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType';
|
||||
import type { SpandrelImageToImageModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const SpandrelImageToImageModelFieldInputComponent = (
|
||||
props: FieldComponentProps<SpandrelImageToImageModelFieldInputInstance, SpandrelImageToImageModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: SpandrelImageToImageModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
|
||||
fieldSpandrelImageToImageModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const { options, value, onChange } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<Tooltip label={value?.description}>
|
||||
<FormControl className="nowheel nodrag" isInvalid={!value}>
|
||||
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SpandrelImageToImageModelFieldInputComponent);
|
@ -19,6 +19,7 @@ import type {
|
||||
ModelIdentifierFieldValue,
|
||||
SchedulerFieldValue,
|
||||
SDXLRefinerModelFieldValue,
|
||||
SpandrelImageToImageModelFieldValue,
|
||||
StatefulFieldValue,
|
||||
StringFieldValue,
|
||||
T2IAdapterModelFieldValue,
|
||||
@ -39,6 +40,7 @@ import {
|
||||
zModelIdentifierFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zSpandrelImageToImageModelFieldValue,
|
||||
zStatefulFieldValue,
|
||||
zStringFieldValue,
|
||||
zT2IAdapterModelFieldValue,
|
||||
@ -333,6 +335,9 @@ export const nodesSlice = createSlice({
|
||||
fieldT2IAdapterModelValueChanged: (state, action: FieldValueAction<T2IAdapterModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zT2IAdapterModelFieldValue);
|
||||
},
|
||||
fieldSpandrelImageToImageModelValueChanged: (state, action: FieldValueAction<SpandrelImageToImageModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zSpandrelImageToImageModelFieldValue);
|
||||
},
|
||||
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
|
||||
fieldValueReducer(state, action, zEnumFieldValue);
|
||||
},
|
||||
@ -384,6 +389,7 @@ export const {
|
||||
fieldImageValueChanged,
|
||||
fieldIPAdapterModelValueChanged,
|
||||
fieldT2IAdapterModelValueChanged,
|
||||
fieldSpandrelImageToImageModelValueChanged,
|
||||
fieldLabelChanged,
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldModelIdentifierValueChanged,
|
||||
|
@ -66,6 +66,7 @@ const zModelType = z.enum([
|
||||
'embedding',
|
||||
'onnx',
|
||||
'clip_vision',
|
||||
'spandrel_image_to_image',
|
||||
]);
|
||||
const zSubModelType = z.enum([
|
||||
'unet',
|
||||
|
@ -38,6 +38,7 @@ export const MODEL_TYPES = [
|
||||
'VAEField',
|
||||
'CLIPField',
|
||||
'T2IAdapterModelField',
|
||||
'SpandrelImageToImageModelField',
|
||||
];
|
||||
|
||||
/**
|
||||
@ -62,6 +63,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
||||
MainModelField: 'teal.500',
|
||||
SDXLMainModelField: 'teal.500',
|
||||
SDXLRefinerModelField: 'teal.500',
|
||||
SpandrelImageToImageModelField: 'teal.500',
|
||||
StringField: 'yellow.500',
|
||||
T2IAdapterField: 'teal.500',
|
||||
T2IAdapterModelField: 'teal.500',
|
||||
|
@ -139,6 +139,10 @@ const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('T2IAdapterModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSpandrelImageToImageModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SpandrelImageToImageModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SchedulerField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@ -160,6 +164,7 @@ const zStatefulFieldType = z.union([
|
||||
zControlNetModelFieldType,
|
||||
zIPAdapterModelFieldType,
|
||||
zT2IAdapterModelFieldType,
|
||||
zSpandrelImageToImageModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
]);
|
||||
@ -581,6 +586,30 @@ export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAda
|
||||
zT2IAdapterModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SpandrelModelToModelField
|
||||
|
||||
export const zSpandrelImageToImageModelFieldValue = zModelIdentifierField.optional();
|
||||
const zSpandrelImageToImageModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zSpandrelImageToImageModelFieldValue,
|
||||
});
|
||||
const zSpandrelImageToImageModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSpandrelImageToImageModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zSpandrelImageToImageModelFieldValue,
|
||||
});
|
||||
const zSpandrelImageToImageModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSpandrelImageToImageModelFieldType,
|
||||
});
|
||||
export type SpandrelImageToImageModelFieldValue = z.infer<typeof zSpandrelImageToImageModelFieldValue>;
|
||||
export type SpandrelImageToImageModelFieldInputInstance = z.infer<typeof zSpandrelImageToImageModelFieldInputInstance>;
|
||||
export type SpandrelImageToImageModelFieldInputTemplate = z.infer<typeof zSpandrelImageToImageModelFieldInputTemplate>;
|
||||
export const isSpandrelImageToImageModelFieldInputInstance = (val: unknown): val is SpandrelImageToImageModelFieldInputInstance =>
|
||||
zSpandrelImageToImageModelFieldInputInstance.safeParse(val).success;
|
||||
export const isSpandrelImageToImageModelFieldInputTemplate = (val: unknown): val is SpandrelImageToImageModelFieldInputTemplate =>
|
||||
zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
|
||||
// #region SchedulerField
|
||||
|
||||
export const zSchedulerFieldValue = zSchedulerField.optional();
|
||||
@ -667,6 +696,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zControlNetModelFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
zT2IAdapterModelFieldValue,
|
||||
zSpandrelImageToImageModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
]);
|
||||
@ -694,6 +724,7 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zControlNetModelFieldInputInstance,
|
||||
zIPAdapterModelFieldInputInstance,
|
||||
zT2IAdapterModelFieldInputInstance,
|
||||
zSpandrelImageToImageModelFieldInputInstance,
|
||||
zColorFieldInputInstance,
|
||||
zSchedulerFieldInputInstance,
|
||||
]);
|
||||
@ -722,6 +753,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zControlNetModelFieldInputTemplate,
|
||||
zIPAdapterModelFieldInputTemplate,
|
||||
zT2IAdapterModelFieldInputTemplate,
|
||||
zSpandrelImageToImageModelFieldInputTemplate,
|
||||
zColorFieldInputTemplate,
|
||||
zSchedulerFieldInputTemplate,
|
||||
zStatelessFieldInputTemplate,
|
||||
@ -751,6 +783,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zControlNetModelFieldOutputTemplate,
|
||||
zIPAdapterModelFieldOutputTemplate,
|
||||
zT2IAdapterModelFieldOutputTemplate,
|
||||
zSpandrelImageToImageModelFieldOutputTemplate,
|
||||
zColorFieldOutputTemplate,
|
||||
zSchedulerFieldOutputTemplate,
|
||||
]);
|
||||
|
@ -18,6 +18,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
SDXLRefinerModelField: undefined,
|
||||
StringField: '',
|
||||
T2IAdapterModelField: undefined,
|
||||
SpandrelImageToImageModelField: undefined,
|
||||
VAEModelField: undefined,
|
||||
ControlNetModelField: undefined,
|
||||
};
|
||||
|
@ -17,6 +17,7 @@ import type {
|
||||
SchedulerFieldInputTemplate,
|
||||
SDXLMainModelFieldInputTemplate,
|
||||
SDXLRefinerModelFieldInputTemplate,
|
||||
SpandrelImageToImageModelFieldInputTemplate,
|
||||
StatefulFieldType,
|
||||
StatelessFieldInputTemplate,
|
||||
StringFieldInputTemplate,
|
||||
@ -263,6 +264,17 @@ const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<T2IAdapt
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildSpandrelImageToImageModelFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
SpandrelImageToImageModelFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, fieldType }) => {
|
||||
const template: SpandrelImageToImageModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<BoardFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -377,6 +389,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
||||
StringField: buildStringFieldInputTemplate,
|
||||
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
|
||||
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
|
||||
VAEModelField: buildVAEModelFieldInputTemplate,
|
||||
} as const;
|
||||
|
||||
|
@ -35,6 +35,7 @@ const MODEL_FIELD_TYPES = [
|
||||
'ControlNetModelField',
|
||||
'IPAdapterModelField',
|
||||
'T2IAdapterModelField',
|
||||
'SpandrelImageToImageModelField',
|
||||
];
|
||||
|
||||
/**
|
||||
|
@ -11,6 +11,7 @@ import {
|
||||
isNonSDXLMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isSDXLMainModelModelConfig,
|
||||
isSpandrelImageToImageModelConfig,
|
||||
isT2IAdapterModelConfig,
|
||||
isTIModelConfig,
|
||||
isVAEModelConfig,
|
||||
@ -39,6 +40,7 @@ export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
||||
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
||||
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
||||
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
|
||||
export const useSpandrelImageToImageModels = buildModelsHook(isSpandrelImageToImageModelConfig);
|
||||
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
|
||||
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
|
||||
export const useVAEModels = buildModelsHook(isVAEModelConfig);
|
||||
|
File diff suppressed because one or more lines are too long
@ -49,6 +49,7 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
||||
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
||||
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
|
||||
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
||||
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
|
||||
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
||||
type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||
type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||
@ -60,6 +61,7 @@ export type AnyModelConfig =
|
||||
| ControlNetModelConfig
|
||||
| IPAdapterModelConfig
|
||||
| T2IAdapterModelConfig
|
||||
| SpandrelImageToImageModelConfig
|
||||
| TextualInversionModelConfig
|
||||
| MainModelConfig
|
||||
| CLIPVisionDiffusersConfig;
|
||||
@ -84,6 +86,10 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd
|
||||
return config.type === 't2i_adapter';
|
||||
};
|
||||
|
||||
export const isSpandrelImageToImageModelConfig = (config: AnyModelConfig): config is SpandrelImageToImageModelConfig => {
|
||||
return config.type === 'spandrel_image_to_image';
|
||||
}
|
||||
|
||||
export const isControlAdapterModelConfig = (
|
||||
config: AnyModelConfig
|
||||
): config is ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig => {
|
||||
|
Reference in New Issue
Block a user