mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: add Scheduler
as field type
- update node schemas - add `UIType.Scheduler` - add field type to schema parser, input components
This commit is contained in:
parent
210a3f9aa7
commit
98431b3de4
@ -143,6 +143,7 @@ class UIType(str, Enum):
|
||||
# region Misc
|
||||
FilePath = "FilePath"
|
||||
Enum = "enum"
|
||||
Scheduler = "Scheduler"
|
||||
# endregion
|
||||
|
||||
|
||||
|
@ -119,7 +119,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(default="euler", description=FieldDescriptions.scheduler)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
|
||||
)
|
||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection)
|
||||
control: Union[ControlField, list[ControlField]] = InputField(
|
||||
default=None, description=FieldDescriptions.control, input=Input.Connection
|
||||
|
@ -169,7 +169,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
ui_type=UIType.Float,
|
||||
)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct
|
||||
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler
|
||||
)
|
||||
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
|
||||
unet: UNetField = InputField(
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { Box, Text } from '@chakra-ui/react';
|
||||
import {
|
||||
useFieldData,
|
||||
useFieldTemplate,
|
||||
@ -21,6 +21,7 @@ import MainModelInputField from './fieldTypes/MainModelInputField';
|
||||
import NumberInputField from './fieldTypes/NumberInputField';
|
||||
import RefinerModelInputField from './fieldTypes/RefinerModelInputField';
|
||||
import SDXLMainModelInputField from './fieldTypes/SDXLMainModelInputField';
|
||||
import SchedulerInputField from './fieldTypes/SchedulerInputField';
|
||||
import StringInputField from './fieldTypes/StringInputField';
|
||||
import UnetInputField from './fieldTypes/UnetInputField';
|
||||
import VaeInputField from './fieldTypes/VaeInputField';
|
||||
@ -286,7 +287,30 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
return <Box p={2}>Unknown field type: {field?.type}</Box>;
|
||||
if (field?.type === 'Scheduler' && fieldTemplate?.type === 'Scheduler') {
|
||||
return (
|
||||
<SchedulerInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Box p={1}>
|
||||
<Text
|
||||
sx={{
|
||||
fontSize: 'sm',
|
||||
fontWeight: 600,
|
||||
color: 'error.400',
|
||||
_dark: { color: 'error.300' },
|
||||
}}
|
||||
>
|
||||
Unknown field type: {field?.type}
|
||||
</Text>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(InputFieldRenderer);
|
||||
|
@ -0,0 +1,75 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { fieldSchedulerValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
SchedulerInputFieldTemplate,
|
||||
SchedulerInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import {
|
||||
SCHEDULER_LABEL_MAP,
|
||||
SchedulerParam,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ ui }) => {
|
||||
const { favoriteSchedulers: enabledSchedulers } = ui;
|
||||
|
||||
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
|
||||
value: name,
|
||||
label: label,
|
||||
group: enabledSchedulers.includes(name as SchedulerParam)
|
||||
? 'Favorites'
|
||||
: undefined,
|
||||
})).sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
return {
|
||||
data,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const SchedulerInputField = (
|
||||
props: FieldComponentProps<
|
||||
SchedulerInputFieldValue,
|
||||
SchedulerInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { data } = useAppSelector(selector);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(value: string | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldSchedulerValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: value as SchedulerParam,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSearchableSelect
|
||||
className="nowheel nodrag"
|
||||
value={field.value}
|
||||
data={data}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SchedulerInputField);
|
@ -45,6 +45,7 @@ import {
|
||||
MainModelInputFieldValue,
|
||||
NodeStatus,
|
||||
NotesNodeData,
|
||||
SchedulerInputFieldValue,
|
||||
SDXLRefinerModelInputFieldValue,
|
||||
StringInputFieldValue,
|
||||
VaeModelInputFieldValue,
|
||||
@ -447,6 +448,12 @@ const nodesSlice = createSlice({
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldSchedulerValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<SchedulerInputFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
imageCollectionFieldValueChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
@ -670,6 +677,7 @@ export const {
|
||||
fieldEnumModelValueChanged,
|
||||
fieldControlNetModelValueChanged,
|
||||
fieldRefinerModelValueChanged,
|
||||
fieldSchedulerValueChanged,
|
||||
nodeIsOpenChanged,
|
||||
nodeLabelChanged,
|
||||
nodeNotesChanged,
|
||||
|
@ -127,6 +127,11 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
title: 'ControlNet',
|
||||
description: 'TODO',
|
||||
},
|
||||
Scheduler: {
|
||||
color: 'base.500',
|
||||
title: 'Scheduler',
|
||||
description: 'TODO',
|
||||
},
|
||||
Collection: {
|
||||
color: 'base.500',
|
||||
title: 'Collection',
|
||||
|
@ -3,6 +3,7 @@ import {
|
||||
LoRAModelParam,
|
||||
MainModelParam,
|
||||
OnnxModelParam,
|
||||
SchedulerParam,
|
||||
VaeModelParam,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
@ -98,6 +99,7 @@ export const zFieldType = z.enum([
|
||||
// region Misc
|
||||
'FilePath',
|
||||
'enum',
|
||||
'Scheduler',
|
||||
// endregion
|
||||
]);
|
||||
|
||||
@ -137,7 +139,8 @@ export type InputFieldValue =
|
||||
| CollectionInputFieldValue
|
||||
| CollectionItemInputFieldValue
|
||||
| ColorInputFieldValue
|
||||
| ImageCollectionInputFieldValue;
|
||||
| ImageCollectionInputFieldValue
|
||||
| SchedulerInputFieldValue;
|
||||
|
||||
/**
|
||||
* An input field template is generated on each page load from the OpenAPI schema.
|
||||
@ -167,7 +170,8 @@ export type InputFieldTemplate =
|
||||
| CollectionInputFieldTemplate
|
||||
| CollectionItemInputFieldTemplate
|
||||
| ColorInputFieldTemplate
|
||||
| ImageCollectionInputFieldTemplate;
|
||||
| ImageCollectionInputFieldTemplate
|
||||
| SchedulerInputFieldTemplate;
|
||||
|
||||
/**
|
||||
* An output field is persisted across as part of the user's local state.
|
||||
@ -322,6 +326,11 @@ export type ColorInputFieldValue = InputFieldValueBase & {
|
||||
value?: RgbaColor;
|
||||
};
|
||||
|
||||
export type SchedulerInputFieldValue = InputFieldValueBase & {
|
||||
type: 'Scheduler';
|
||||
value?: SchedulerParam;
|
||||
};
|
||||
|
||||
export type InputFieldTemplateBase = {
|
||||
name: string;
|
||||
title: string;
|
||||
@ -456,6 +465,11 @@ export type ColorInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'ColorField';
|
||||
};
|
||||
|
||||
export type SchedulerInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: SchedulerParam;
|
||||
type: 'Scheduler';
|
||||
};
|
||||
|
||||
export const isInputFieldValue = (
|
||||
field?: InputFieldValue | OutputFieldValue
|
||||
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
|
||||
|
@ -27,6 +27,7 @@ import {
|
||||
OutputFieldTemplate,
|
||||
SDXLMainModelInputFieldTemplate,
|
||||
SDXLRefinerModelInputFieldTemplate,
|
||||
SchedulerInputFieldTemplate,
|
||||
StringInputFieldTemplate,
|
||||
UNetInputFieldTemplate,
|
||||
VaeInputFieldTemplate,
|
||||
@ -400,6 +401,19 @@ const buildColorInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildSchedulerInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): SchedulerInputFieldTemplate => {
|
||||
const template: SchedulerInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'Scheduler',
|
||||
default: schemaObject.default ?? 'euler',
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
export const getFieldType = (
|
||||
schemaObject: InvocationFieldSchema
|
||||
): FieldType => {
|
||||
@ -606,6 +620,12 @@ export const buildInputFieldTemplate = (
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'Scheduler') {
|
||||
return buildSchedulerInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
return;
|
||||
};
|
||||
|
||||
|
@ -93,5 +93,9 @@ export const buildInputFieldValue = (
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'Scheduler') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
return fieldValue;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user