feat(ui): add ModelIdentifierField field type

This new field type accepts _any_ model. A field renderer lets the user select any available model.
This commit is contained in:
psychedelicious 2024-05-17 20:47:00 +10:00
parent 6a2c53f6c5
commit a012bb6e07
6 changed files with 130 additions and 0 deletions

View File

@ -1,3 +1,4 @@
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance'; import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate'; import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
import { import {
@ -23,6 +24,8 @@ import {
isLoRAModelFieldInputTemplate, isLoRAModelFieldInputTemplate,
isMainModelFieldInputInstance, isMainModelFieldInputInstance,
isMainModelFieldInputTemplate, isMainModelFieldInputTemplate,
isModelIdentifierFieldInputInstance,
isModelIdentifierFieldInputTemplate,
isSchedulerFieldInputInstance, isSchedulerFieldInputInstance,
isSchedulerFieldInputTemplate, isSchedulerFieldInputTemplate,
isSDXLMainModelFieldInputInstance, isSDXLMainModelFieldInputInstance,
@ -95,6 +98,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return <MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
} }
if (isModelIdentifierFieldInputInstance(fieldInstance) && isModelIdentifierFieldInputTemplate(fieldTemplate)) {
return <ModelIdentifierFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) { if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) {
return <RefinerModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return <RefinerModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
} }

View File

@ -0,0 +1,68 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldModelIdentifierValueChanged } from 'features/nodes/store/nodesSlice';
import type { ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback, useMemo } from 'react';
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate>;
const ModelIdentifierFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetModelConfigsQuery();
const _onChange = useCallback(
(value: AnyModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldModelIdentifierValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const modelConfigs = useMemo(() => {
if (!data) {
return EMPTY_ARRAY;
}
return modelConfigsAdapterSelectors.selectAll(data);
}, [data]);
console.log(modelConfigs);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
groupByType: true,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
);
};
export default memo(ModelIdentifierFieldInputComponent);

View File

