diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index b6e331c114..99937ceec4 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -1,3 +1,4 @@ +import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent'; import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance'; import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate'; import { @@ -23,6 +24,8 @@ import { isLoRAModelFieldInputTemplate, isMainModelFieldInputInstance, isMainModelFieldInputTemplate, + isModelIdentifierFieldInputInstance, + isModelIdentifierFieldInputTemplate, isSchedulerFieldInputInstance, isSchedulerFieldInputTemplate, isSDXLMainModelFieldInputInstance, @@ -95,6 +98,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { return ; } + if (isModelIdentifierFieldInputInstance(fieldInstance) && isModelIdentifierFieldInputTemplate(fieldTemplate)) { + return ; + } + if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) { return ; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx new file mode 100644 index 0000000000..6a0c9b63fa --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx @@ -0,0 +1,68 @@ +import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library'; +import { EMPTY_ARRAY } from 'app/store/constants'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { fieldModelIdentifierValueChanged } from 'features/nodes/store/nodesSlice'; +import type { ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate } from 'features/nodes/types/field'; +import { memo, useCallback, useMemo } from 'react'; +import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/types'; + +import type { FieldComponentProps } from './types'; + +type Props = FieldComponentProps; + +const ModelIdentifierFieldInputComponent = (props: Props) => { + const { nodeId, field } = props; + const dispatch = useAppDispatch(); + const { data, isLoading } = useGetModelConfigsQuery(); + const _onChange = useCallback( + (value: AnyModelConfig | null) => { + if (!value) { + return; + } + dispatch( + fieldModelIdentifierValueChanged({ + nodeId, + fieldName: field.name, + value, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + const modelConfigs = useMemo(() => { + if (!data) { + return EMPTY_ARRAY; + } + + return modelConfigsAdapterSelectors.selectAll(data); + }, [data]); + + console.log(modelConfigs); + + const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ + modelConfigs, + onChange: _onChange, + isLoading, + selectedModel: field.value, + groupByType: true, + }); + + return ( + + + + + + ); +}; + +export default memo(ModelIdentifierFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 1f61c77e83..cec13e8df4 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -16,6 +16,7 @@ import type { IPAdapterModelFieldValue, LoRAModelFieldValue, MainModelFieldValue, + ModelIdentifierFieldValue, SchedulerFieldValue, SDXLRefinerModelFieldValue, StatefulFieldValue, @@ -35,6 +36,7 @@ import { zIPAdapterModelFieldValue, zLoRAModelFieldValue, zMainModelFieldValue, + zModelIdentifierFieldValue, zSchedulerFieldValue, zSDXLRefinerModelFieldValue, zStatefulFieldValue, @@ -344,6 +346,9 @@ export const nodesSlice = createSlice({ fieldMainModelValueChanged: (state, action: FieldValueAction) => { fieldValueReducer(state, action, zMainModelFieldValue); }, + fieldModelIdentifierValueChanged: (state, action: FieldValueAction) => { + fieldValueReducer(state, action, zModelIdentifierFieldValue); + }, fieldRefinerModelValueChanged: (state, action: FieldValueAction) => { fieldValueReducer(state, action, zSDXLRefinerModelFieldValue); }, @@ -469,6 +474,7 @@ export const { fieldT2IAdapterModelValueChanged, fieldLabelChanged, fieldLoRAModelValueChanged, + fieldModelIdentifierValueChanged, fieldMainModelValueChanged, fieldNumberValueChanged, fieldRefinerModelValueChanged, diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index 4dcc478352..a98f773c7e 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -106,6 +106,10 @@ const zMainModelFieldType = zFieldTypeBase.extend({ name: z.literal('MainModelField'), originalType: zStatelessFieldType.optional(), }); +const zModelIdentifierFieldType = zFieldTypeBase.extend({ + name: z.literal('ModelIdentifierField'), + originalType: zStatelessFieldType.optional(), +}); const zSDXLMainModelFieldType = zFieldTypeBase.extend({ name: z.literal('SDXLMainModelField'), originalType: zStatelessFieldType.optional(), @@ -146,6 +150,7 @@ const zStatefulFieldType = z.union([ zEnumFieldType, zImageFieldType, zBoardFieldType, + zModelIdentifierFieldType, zMainModelFieldType, zSDXLMainModelFieldType, zSDXLRefinerModelFieldType, @@ -396,6 +401,29 @@ export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFie zMainModelFieldInputTemplate.safeParse(val).success; // #endregion +// #region ModelIdentifierField +export const zModelIdentifierFieldValue = zModelIdentifierField.optional(); +const zModelIdentifierFieldInputInstance = zFieldInputInstanceBase.extend({ + value: zModelIdentifierFieldValue, +}); +const zModelIdentifierFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zModelIdentifierFieldType, + originalType: zFieldType.optional(), + default: zModelIdentifierFieldValue, +}); +const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zModelIdentifierFieldType, + originalType: zFieldType.optional(), +}); +export type ModelIdentifierFieldValue = z.infer; +export type ModelIdentifierFieldInputInstance = z.infer; +export type ModelIdentifierFieldInputTemplate = z.infer; +export const isModelIdentifierFieldInputInstance = (val: unknown): val is ModelIdentifierFieldInputInstance => + zModelIdentifierFieldInputInstance.safeParse(val).success; +export const isModelIdentifierFieldInputTemplate = (val: unknown): val is ModelIdentifierFieldInputTemplate => + zModelIdentifierFieldInputTemplate.safeParse(val).success; +// #endregion + // #region SDXLMainModelField const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. @@ -643,6 +671,7 @@ export const zStatefulFieldValue = z.union([ zEnumFieldValue, zImageFieldValue, zBoardFieldValue, + zModelIdentifierFieldValue, zMainModelFieldValue, zSDXLMainModelFieldValue, zSDXLRefinerModelFieldValue, @@ -669,6 +698,7 @@ const zStatefulFieldInputInstance = z.union([ zEnumFieldInputInstance, zImageFieldInputInstance, zBoardFieldInputInstance, + zModelIdentifierFieldInputInstance, zMainModelFieldInputInstance, zSDXLMainModelFieldInputInstance, zSDXLRefinerModelFieldInputInstance, @@ -696,6 +726,7 @@ const zStatefulFieldInputTemplate = z.union([ zEnumFieldInputTemplate, zImageFieldInputTemplate, zBoardFieldInputTemplate, + zModelIdentifierFieldInputTemplate, zMainModelFieldInputTemplate, zSDXLMainModelFieldInputTemplate, zSDXLRefinerModelFieldInputTemplate, @@ -724,6 +755,7 @@ const zStatefulFieldOutputTemplate = z.union([ zEnumFieldOutputTemplate, zImageFieldOutputTemplate, zBoardFieldOutputTemplate, + zModelIdentifierFieldOutputTemplate, zMainModelFieldOutputTemplate, zSDXLMainModelFieldOutputTemplate, zSDXLRefinerModelFieldOutputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index f8097566c9..597779fd61 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -11,6 +11,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record = IntegerField: 0, IPAdapterModelField: undefined, LoRAModelField: undefined, + ModelIdentifierField: undefined, MainModelField: undefined, SchedulerField: 'euler', SDXLMainModelField: undefined, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts index 6b4c4d8b29..2b77274526 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts @@ -13,6 +13,7 @@ import type { IPAdapterModelFieldInputTemplate, LoRAModelFieldInputTemplate, MainModelFieldInputTemplate, + ModelIdentifierFieldInputTemplate, SchedulerFieldInputTemplate, SDXLMainModelFieldInputTemplate, SDXLRefinerModelFieldInputTemplate, @@ -136,6 +137,20 @@ const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder = ({ + schemaObject, + baseField, + fieldType, +}) => { + const template: ModelIdentifierFieldInputTemplate = { + ...baseField, + type: fieldType, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, @@ -355,6 +370,7 @@ export const TEMPLATE_BUILDER_MAP: Record