feat: add Scheduler as field type

- update node schemas
- add `UIType.Scheduler`
- add field type to schema parser, input components
This commit is contained in:
psychedelicious 2023-08-17 18:58:01 +10:00
parent 210a3f9aa7
commit 98431b3de4
10 changed files with 159 additions and 6 deletions

View File

@ -143,6 +143,7 @@ class UIType(str, Enum):
# region Misc # region Misc
FilePath = "FilePath" FilePath = "FilePath"
Enum = "enum" Enum = "enum"
Scheduler = "Scheduler"
# endregion # endregion

View File

@ -119,7 +119,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start) denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SAMPLER_NAME_VALUES = InputField(default="euler", description=FieldDescriptions.scheduler) scheduler: SAMPLER_NAME_VALUES = InputField(
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
)
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection) unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection)
control: Union[ControlField, list[ControlField]] = InputField( control: Union[ControlField, list[ControlField]] = InputField(
default=None, description=FieldDescriptions.control, input=Input.Connection default=None, description=FieldDescriptions.control, input=Input.Connection

View File

@ -169,7 +169,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
ui_type=UIType.Float, ui_type=UIType.Float,
) )
scheduler: SAMPLER_NAME_VALUES = InputField( scheduler: SAMPLER_NAME_VALUES = InputField(
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler
) )
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision) precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
unet: UNetField = InputField( unet: UNetField = InputField(

View File

@ -1,4 +1,4 @@
import { Box } from '@chakra-ui/react'; import { Box, Text } from '@chakra-ui/react';
import { import {
useFieldData, useFieldData,
useFieldTemplate, useFieldTemplate,
@ -21,6 +21,7 @@ import MainModelInputField from './fieldTypes/MainModelInputField';
import NumberInputField from './fieldTypes/NumberInputField'; import NumberInputField from './fieldTypes/NumberInputField';
import RefinerModelInputField from './fieldTypes/RefinerModelInputField'; import RefinerModelInputField from './fieldTypes/RefinerModelInputField';
import SDXLMainModelInputField from './fieldTypes/SDXLMainModelInputField'; import SDXLMainModelInputField from './fieldTypes/SDXLMainModelInputField';
import SchedulerInputField from './fieldTypes/SchedulerInputField';
import StringInputField from './fieldTypes/StringInputField'; import StringInputField from './fieldTypes/StringInputField';
import UnetInputField from './fieldTypes/UnetInputField'; import UnetInputField from './fieldTypes/UnetInputField';
import VaeInputField from './fieldTypes/VaeInputField'; import VaeInputField from './fieldTypes/VaeInputField';
@ -286,7 +287,30 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
); );
} }
return <Box p={2}>Unknown field type: {field?.type}</Box>; if (field?.type === 'Scheduler' && fieldTemplate?.type === 'Scheduler') {
return (
<SchedulerInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
return (
<Box p={1}>
<Text
sx={{
fontSize: 'sm',
fontWeight: 600,
color: 'error.400',
_dark: { color: 'error.300' },
}}
>
Unknown field type: {field?.type}
</Text>
</Box>
);
}; };
export default memo(InputFieldRenderer); export default memo(InputFieldRenderer);

View File

