mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): add ModelIdentifierField field type
This new field type accepts _any_ model. A field renderer lets the user select any available model.
This commit is contained in:
parent
6a2c53f6c5
commit
a012bb6e07
@ -1,3 +1,4 @@
|
|||||||
|
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
|
||||||
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||||
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
|
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
|
||||||
import {
|
import {
|
||||||
@ -23,6 +24,8 @@ import {
|
|||||||
isLoRAModelFieldInputTemplate,
|
isLoRAModelFieldInputTemplate,
|
||||||
isMainModelFieldInputInstance,
|
isMainModelFieldInputInstance,
|
||||||
isMainModelFieldInputTemplate,
|
isMainModelFieldInputTemplate,
|
||||||
|
isModelIdentifierFieldInputInstance,
|
||||||
|
isModelIdentifierFieldInputTemplate,
|
||||||
isSchedulerFieldInputInstance,
|
isSchedulerFieldInputInstance,
|
||||||
isSchedulerFieldInputTemplate,
|
isSchedulerFieldInputTemplate,
|
||||||
isSDXLMainModelFieldInputInstance,
|
isSDXLMainModelFieldInputInstance,
|
||||||
@ -95,6 +98,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
return <MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
return <MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isModelIdentifierFieldInputInstance(fieldInstance) && isModelIdentifierFieldInputTemplate(fieldTemplate)) {
|
||||||
|
return <ModelIdentifierFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
|
}
|
||||||
|
|
||||||
if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) {
|
if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) {
|
||||||
return <RefinerModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
return <RefinerModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,68 @@
|
|||||||
|
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||||
|
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
|
import { fieldModelIdentifierValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import type { ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
import type { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
type Props = FieldComponentProps<ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate>;
|
||||||
|
|
||||||
|
const ModelIdentifierFieldInputComponent = (props: Props) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { data, isLoading } = useGetModelConfigsQuery();
|
||||||
|
const _onChange = useCallback(
|
||||||
|
(value: AnyModelConfig | null) => {
|
||||||
|
if (!value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(
|
||||||
|
fieldModelIdentifierValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const modelConfigs = useMemo(() => {
|
||||||
|
if (!data) {
|
||||||
|
return EMPTY_ARRAY;
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelConfigsAdapterSelectors.selectAll(data);
|
||||||
|
}, [data]);
|
||||||
|
|
||||||
|
console.log(modelConfigs);
|
||||||
|
|
||||||
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
|
modelConfigs,
|
||||||
|
onChange: _onChange,
|
||||||
|
isLoading,
|
||||||
|
selectedModel: field.value,
|
||||||
|
groupByType: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex w="full" alignItems="center" gap={2}>
|
||||||
|
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||||
|
<Combobox
|
||||||
|
value={value}
|
||||||
|
placeholder={placeholder}
|
||||||
|
options={options}
|
||||||
|
onChange={onChange}
|
||||||
|
noOptionsMessage={noOptionsMessage}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ModelIdentifierFieldInputComponent);
|
@ -16,6 +16,7 @@ import type {
|
|||||||
IPAdapterModelFieldValue,
|
IPAdapterModelFieldValue,
|
||||||
LoRAModelFieldValue,
|
LoRAModelFieldValue,
|
||||||
MainModelFieldValue,
|
MainModelFieldValue,
|
||||||
|
ModelIdentifierFieldValue,
|
||||||
SchedulerFieldValue,
|
SchedulerFieldValue,
|
||||||
SDXLRefinerModelFieldValue,
|
SDXLRefinerModelFieldValue,
|
||||||
StatefulFieldValue,
|
StatefulFieldValue,
|
||||||
@ -35,6 +36,7 @@ import {
|
|||||||
zIPAdapterModelFieldValue,
|
zIPAdapterModelFieldValue,
|
||||||
zLoRAModelFieldValue,
|
zLoRAModelFieldValue,
|
||||||
zMainModelFieldValue,
|
zMainModelFieldValue,
|
||||||
|
zModelIdentifierFieldValue,
|
||||||
zSchedulerFieldValue,
|
zSchedulerFieldValue,
|
||||||
zSDXLRefinerModelFieldValue,
|
zSDXLRefinerModelFieldValue,
|
||||||
zStatefulFieldValue,
|
zStatefulFieldValue,
|
||||||
@ -344,6 +346,9 @@ export const nodesSlice = createSlice({
|
|||||||
fieldMainModelValueChanged: (state, action: FieldValueAction<MainModelFieldValue>) => {
|
fieldMainModelValueChanged: (state, action: FieldValueAction<MainModelFieldValue>) => {
|
||||||
fieldValueReducer(state, action, zMainModelFieldValue);
|
fieldValueReducer(state, action, zMainModelFieldValue);
|
||||||
},
|
},
|
||||||
|
fieldModelIdentifierValueChanged: (state, action: FieldValueAction<ModelIdentifierFieldValue>) => {
|
||||||
|
fieldValueReducer(state, action, zModelIdentifierFieldValue);
|
||||||
|
},
|
||||||
fieldRefinerModelValueChanged: (state, action: FieldValueAction<SDXLRefinerModelFieldValue>) => {
|
fieldRefinerModelValueChanged: (state, action: FieldValueAction<SDXLRefinerModelFieldValue>) => {
|
||||||
fieldValueReducer(state, action, zSDXLRefinerModelFieldValue);
|
fieldValueReducer(state, action, zSDXLRefinerModelFieldValue);
|
||||||
},
|
},
|
||||||
@ -469,6 +474,7 @@ export const {
|
|||||||
fieldT2IAdapterModelValueChanged,
|
fieldT2IAdapterModelValueChanged,
|
||||||
fieldLabelChanged,
|
fieldLabelChanged,
|
||||||
fieldLoRAModelValueChanged,
|
fieldLoRAModelValueChanged,
|
||||||
|
fieldModelIdentifierValueChanged,
|
||||||
fieldMainModelValueChanged,
|
fieldMainModelValueChanged,
|
||||||
fieldNumberValueChanged,
|
fieldNumberValueChanged,
|
||||||
fieldRefinerModelValueChanged,
|
fieldRefinerModelValueChanged,
|
||||||
|
@ -106,6 +106,10 @@ const zMainModelFieldType = zFieldTypeBase.extend({
|
|||||||
name: z.literal('MainModelField'),
|
name: z.literal('MainModelField'),
|
||||||
originalType: zStatelessFieldType.optional(),
|
originalType: zStatelessFieldType.optional(),
|
||||||
});
|
});
|
||||||
|
const zModelIdentifierFieldType = zFieldTypeBase.extend({
|
||||||
|
name: z.literal('ModelIdentifierField'),
|
||||||
|
originalType: zStatelessFieldType.optional(),
|
||||||
|
});
|
||||||
const zSDXLMainModelFieldType = zFieldTypeBase.extend({
|
const zSDXLMainModelFieldType = zFieldTypeBase.extend({
|
||||||
name: z.literal('SDXLMainModelField'),
|
name: z.literal('SDXLMainModelField'),
|
||||||
originalType: zStatelessFieldType.optional(),
|
originalType: zStatelessFieldType.optional(),
|
||||||
@ -146,6 +150,7 @@ const zStatefulFieldType = z.union([
|
|||||||
zEnumFieldType,
|
zEnumFieldType,
|
||||||
zImageFieldType,
|
zImageFieldType,
|
||||||
zBoardFieldType,
|
zBoardFieldType,
|
||||||
|
zModelIdentifierFieldType,
|
||||||
zMainModelFieldType,
|
zMainModelFieldType,
|
||||||
zSDXLMainModelFieldType,
|
zSDXLMainModelFieldType,
|
||||||
zSDXLRefinerModelFieldType,
|
zSDXLRefinerModelFieldType,
|
||||||
@ -396,6 +401,29 @@ export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFie
|
|||||||
zMainModelFieldInputTemplate.safeParse(val).success;
|
zMainModelFieldInputTemplate.safeParse(val).success;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
|
// #region ModelIdentifierField
|
||||||
|
export const zModelIdentifierFieldValue = zModelIdentifierField.optional();
|
||||||
|
const zModelIdentifierFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||||
|
value: zModelIdentifierFieldValue,
|
||||||
|
});
|
||||||
|
const zModelIdentifierFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||||
|
type: zModelIdentifierFieldType,
|
||||||
|
originalType: zFieldType.optional(),
|
||||||
|
default: zModelIdentifierFieldValue,
|
||||||
|
});
|
||||||
|
const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||||
|
type: zModelIdentifierFieldType,
|
||||||
|
originalType: zFieldType.optional(),
|
||||||
|
});
|
||||||
|
export type ModelIdentifierFieldValue = z.infer<typeof zModelIdentifierFieldValue>;
|
||||||
|
export type ModelIdentifierFieldInputInstance = z.infer<typeof zModelIdentifierFieldInputInstance>;
|
||||||
|
export type ModelIdentifierFieldInputTemplate = z.infer<typeof zModelIdentifierFieldInputTemplate>;
|
||||||
|
export const isModelIdentifierFieldInputInstance = (val: unknown): val is ModelIdentifierFieldInputInstance =>
|
||||||
|
zModelIdentifierFieldInputInstance.safeParse(val).success;
|
||||||
|
export const isModelIdentifierFieldInputTemplate = (val: unknown): val is ModelIdentifierFieldInputTemplate =>
|
||||||
|
zModelIdentifierFieldInputTemplate.safeParse(val).success;
|
||||||
|
// #endregion
|
||||||
|
|
||||||
// #region SDXLMainModelField
|
// #region SDXLMainModelField
|
||||||
|
|
||||||
const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||||
@ -643,6 +671,7 @@ export const zStatefulFieldValue = z.union([
|
|||||||
zEnumFieldValue,
|
zEnumFieldValue,
|
||||||
zImageFieldValue,
|
zImageFieldValue,
|
||||||
zBoardFieldValue,
|
zBoardFieldValue,
|
||||||
|
zModelIdentifierFieldValue,
|
||||||
zMainModelFieldValue,
|
zMainModelFieldValue,
|
||||||
zSDXLMainModelFieldValue,
|
zSDXLMainModelFieldValue,
|
||||||
zSDXLRefinerModelFieldValue,
|
zSDXLRefinerModelFieldValue,
|
||||||
@ -669,6 +698,7 @@ const zStatefulFieldInputInstance = z.union([
|
|||||||
zEnumFieldInputInstance,
|
zEnumFieldInputInstance,
|
||||||
zImageFieldInputInstance,
|
zImageFieldInputInstance,
|
||||||
zBoardFieldInputInstance,
|
zBoardFieldInputInstance,
|
||||||
|
zModelIdentifierFieldInputInstance,
|
||||||
zMainModelFieldInputInstance,
|
zMainModelFieldInputInstance,
|
||||||
zSDXLMainModelFieldInputInstance,
|
zSDXLMainModelFieldInputInstance,
|
||||||
zSDXLRefinerModelFieldInputInstance,
|
zSDXLRefinerModelFieldInputInstance,
|
||||||
@ -696,6 +726,7 @@ const zStatefulFieldInputTemplate = z.union([
|
|||||||
zEnumFieldInputTemplate,
|
zEnumFieldInputTemplate,
|
||||||
zImageFieldInputTemplate,
|
zImageFieldInputTemplate,
|
||||||
zBoardFieldInputTemplate,
|
zBoardFieldInputTemplate,
|
||||||
|
zModelIdentifierFieldInputTemplate,
|
||||||
zMainModelFieldInputTemplate,
|
zMainModelFieldInputTemplate,
|
||||||
zSDXLMainModelFieldInputTemplate,
|
zSDXLMainModelFieldInputTemplate,
|
||||||
zSDXLRefinerModelFieldInputTemplate,
|
zSDXLRefinerModelFieldInputTemplate,
|
||||||
@ -724,6 +755,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
|||||||
zEnumFieldOutputTemplate,
|
zEnumFieldOutputTemplate,
|
||||||
zImageFieldOutputTemplate,
|
zImageFieldOutputTemplate,
|
||||||
zBoardFieldOutputTemplate,
|
zBoardFieldOutputTemplate,
|
||||||
|
zModelIdentifierFieldOutputTemplate,
|
||||||
zMainModelFieldOutputTemplate,
|
zMainModelFieldOutputTemplate,
|
||||||
zSDXLMainModelFieldOutputTemplate,
|
zSDXLMainModelFieldOutputTemplate,
|
||||||
zSDXLRefinerModelFieldOutputTemplate,
|
zSDXLRefinerModelFieldOutputTemplate,
|
||||||
|
@ -11,6 +11,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
|||||||
IntegerField: 0,
|
IntegerField: 0,
|
||||||
IPAdapterModelField: undefined,
|
IPAdapterModelField: undefined,
|
||||||
LoRAModelField: undefined,
|
LoRAModelField: undefined,
|
||||||
|
ModelIdentifierField: undefined,
|
||||||
MainModelField: undefined,
|
MainModelField: undefined,
|
||||||
SchedulerField: 'euler',
|
SchedulerField: 'euler',
|
||||||
SDXLMainModelField: undefined,
|
SDXLMainModelField: undefined,
|
||||||
|
@ -13,6 +13,7 @@ import type {
|
|||||||
IPAdapterModelFieldInputTemplate,
|
IPAdapterModelFieldInputTemplate,
|
||||||
LoRAModelFieldInputTemplate,
|
LoRAModelFieldInputTemplate,
|
||||||
MainModelFieldInputTemplate,
|
MainModelFieldInputTemplate,
|
||||||
|
ModelIdentifierFieldInputTemplate,
|
||||||
SchedulerFieldInputTemplate,
|
SchedulerFieldInputTemplate,
|
||||||
SDXLMainModelFieldInputTemplate,
|
SDXLMainModelFieldInputTemplate,
|
||||||
SDXLRefinerModelFieldInputTemplate,
|
SDXLRefinerModelFieldInputTemplate,
|
||||||
@ -136,6 +137,20 @@ const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder<BooleanFieldInpu
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildModelIdentifierFieldInputTemplate: FieldInputTemplateBuilder<ModelIdentifierFieldInputTemplate> = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
fieldType,
|
||||||
|
}) => {
|
||||||
|
const template: ModelIdentifierFieldInputTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: fieldType,
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder<MainModelFieldInputTemplate> = ({
|
const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder<MainModelFieldInputTemplate> = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -355,6 +370,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
|||||||
IntegerField: buildIntegerFieldInputTemplate,
|
IntegerField: buildIntegerFieldInputTemplate,
|
||||||
IPAdapterModelField: buildIPAdapterModelFieldInputTemplate,
|
IPAdapterModelField: buildIPAdapterModelFieldInputTemplate,
|
||||||
LoRAModelField: buildLoRAModelFieldInputTemplate,
|
LoRAModelField: buildLoRAModelFieldInputTemplate,
|
||||||
|
ModelIdentifierField: buildModelIdentifierFieldInputTemplate,
|
||||||
MainModelField: buildMainModelFieldInputTemplate,
|
MainModelField: buildMainModelFieldInputTemplate,
|
||||||
SchedulerField: buildSchedulerFieldInputTemplate,
|
SchedulerField: buildSchedulerFieldInputTemplate,
|
||||||
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
|
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
|
||||||
|
Loading…
Reference in New Issue
Block a user