@ -16,6 +16,7 @@ import type {
IPAdapterModelFieldValue, IPAdapterModelFieldValue,
LoRAModelFieldValue, LoRAModelFieldValue,
MainModelFieldValue, MainModelFieldValue,
ModelIdentifierFieldValue,
SchedulerFieldValue, SchedulerFieldValue,
SDXLRefinerModelFieldValue, SDXLRefinerModelFieldValue,
StatefulFieldValue, StatefulFieldValue,
@ -35,6 +36,7 @@ import {
zIPAdapterModelFieldValue, zIPAdapterModelFieldValue,
zLoRAModelFieldValue, zLoRAModelFieldValue,
zMainModelFieldValue, zMainModelFieldValue,
zModelIdentifierFieldValue,
zSchedulerFieldValue, zSchedulerFieldValue,
zSDXLRefinerModelFieldValue, zSDXLRefinerModelFieldValue,
zStatefulFieldValue, zStatefulFieldValue,
@ -344,6 +346,9 @@ export const nodesSlice = createSlice({
fieldMainModelValueChanged: (state, action: FieldValueAction<MainModelFieldValue>) => { fieldMainModelValueChanged: (state, action: FieldValueAction<MainModelFieldValue>) => {
fieldValueReducer(state, action, zMainModelFieldValue); fieldValueReducer(state, action, zMainModelFieldValue);
}, },
fieldModelIdentifierValueChanged: (state, action: FieldValueAction<ModelIdentifierFieldValue>) => {
fieldValueReducer(state, action, zModelIdentifierFieldValue);
},
fieldRefinerModelValueChanged: (state, action: FieldValueAction<SDXLRefinerModelFieldValue>) => { fieldRefinerModelValueChanged: (state, action: FieldValueAction<SDXLRefinerModelFieldValue>) => {
fieldValueReducer(state, action, zSDXLRefinerModelFieldValue); fieldValueReducer(state, action, zSDXLRefinerModelFieldValue);
}, },
@ -469,6 +474,7 @@ export const {
fieldT2IAdapterModelValueChanged, fieldT2IAdapterModelValueChanged,
fieldLabelChanged, fieldLabelChanged,
fieldLoRAModelValueChanged, fieldLoRAModelValueChanged,
fieldModelIdentifierValueChanged,
fieldMainModelValueChanged, fieldMainModelValueChanged,
fieldNumberValueChanged, fieldNumberValueChanged,
fieldRefinerModelValueChanged, fieldRefinerModelValueChanged,

View File

@ -106,6 +106,10 @@ const zMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('MainModelField'), name: z.literal('MainModelField'),
originalType: zStatelessFieldType.optional(), originalType: zStatelessFieldType.optional(),
}); });
const zModelIdentifierFieldType = zFieldTypeBase.extend({
name: z.literal('ModelIdentifierField'),
originalType: zStatelessFieldType.optional(),
});
const zSDXLMainModelFieldType = zFieldTypeBase.extend({ const zSDXLMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLMainModelField'), name: z.literal('SDXLMainModelField'),
originalType: zStatelessFieldType.optional(), originalType: zStatelessFieldType.optional(),
@ -146,6 +150,7 @@ const zStatefulFieldType = z.union([
zEnumFieldType, zEnumFieldType,
zImageFieldType, zImageFieldType,
zBoardFieldType, zBoardFieldType,
zModelIdentifierFieldType,
zMainModelFieldType, zMainModelFieldType,
zSDXLMainModelFieldType, zSDXLMainModelFieldType,
zSDXLRefinerModelFieldType, zSDXLRefinerModelFieldType,
@ -396,6 +401,29 @@ export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFie
zMainModelFieldInputTemplate.safeParse(val).success; zMainModelFieldInputTemplate.safeParse(val).success;
// #endregion // #endregion
// #region ModelIdentifierField
export const zModelIdentifierFieldValue = zModelIdentifierField.optional();
const zModelIdentifierFieldInputInstance = zFieldInputInstanceBase.extend({
value: zModelIdentifierFieldValue,
});
const zModelIdentifierFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zModelIdentifierFieldType,
originalType: zFieldType.optional(),
default: zModelIdentifierFieldValue,
});
const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zModelIdentifierFieldType,
originalType: zFieldType.optional(),
});
export type ModelIdentifierFieldValue = z.infer<typeof zModelIdentifierFieldValue>;
export type ModelIdentifierFieldInputInstance = z.infer<typeof zModelIdentifierFieldInputInstance>;
export type ModelIdentifierFieldInputTemplate = z.infer<typeof zModelIdentifierFieldInputTemplate>;
export const isModelIdentifierFieldInputInstance = (val: unknown): val is ModelIdentifierFieldInputInstance =>
zModelIdentifierFieldInputInstance.safeParse(val).success;
export const isModelIdentifierFieldInputTemplate = (val: unknown): val is ModelIdentifierFieldInputTemplate =>
zModelIdentifierFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SDXLMainModelField // #region SDXLMainModelField
const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
@ -643,6 +671,7 @@ export const zStatefulFieldValue = z.union([
zEnumFieldValue, zEnumFieldValue,
zImageFieldValue, zImageFieldValue,
zBoardFieldValue, zBoardFieldValue,
zModelIdentifierFieldValue,
zMainModelFieldValue, zMainModelFieldValue,
zSDXLMainModelFieldValue, zSDXLMainModelFieldValue,
zSDXLRefinerModelFieldValue, zSDXLRefinerModelFieldValue,
@ -669,6 +698,7 @@ const zStatefulFieldInputInstance = z.union([
zEnumFieldInputInstance, zEnumFieldInputInstance,
zImageFieldInputInstance, zImageFieldInputInstance,
zBoardFieldInputInstance, zBoardFieldInputInstance,
zModelIdentifierFieldInputInstance,
zMainModelFieldInputInstance, zMainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance, zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance, zSDXLRefinerModelFieldInputInstance,
@ -696,6 +726,7 @@ const zStatefulFieldInputTemplate = z.union([
zEnumFieldInputTemplate, zEnumFieldInputTemplate,
zImageFieldInputTemplate, zImageFieldInputTemplate,
zBoardFieldInputTemplate, zBoardFieldInputTemplate,
zModelIdentifierFieldInputTemplate,
zMainModelFieldInputTemplate, zMainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate, zSDXLMainModelFieldInputTemplate,
zSDXLRefinerModelFieldInputTemplate, zSDXLRefinerModelFieldInputTemplate,
@ -724,6 +755,7 @@ const zStatefulFieldOutputTemplate = z.union([
zEnumFieldOutputTemplate, zEnumFieldOutputTemplate,
zImageFieldOutputTemplate, zImageFieldOutputTemplate,
zBoardFieldOutputTemplate, zBoardFieldOutputTemplate,
zModelIdentifierFieldOutputTemplate,
zMainModelFieldOutputTemplate, zMainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate, zSDXLMainModelFieldOutputTemplate,
zSDXLRefinerModelFieldOutputTemplate, zSDXLRefinerModelFieldOutputTemplate,

View File

@ -11,6 +11,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
IntegerField: 0, IntegerField: 0,
IPAdapterModelField: undefined, IPAdapterModelField: undefined,
LoRAModelField: undefined, LoRAModelField: undefined,
ModelIdentifierField: undefined,
MainModelField: undefined, MainModelField: undefined,
SchedulerField: 'euler', SchedulerField: 'euler',
SDXLMainModelField: undefined, SDXLMainModelField: undefined,

View File

@ -13,6 +13,7 @@ import type {
IPAdapterModelFieldInputTemplate, IPAdapterModelFieldInputTemplate,
LoRAModelFieldInputTemplate, LoRAModelFieldInputTemplate,
MainModelFieldInputTemplate, MainModelFieldInputTemplate,
ModelIdentifierFieldInputTemplate,
SchedulerFieldInputTemplate, SchedulerFieldInputTemplate,
SDXLMainModelFieldInputTemplate, SDXLMainModelFieldInputTemplate,
SDXLRefinerModelFieldInputTemplate, SDXLRefinerModelFieldInputTemplate,
@ -136,6 +137,20 @@ const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder<BooleanFieldInpu
return template; return template;
}; };
const buildModelIdentifierFieldInputTemplate: FieldInputTemplateBuilder<ModelIdentifierFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: ModelIdentifierFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder<MainModelFieldInputTemplate> = ({ const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder<MainModelFieldInputTemplate> = ({
schemaObject, schemaObject,
baseField, baseField,
@ -355,6 +370,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
IntegerField: buildIntegerFieldInputTemplate, IntegerField: buildIntegerFieldInputTemplate,
IPAdapterModelField: buildIPAdapterModelFieldInputTemplate, IPAdapterModelField: buildIPAdapterModelFieldInputTemplate,
LoRAModelField: buildLoRAModelFieldInputTemplate, LoRAModelField: buildLoRAModelFieldInputTemplate,
ModelIdentifierField: buildModelIdentifierFieldInputTemplate,
MainModelField: buildMainModelFieldInputTemplate, MainModelField: buildMainModelFieldInputTemplate,
SchedulerField: buildSchedulerFieldInputTemplate, SchedulerField: buildSchedulerFieldInputTemplate,
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate, SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,