@ -0,0 +1,75 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldSchedulerValueChanged } from 'features/nodes/store/nodesSlice';
import {
SchedulerInputFieldTemplate,
SchedulerInputFieldValue,
} from 'features/nodes/types/types';
import {
SCHEDULER_LABEL_MAP,
SchedulerParam,
} from 'features/parameters/types/parameterSchemas';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { FieldComponentProps } from './types';
const selector = createSelector(
[stateSelector],
({ ui }) => {
const { favoriteSchedulers: enabledSchedulers } = ui;
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
value: name,
label: label,
group: enabledSchedulers.includes(name as SchedulerParam)
? 'Favorites'
: undefined,
})).sort((a, b) => a.label.localeCompare(b.label));
return {
data,
};
},
defaultSelectorOptions
);
const SchedulerInputField = (
props: FieldComponentProps<
SchedulerInputFieldValue,
SchedulerInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data } = useAppSelector(selector);
const handleChange = useCallback(
(value: string | null) => {
if (!value) {
return;
}
dispatch(
fieldSchedulerValueChanged({
nodeId,
fieldName: field.name,
value: value as SchedulerParam,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<IAIMantineSearchableSelect
className="nowheel nodrag"
value={field.value}
data={data}
onChange={handleChange}
/>
);
};
export default memo(SchedulerInputField);

View File

@ -45,6 +45,7 @@ import {
MainModelInputFieldValue, MainModelInputFieldValue,
NodeStatus, NodeStatus,
NotesNodeData, NotesNodeData,
SchedulerInputFieldValue,
SDXLRefinerModelInputFieldValue, SDXLRefinerModelInputFieldValue,
StringInputFieldValue, StringInputFieldValue,
VaeModelInputFieldValue, VaeModelInputFieldValue,
@ -447,6 +448,12 @@ const nodesSlice = createSlice({
) => { ) => {
fieldValueReducer(state, action); fieldValueReducer(state, action);
}, },
fieldSchedulerValueChanged: (
state,
action: FieldValueAction<SchedulerInputFieldValue>
) => {
fieldValueReducer(state, action);
},
imageCollectionFieldValueChanged: ( imageCollectionFieldValueChanged: (
state, state,
action: PayloadAction<{ action: PayloadAction<{
@ -670,6 +677,7 @@ export const {
fieldEnumModelValueChanged, fieldEnumModelValueChanged,
fieldControlNetModelValueChanged, fieldControlNetModelValueChanged,
fieldRefinerModelValueChanged, fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
nodeIsOpenChanged, nodeIsOpenChanged,
nodeLabelChanged, nodeLabelChanged,
nodeNotesChanged, nodeNotesChanged,

View File

@ -127,6 +127,11 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'ControlNet', title: 'ControlNet',
description: 'TODO', description: 'TODO',
}, },
Scheduler: {
color: 'base.500',
title: 'Scheduler',
description: 'TODO',
},
Collection: { Collection: {
color: 'base.500', color: 'base.500',
title: 'Collection', title: 'Collection',

View File

@ -3,6 +3,7 @@ import {
LoRAModelParam, LoRAModelParam,
MainModelParam, MainModelParam,
OnnxModelParam, OnnxModelParam,
SchedulerParam,
VaeModelParam, VaeModelParam,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
import { OpenAPIV3 } from 'openapi-types'; import { OpenAPIV3 } from 'openapi-types';
@ -98,6 +99,7 @@ export const zFieldType = z.enum([
// region Misc // region Misc
'FilePath', 'FilePath',
'enum', 'enum',
'Scheduler',
// endregion // endregion
]); ]);
@ -137,7 +139,8 @@ export type InputFieldValue =
| CollectionInputFieldValue | CollectionInputFieldValue
| CollectionItemInputFieldValue | CollectionItemInputFieldValue
| ColorInputFieldValue | ColorInputFieldValue
| ImageCollectionInputFieldValue; | ImageCollectionInputFieldValue
| SchedulerInputFieldValue;
/** /**
* An input field template is generated on each page load from the OpenAPI schema. * An input field template is generated on each page load from the OpenAPI schema.
@ -167,7 +170,8 @@ export type InputFieldTemplate =
| CollectionInputFieldTemplate | CollectionInputFieldTemplate
| CollectionItemInputFieldTemplate | CollectionItemInputFieldTemplate
| ColorInputFieldTemplate | ColorInputFieldTemplate
| ImageCollectionInputFieldTemplate; | ImageCollectionInputFieldTemplate
| SchedulerInputFieldTemplate;
/** /**
* An output field is persisted across as part of the user's local state. * An output field is persisted across as part of the user's local state.
@ -322,6 +326,11 @@ export type ColorInputFieldValue = InputFieldValueBase & {
value?: RgbaColor; value?: RgbaColor;
}; };
export type SchedulerInputFieldValue = InputFieldValueBase & {
type: 'Scheduler';
value?: SchedulerParam;
};
export type InputFieldTemplateBase = { export type InputFieldTemplateBase = {
name: string; name: string;
title: string; title: string;
@ -456,6 +465,11 @@ export type ColorInputFieldTemplate = InputFieldTemplateBase & {
type: 'ColorField'; type: 'ColorField';
}; };
export type SchedulerInputFieldTemplate = InputFieldTemplateBase & {
default: SchedulerParam;
type: 'Scheduler';
};
export const isInputFieldValue = ( export const isInputFieldValue = (
field?: InputFieldValue | OutputFieldValue field?: InputFieldValue | OutputFieldValue
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input'); ): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');

View File

@ -27,6 +27,7 @@ import {
OutputFieldTemplate, OutputFieldTemplate,
SDXLMainModelInputFieldTemplate, SDXLMainModelInputFieldTemplate,
SDXLRefinerModelInputFieldTemplate, SDXLRefinerModelInputFieldTemplate,
SchedulerInputFieldTemplate,
StringInputFieldTemplate, StringInputFieldTemplate,
UNetInputFieldTemplate, UNetInputFieldTemplate,
VaeInputFieldTemplate, VaeInputFieldTemplate,
@ -400,6 +401,19 @@ const buildColorInputFieldTemplate = ({
return template; return template;
}; };
const buildSchedulerInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): SchedulerInputFieldTemplate => {
const template: SchedulerInputFieldTemplate = {
...baseField,
type: 'Scheduler',
default: schemaObject.default ?? 'euler',
};
return template;
};
export const getFieldType = ( export const getFieldType = (
schemaObject: InvocationFieldSchema schemaObject: InvocationFieldSchema
): FieldType => { ): FieldType => {
@ -606,6 +620,12 @@ export const buildInputFieldTemplate = (
baseField, baseField,
}); });
} }
if (fieldType === 'Scheduler') {
return buildSchedulerInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
return; return;
}; };

View File

@ -93,5 +93,9 @@ export const buildInputFieldValue = (
fieldValue.value = undefined; fieldValue.value = undefined;
} }
if (template.type === 'Scheduler') {
fieldValue.value = undefined;
}
return fieldValue; return fieldValue;
}; };