mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: add Scheduler
as field type
- update node schemas - add `UIType.Scheduler` - add field type to schema parser, input components
This commit is contained in:
parent
210a3f9aa7
commit
98431b3de4
@ -143,6 +143,7 @@ class UIType(str, Enum):
|
|||||||
# region Misc
|
# region Misc
|
||||||
FilePath = "FilePath"
|
FilePath = "FilePath"
|
||||||
Enum = "enum"
|
Enum = "enum"
|
||||||
|
Scheduler = "Scheduler"
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
@ -119,7 +119,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||||
scheduler: SAMPLER_NAME_VALUES = InputField(default="euler", description=FieldDescriptions.scheduler)
|
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||||
|
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
|
||||||
|
)
|
||||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection)
|
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection)
|
||||||
control: Union[ControlField, list[ControlField]] = InputField(
|
control: Union[ControlField, list[ControlField]] = InputField(
|
||||||
default=None, description=FieldDescriptions.control, input=Input.Connection
|
default=None, description=FieldDescriptions.control, input=Input.Connection
|
||||||
|
@ -169,7 +169,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
ui_type=UIType.Float,
|
ui_type=UIType.Float,
|
||||||
)
|
)
|
||||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||||
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct
|
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler
|
||||||
)
|
)
|
||||||
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
|
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
|
||||||
unet: UNetField = InputField(
|
unet: UNetField = InputField(
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Box } from '@chakra-ui/react';
|
import { Box, Text } from '@chakra-ui/react';
|
||||||
import {
|
import {
|
||||||
useFieldData,
|
useFieldData,
|
||||||
useFieldTemplate,
|
useFieldTemplate,
|
||||||
@ -21,6 +21,7 @@ import MainModelInputField from './fieldTypes/MainModelInputField';
|
|||||||
import NumberInputField from './fieldTypes/NumberInputField';
|
import NumberInputField from './fieldTypes/NumberInputField';
|
||||||
import RefinerModelInputField from './fieldTypes/RefinerModelInputField';
|
import RefinerModelInputField from './fieldTypes/RefinerModelInputField';
|
||||||
import SDXLMainModelInputField from './fieldTypes/SDXLMainModelInputField';
|
import SDXLMainModelInputField from './fieldTypes/SDXLMainModelInputField';
|
||||||
|
import SchedulerInputField from './fieldTypes/SchedulerInputField';
|
||||||
import StringInputField from './fieldTypes/StringInputField';
|
import StringInputField from './fieldTypes/StringInputField';
|
||||||
import UnetInputField from './fieldTypes/UnetInputField';
|
import UnetInputField from './fieldTypes/UnetInputField';
|
||||||
import VaeInputField from './fieldTypes/VaeInputField';
|
import VaeInputField from './fieldTypes/VaeInputField';
|
||||||
@ -286,7 +287,30 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return <Box p={2}>Unknown field type: {field?.type}</Box>;
|
if (field?.type === 'Scheduler' && fieldTemplate?.type === 'Scheduler') {
|
||||||
|
return (
|
||||||
|
<SchedulerInputField
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
fieldTemplate={fieldTemplate}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box p={1}>
|
||||||
|
<Text
|
||||||
|
sx={{
|
||||||
|
fontSize: 'sm',
|
||||||
|
fontWeight: 600,
|
||||||
|
color: 'error.400',
|
||||||
|
_dark: { color: 'error.300' },
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
Unknown field type: {field?.type}
|
||||||
|
</Text>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(InputFieldRenderer);
|
export default memo(InputFieldRenderer);
|
||||||
|
@ -0,0 +1,75 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||||
|
import { fieldSchedulerValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import {
|
||||||
|
SchedulerInputFieldTemplate,
|
||||||
|
SchedulerInputFieldValue,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import {
|
||||||
|
SCHEDULER_LABEL_MAP,
|
||||||
|
SchedulerParam,
|
||||||
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { map } from 'lodash-es';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[stateSelector],
|
||||||
|
({ ui }) => {
|
||||||
|
const { favoriteSchedulers: enabledSchedulers } = ui;
|
||||||
|
|
||||||
|
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
|
||||||
|
value: name,
|
||||||
|
label: label,
|
||||||
|
group: enabledSchedulers.includes(name as SchedulerParam)
|
||||||
|
? 'Favorites'
|
||||||
|
: undefined,
|
||||||
|
})).sort((a, b) => a.label.localeCompare(b.label));
|
||||||
|
|
||||||
|
return {
|
||||||
|
data,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const SchedulerInputField = (
|
||||||
|
props: FieldComponentProps<
|
||||||
|
SchedulerInputFieldValue,
|
||||||
|
SchedulerInputFieldTemplate
|
||||||
|
>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { data } = useAppSelector(selector);
|
||||||
|
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(value: string | null) => {
|
||||||
|
if (!value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(
|
||||||
|
fieldSchedulerValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value: value as SchedulerParam,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIMantineSearchableSelect
|
||||||
|
className="nowheel nodrag"
|
||||||
|
value={field.value}
|
||||||
|
data={data}
|
||||||
|
onChange={handleChange}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(SchedulerInputField);
|
@ -45,6 +45,7 @@ import {
|
|||||||
MainModelInputFieldValue,
|
MainModelInputFieldValue,
|
||||||
NodeStatus,
|
NodeStatus,
|
||||||
NotesNodeData,
|
NotesNodeData,
|
||||||
|
SchedulerInputFieldValue,
|
||||||
SDXLRefinerModelInputFieldValue,
|
SDXLRefinerModelInputFieldValue,
|
||||||
StringInputFieldValue,
|
StringInputFieldValue,
|
||||||
VaeModelInputFieldValue,
|
VaeModelInputFieldValue,
|
||||||
@ -447,6 +448,12 @@ const nodesSlice = createSlice({
|
|||||||
) => {
|
) => {
|
||||||
fieldValueReducer(state, action);
|
fieldValueReducer(state, action);
|
||||||
},
|
},
|
||||||
|
fieldSchedulerValueChanged: (
|
||||||
|
state,
|
||||||
|
action: FieldValueAction<SchedulerInputFieldValue>
|
||||||
|
) => {
|
||||||
|
fieldValueReducer(state, action);
|
||||||
|
},
|
||||||
imageCollectionFieldValueChanged: (
|
imageCollectionFieldValueChanged: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
@ -670,6 +677,7 @@ export const {
|
|||||||
fieldEnumModelValueChanged,
|
fieldEnumModelValueChanged,
|
||||||
fieldControlNetModelValueChanged,
|
fieldControlNetModelValueChanged,
|
||||||
fieldRefinerModelValueChanged,
|
fieldRefinerModelValueChanged,
|
||||||
|
fieldSchedulerValueChanged,
|
||||||
nodeIsOpenChanged,
|
nodeIsOpenChanged,
|
||||||
nodeLabelChanged,
|
nodeLabelChanged,
|
||||||
nodeNotesChanged,
|
nodeNotesChanged,
|
||||||
|
@ -127,6 +127,11 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
|||||||
title: 'ControlNet',
|
title: 'ControlNet',
|
||||||
description: 'TODO',
|
description: 'TODO',
|
||||||
},
|
},
|
||||||
|
Scheduler: {
|
||||||
|
color: 'base.500',
|
||||||
|
title: 'Scheduler',
|
||||||
|
description: 'TODO',
|
||||||
|
},
|
||||||
Collection: {
|
Collection: {
|
||||||
color: 'base.500',
|
color: 'base.500',
|
||||||
title: 'Collection',
|
title: 'Collection',
|
||||||
|
@ -3,6 +3,7 @@ import {
|
|||||||
LoRAModelParam,
|
LoRAModelParam,
|
||||||
MainModelParam,
|
MainModelParam,
|
||||||
OnnxModelParam,
|
OnnxModelParam,
|
||||||
|
SchedulerParam,
|
||||||
VaeModelParam,
|
VaeModelParam,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
import { OpenAPIV3 } from 'openapi-types';
|
import { OpenAPIV3 } from 'openapi-types';
|
||||||
@ -98,6 +99,7 @@ export const zFieldType = z.enum([
|
|||||||
// region Misc
|
// region Misc
|
||||||
'FilePath',
|
'FilePath',
|
||||||
'enum',
|
'enum',
|
||||||
|
'Scheduler',
|
||||||
// endregion
|
// endregion
|
||||||
]);
|
]);
|
||||||
|
|
||||||
@ -137,7 +139,8 @@ export type InputFieldValue =
|
|||||||
| CollectionInputFieldValue
|
| CollectionInputFieldValue
|
||||||
| CollectionItemInputFieldValue
|
| CollectionItemInputFieldValue
|
||||||
| ColorInputFieldValue
|
| ColorInputFieldValue
|
||||||
| ImageCollectionInputFieldValue;
|
| ImageCollectionInputFieldValue
|
||||||
|
| SchedulerInputFieldValue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An input field template is generated on each page load from the OpenAPI schema.
|
* An input field template is generated on each page load from the OpenAPI schema.
|
||||||
@ -167,7 +170,8 @@ export type InputFieldTemplate =
|
|||||||
| CollectionInputFieldTemplate
|
| CollectionInputFieldTemplate
|
||||||
| CollectionItemInputFieldTemplate
|
| CollectionItemInputFieldTemplate
|
||||||
| ColorInputFieldTemplate
|
| ColorInputFieldTemplate
|
||||||
| ImageCollectionInputFieldTemplate;
|
| ImageCollectionInputFieldTemplate
|
||||||
|
| SchedulerInputFieldTemplate;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An output field is persisted across as part of the user's local state.
|
* An output field is persisted across as part of the user's local state.
|
||||||
@ -322,6 +326,11 @@ export type ColorInputFieldValue = InputFieldValueBase & {
|
|||||||
value?: RgbaColor;
|
value?: RgbaColor;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type SchedulerInputFieldValue = InputFieldValueBase & {
|
||||||
|
type: 'Scheduler';
|
||||||
|
value?: SchedulerParam;
|
||||||
|
};
|
||||||
|
|
||||||
export type InputFieldTemplateBase = {
|
export type InputFieldTemplateBase = {
|
||||||
name: string;
|
name: string;
|
||||||
title: string;
|
title: string;
|
||||||
@ -456,6 +465,11 @@ export type ColorInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'ColorField';
|
type: 'ColorField';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type SchedulerInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: SchedulerParam;
|
||||||
|
type: 'Scheduler';
|
||||||
|
};
|
||||||
|
|
||||||
export const isInputFieldValue = (
|
export const isInputFieldValue = (
|
||||||
field?: InputFieldValue | OutputFieldValue
|
field?: InputFieldValue | OutputFieldValue
|
||||||
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
|
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
|
||||||
|
@ -27,6 +27,7 @@ import {
|
|||||||
OutputFieldTemplate,
|
OutputFieldTemplate,
|
||||||
SDXLMainModelInputFieldTemplate,
|
SDXLMainModelInputFieldTemplate,
|
||||||
SDXLRefinerModelInputFieldTemplate,
|
SDXLRefinerModelInputFieldTemplate,
|
||||||
|
SchedulerInputFieldTemplate,
|
||||||
StringInputFieldTemplate,
|
StringInputFieldTemplate,
|
||||||
UNetInputFieldTemplate,
|
UNetInputFieldTemplate,
|
||||||
VaeInputFieldTemplate,
|
VaeInputFieldTemplate,
|
||||||
@ -400,6 +401,19 @@ const buildColorInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildSchedulerInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): SchedulerInputFieldTemplate => {
|
||||||
|
const template: SchedulerInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'Scheduler',
|
||||||
|
default: schemaObject.default ?? 'euler',
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
export const getFieldType = (
|
export const getFieldType = (
|
||||||
schemaObject: InvocationFieldSchema
|
schemaObject: InvocationFieldSchema
|
||||||
): FieldType => {
|
): FieldType => {
|
||||||
@ -606,6 +620,12 @@ export const buildInputFieldTemplate = (
|
|||||||
baseField,
|
baseField,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
if (fieldType === 'Scheduler') {
|
||||||
|
return buildSchedulerInputFieldTemplate({
|
||||||
|
schemaObject: fieldSchema,
|
||||||
|
baseField,
|
||||||
|
});
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -93,5 +93,9 @@ export const buildInputFieldValue = (
|
|||||||
fieldValue.value = undefined;
|
fieldValue.value = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (template.type === 'Scheduler') {
|
||||||
|
fieldValue.value = undefined;
|
||||||
|
}
|
||||||
|
|
||||||
return fieldValue;
|
return fieldValue;
|
||||||
};
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user