feat: add Scheduler as field type

- update node schemas
- add `UIType.Scheduler`
- add field type to schema parser, input components
This commit is contained in:
psychedelicious 2023-08-17 18:58:01 +10:00
parent 210a3f9aa7
commit 98431b3de4
10 changed files with 159 additions and 6 deletions

View File

@ -143,6 +143,7 @@ class UIType(str, Enum):
# region Misc
FilePath = "FilePath"
Enum = "enum"
Scheduler = "Scheduler"
# endregion

View File

@ -119,7 +119,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SAMPLER_NAME_VALUES = InputField(default="euler", description=FieldDescriptions.scheduler)
scheduler: SAMPLER_NAME_VALUES = InputField(
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
)
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection)
control: Union[ControlField, list[ControlField]] = InputField(
default=None, description=FieldDescriptions.control, input=Input.Connection

View File

@ -169,7 +169,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
ui_type=UIType.Float,
)
scheduler: SAMPLER_NAME_VALUES = InputField(
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler
)
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
unet: UNetField = InputField(

View File

@ -1,4 +1,4 @@
import { Box } from '@chakra-ui/react';
import { Box, Text } from '@chakra-ui/react';
import {
useFieldData,
useFieldTemplate,
@ -21,6 +21,7 @@ import MainModelInputField from './fieldTypes/MainModelInputField';
import NumberInputField from './fieldTypes/NumberInputField';
import RefinerModelInputField from './fieldTypes/RefinerModelInputField';
import SDXLMainModelInputField from './fieldTypes/SDXLMainModelInputField';
import SchedulerInputField from './fieldTypes/SchedulerInputField';
import StringInputField from './fieldTypes/StringInputField';
import UnetInputField from './fieldTypes/UnetInputField';
import VaeInputField from './fieldTypes/VaeInputField';
@ -286,7 +287,30 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
);
}
return <Box p={2}>Unknown field type: {field?.type}</Box>;
if (field?.type === 'Scheduler' && fieldTemplate?.type === 'Scheduler') {
return (
<SchedulerInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
return (
<Box p={1}>
<Text
sx={{
fontSize: 'sm',
fontWeight: 600,
color: 'error.400',
_dark: { color: 'error.300' },
}}
>
Unknown field type: {field?.type}
</Text>
</Box>
);
};
export default memo(InputFieldRenderer);

View File

@ -0,0 +1,75 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldSchedulerValueChanged } from 'features/nodes/store/nodesSlice';
import {
SchedulerInputFieldTemplate,
SchedulerInputFieldValue,
} from 'features/nodes/types/types';
import {
SCHEDULER_LABEL_MAP,
SchedulerParam,
} from 'features/parameters/types/parameterSchemas';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { FieldComponentProps } from './types';
const selector = createSelector(
[stateSelector],
({ ui }) => {
const { favoriteSchedulers: enabledSchedulers } = ui;
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
value: name,
label: label,
group: enabledSchedulers.includes(name as SchedulerParam)
? 'Favorites'
: undefined,
})).sort((a, b) => a.label.localeCompare(b.label));
return {
data,
};
},
defaultSelectorOptions
);
const SchedulerInputField = (
props: FieldComponentProps<
SchedulerInputFieldValue,
SchedulerInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data } = useAppSelector(selector);
const handleChange = useCallback(
(value: string | null) => {
if (!value) {
return;
}
dispatch(
fieldSchedulerValueChanged({
nodeId,
fieldName: field.name,
value: value as SchedulerParam,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<IAIMantineSearchableSelect
className="nowheel nodrag"
value={field.value}
data={data}
onChange={handleChange}
/>
);
};
export default memo(SchedulerInputField);

View File

@ -45,6 +45,7 @@ import {
MainModelInputFieldValue,
NodeStatus,
NotesNodeData,
SchedulerInputFieldValue,
SDXLRefinerModelInputFieldValue,
StringInputFieldValue,
VaeModelInputFieldValue,
@ -447,6 +448,12 @@ const nodesSlice = createSlice({
) => {
fieldValueReducer(state, action);
},
fieldSchedulerValueChanged: (
state,
action: FieldValueAction<SchedulerInputFieldValue>
) => {
fieldValueReducer(state, action);
},
imageCollectionFieldValueChanged: (
state,
action: PayloadAction<{
@ -670,6 +677,7 @@ export const {
fieldEnumModelValueChanged,
fieldControlNetModelValueChanged,
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
nodeIsOpenChanged,
nodeLabelChanged,
nodeNotesChanged,

View File

@ -127,6 +127,11 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'ControlNet',
description: 'TODO',
},
Scheduler: {
color: 'base.500',
title: 'Scheduler',
description: 'TODO',
},
Collection: {
color: 'base.500',
title: 'Collection',

View File

@ -3,6 +3,7 @@ import {
LoRAModelParam,
MainModelParam,
OnnxModelParam,
SchedulerParam,
VaeModelParam,
} from 'features/parameters/types/parameterSchemas';
import { OpenAPIV3 } from 'openapi-types';
@ -98,6 +99,7 @@ export const zFieldType = z.enum([
// region Misc
'FilePath',
'enum',
'Scheduler',
// endregion
]);
@ -137,7 +139,8 @@ export type InputFieldValue =
| CollectionInputFieldValue
| CollectionItemInputFieldValue
| ColorInputFieldValue
| ImageCollectionInputFieldValue;
| ImageCollectionInputFieldValue
| SchedulerInputFieldValue;
/**
* An input field template is generated on each page load from the OpenAPI schema.
@ -167,7 +170,8 @@ export type InputFieldTemplate =
| CollectionInputFieldTemplate
| CollectionItemInputFieldTemplate
| ColorInputFieldTemplate
| ImageCollectionInputFieldTemplate;
| ImageCollectionInputFieldTemplate
| SchedulerInputFieldTemplate;
/**
* An output field is persisted across as part of the user's local state.
@ -322,6 +326,11 @@ export type ColorInputFieldValue = InputFieldValueBase & {
value?: RgbaColor;
};
export type SchedulerInputFieldValue = InputFieldValueBase & {
type: 'Scheduler';
value?: SchedulerParam;
};
export type InputFieldTemplateBase = {
name: string;
title: string;
@ -456,6 +465,11 @@ export type ColorInputFieldTemplate = InputFieldTemplateBase & {
type: 'ColorField';
};
export type SchedulerInputFieldTemplate = InputFieldTemplateBase & {
default: SchedulerParam;
type: 'Scheduler';
};
export const isInputFieldValue = (
field?: InputFieldValue | OutputFieldValue
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');

View File

@ -27,6 +27,7 @@ import {
OutputFieldTemplate,
SDXLMainModelInputFieldTemplate,
SDXLRefinerModelInputFieldTemplate,
SchedulerInputFieldTemplate,
StringInputFieldTemplate,
UNetInputFieldTemplate,
VaeInputFieldTemplate,
@ -400,6 +401,19 @@ const buildColorInputFieldTemplate = ({
return template;
};
const buildSchedulerInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): SchedulerInputFieldTemplate => {
const template: SchedulerInputFieldTemplate = {
...baseField,
type: 'Scheduler',
default: schemaObject.default ?? 'euler',
};
return template;
};
export const getFieldType = (
schemaObject: InvocationFieldSchema
): FieldType => {
@ -606,6 +620,12 @@ export const buildInputFieldTemplate = (
baseField,
});
}
if (fieldType === 'Scheduler') {
return buildSchedulerInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
return;
};

View File

@ -93,5 +93,9 @@ export const buildInputFieldValue = (
fieldValue.value = undefined;
}
if (template.type === 'Scheduler') {
fieldValue.value = undefined;
}
return fieldValue;
};