From a012bb6e071ab3dbd1d023d45068f698155531ff Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Fri, 17 May 2024 20:47:00 +1000
Subject: [PATCH] feat(ui): add ModelIdentifierField field type
This new field type accepts _any_ model. A field renderer lets the user select any available model.
---
.../Invocation/fields/InputFieldRenderer.tsx | 7 ++
.../ModelIdentifierFieldInputComponent.tsx | 68 +++++++++++++++++++
.../src/features/nodes/store/nodesSlice.ts | 6 ++
.../web/src/features/nodes/types/field.ts | 32 +++++++++
.../util/schema/buildFieldInputInstance.ts | 1 +
.../util/schema/buildFieldInputTemplate.ts | 16 +++++
6 files changed, 130 insertions(+)
create mode 100644 invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx
index b6e331c114..99937ceec4 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx
@@ -1,3 +1,4 @@
+import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
import {
@@ -23,6 +24,8 @@ import {
isLoRAModelFieldInputTemplate,
isMainModelFieldInputInstance,
isMainModelFieldInputTemplate,
+ isModelIdentifierFieldInputInstance,
+ isModelIdentifierFieldInputTemplate,
isSchedulerFieldInputInstance,
isSchedulerFieldInputTemplate,
isSDXLMainModelFieldInputInstance,
@@ -95,6 +98,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return ;
}
+ if (isModelIdentifierFieldInputInstance(fieldInstance) && isModelIdentifierFieldInputTemplate(fieldTemplate)) {
+ return ;
+ }
+
if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) {
return ;
}
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx
new file mode 100644
index 0000000000..6a0c9b63fa
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx
@@ -0,0 +1,68 @@
+import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
+import { EMPTY_ARRAY } from 'app/store/constants';
+import { useAppDispatch } from 'app/store/storeHooks';
+import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
+import { fieldModelIdentifierValueChanged } from 'features/nodes/store/nodesSlice';
+import type { ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate } from 'features/nodes/types/field';
+import { memo, useCallback, useMemo } from 'react';
+import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
+import type { AnyModelConfig } from 'services/api/types';
+
+import type { FieldComponentProps } from './types';
+
+type Props = FieldComponentProps;
+
+const ModelIdentifierFieldInputComponent = (props: Props) => {
+ const { nodeId, field } = props;
+ const dispatch = useAppDispatch();
+ const { data, isLoading } = useGetModelConfigsQuery();
+ const _onChange = useCallback(
+ (value: AnyModelConfig | null) => {
+ if (!value) {
+ return;
+ }
+ dispatch(
+ fieldModelIdentifierValueChanged({
+ nodeId,
+ fieldName: field.name,
+ value,
+ })
+ );
+ },
+ [dispatch, field.name, nodeId]
+ );
+
+ const modelConfigs = useMemo(() => {
+ if (!data) {
+ return EMPTY_ARRAY;
+ }
+
+ return modelConfigsAdapterSelectors.selectAll(data);
+ }, [data]);
+
+ console.log(modelConfigs);
+
+ const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
+ modelConfigs,
+ onChange: _onChange,
+ isLoading,
+ selectedModel: field.value,
+ groupByType: true,
+ });
+
+ return (
+
+
+
+
+
+ );
+};
+
+export default memo(ModelIdentifierFieldInputComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
index 1f61c77e83..cec13e8df4 100644
--- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
@@ -16,6 +16,7 @@ import type {
IPAdapterModelFieldValue,
LoRAModelFieldValue,
MainModelFieldValue,
+ ModelIdentifierFieldValue,
SchedulerFieldValue,
SDXLRefinerModelFieldValue,
StatefulFieldValue,
@@ -35,6 +36,7 @@ import {
zIPAdapterModelFieldValue,
zLoRAModelFieldValue,
zMainModelFieldValue,
+ zModelIdentifierFieldValue,
zSchedulerFieldValue,
zSDXLRefinerModelFieldValue,
zStatefulFieldValue,
@@ -344,6 +346,9 @@ export const nodesSlice = createSlice({
fieldMainModelValueChanged: (state, action: FieldValueAction) => {
fieldValueReducer(state, action, zMainModelFieldValue);
},
+ fieldModelIdentifierValueChanged: (state, action: FieldValueAction) => {
+ fieldValueReducer(state, action, zModelIdentifierFieldValue);
+ },
fieldRefinerModelValueChanged: (state, action: FieldValueAction) => {
fieldValueReducer(state, action, zSDXLRefinerModelFieldValue);
},
@@ -469,6 +474,7 @@ export const {
fieldT2IAdapterModelValueChanged,
fieldLabelChanged,
fieldLoRAModelValueChanged,
+ fieldModelIdentifierValueChanged,
fieldMainModelValueChanged,
fieldNumberValueChanged,
fieldRefinerModelValueChanged,
diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts
index 4dcc478352..a98f773c7e 100644
--- a/invokeai/frontend/web/src/features/nodes/types/field.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/field.ts
@@ -106,6 +106,10 @@ const zMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('MainModelField'),
originalType: zStatelessFieldType.optional(),
});
+const zModelIdentifierFieldType = zFieldTypeBase.extend({
+ name: z.literal('ModelIdentifierField'),
+ originalType: zStatelessFieldType.optional(),
+});
const zSDXLMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLMainModelField'),
originalType: zStatelessFieldType.optional(),
@@ -146,6 +150,7 @@ const zStatefulFieldType = z.union([
zEnumFieldType,
zImageFieldType,
zBoardFieldType,
+ zModelIdentifierFieldType,
zMainModelFieldType,
zSDXLMainModelFieldType,
zSDXLRefinerModelFieldType,
@@ -396,6 +401,29 @@ export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFie
zMainModelFieldInputTemplate.safeParse(val).success;
// #endregion
+// #region ModelIdentifierField
+export const zModelIdentifierFieldValue = zModelIdentifierField.optional();
+const zModelIdentifierFieldInputInstance = zFieldInputInstanceBase.extend({
+ value: zModelIdentifierFieldValue,
+});
+const zModelIdentifierFieldInputTemplate = zFieldInputTemplateBase.extend({
+ type: zModelIdentifierFieldType,
+ originalType: zFieldType.optional(),
+ default: zModelIdentifierFieldValue,
+});
+const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({
+ type: zModelIdentifierFieldType,
+ originalType: zFieldType.optional(),
+});
+export type ModelIdentifierFieldValue = z.infer;
+export type ModelIdentifierFieldInputInstance = z.infer;
+export type ModelIdentifierFieldInputTemplate = z.infer;
+export const isModelIdentifierFieldInputInstance = (val: unknown): val is ModelIdentifierFieldInputInstance =>
+ zModelIdentifierFieldInputInstance.safeParse(val).success;
+export const isModelIdentifierFieldInputTemplate = (val: unknown): val is ModelIdentifierFieldInputTemplate =>
+ zModelIdentifierFieldInputTemplate.safeParse(val).success;
+// #endregion
+
// #region SDXLMainModelField
const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
@@ -643,6 +671,7 @@ export const zStatefulFieldValue = z.union([
zEnumFieldValue,
zImageFieldValue,
zBoardFieldValue,
+ zModelIdentifierFieldValue,
zMainModelFieldValue,
zSDXLMainModelFieldValue,
zSDXLRefinerModelFieldValue,
@@ -669,6 +698,7 @@ const zStatefulFieldInputInstance = z.union([
zEnumFieldInputInstance,
zImageFieldInputInstance,
zBoardFieldInputInstance,
+ zModelIdentifierFieldInputInstance,
zMainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
@@ -696,6 +726,7 @@ const zStatefulFieldInputTemplate = z.union([
zEnumFieldInputTemplate,
zImageFieldInputTemplate,
zBoardFieldInputTemplate,
+ zModelIdentifierFieldInputTemplate,
zMainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate,
zSDXLRefinerModelFieldInputTemplate,
@@ -724,6 +755,7 @@ const zStatefulFieldOutputTemplate = z.union([
zEnumFieldOutputTemplate,
zImageFieldOutputTemplate,
zBoardFieldOutputTemplate,
+ zModelIdentifierFieldOutputTemplate,
zMainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate,
zSDXLRefinerModelFieldOutputTemplate,
diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts
index f8097566c9..597779fd61 100644
--- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts
@@ -11,6 +11,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record =
IntegerField: 0,
IPAdapterModelField: undefined,
LoRAModelField: undefined,
+ ModelIdentifierField: undefined,
MainModelField: undefined,
SchedulerField: 'euler',
SDXLMainModelField: undefined,
diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts
index 6b4c4d8b29..2b77274526 100644
--- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts
@@ -13,6 +13,7 @@ import type {
IPAdapterModelFieldInputTemplate,
LoRAModelFieldInputTemplate,
MainModelFieldInputTemplate,
+ ModelIdentifierFieldInputTemplate,
SchedulerFieldInputTemplate,
SDXLMainModelFieldInputTemplate,
SDXLRefinerModelFieldInputTemplate,
@@ -136,6 +137,20 @@ const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder = ({
+ schemaObject,
+ baseField,
+ fieldType,
+}) => {
+ const template: ModelIdentifierFieldInputTemplate = {
+ ...baseField,
+ type: fieldType,
+ default: schemaObject.default ?? undefined,
+ };
+
+ return template;
+};
+
const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder = ({
schemaObject,
baseField,
@@ -355,6 +370,7 @@ export const TEMPLATE_BUILDER_MAP: Record