diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx
index 861f919b33..41e579b3cc 100644
--- a/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx
@@ -1,15 +1,17 @@
+import { Flex, Text } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
+import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
VaeModelInputFieldTemplate,
VaeModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
-import { forEach, isString } from 'lodash-es';
-import { memo, useCallback, useEffect, useMemo } from 'react';
-import { useTranslation } from 'react-i18next';
+import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam';
+import { forEach } from 'lodash-es';
+import { memo, useCallback, useMemo } from 'react';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
@@ -20,17 +22,10 @@ const LoRAModelInputFieldComponent = (
>
) => {
const { nodeId, field } = props;
-
+ const lora = field.value;
const dispatch = useAppDispatch();
- const { t } = useTranslation();
-
const { data: loraModels } = useGetLoRAModelsQuery();
- const selectedModel = useMemo(
- () => loraModels?.entities[field.value ?? loraModels.ids[0]],
- [loraModels?.entities, loraModels?.ids, field.value]
- );
-
const data = useMemo(() => {
if (!loraModels) {
return [];
@@ -38,62 +33,78 @@ const LoRAModelInputFieldComponent = (
const data: SelectItem[] = [];
- forEach(loraModels.entities, (model, id) => {
- if (!model) {
+ forEach(loraModels.entities, (lora, id) => {
+ if (!lora) {
return;
}
data.push({
value: id,
- label: model.model_name,
- group: MODEL_TYPE_MAP[model.base_model],
+ label: lora.model_name,
+ group: MODEL_TYPE_MAP[lora.base_model],
});
});
- return data;
+ return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [loraModels]);
- const handleValueChanged = useCallback(
+ const selectedLoRAModel = useMemo(
+ () =>
+ loraModels?.entities[`${lora?.base_model}/lora/${lora?.model_name}`] ??
+ null,
+ [loraModels?.entities, lora?.base_model, lora?.model_name]
+ );
+
+ const handleChange = useCallback(
(v: string | null) => {
if (!v) {
return;
}
+ const newLoRAModel = modelIdToLoRAModelParam(v);
+
+ if (!newLoRAModel) {
+ return;
+ }
+
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
- value: v,
+ value: newLoRAModel,
})
);
},
[dispatch, field.name, nodeId]
);
- useEffect(() => {
- if (field.value && loraModels?.ids.includes(field.value)) {
- return;
- }
-
- const firstLora = loraModels?.ids[0];
-
- if (!isString(firstLora)) {
- return;
- }
-
- handleValueChanged(firstLora);
- }, [field.value, handleValueChanged, loraModels?.ids]);
+ if (loraModels?.ids.length === 0) {
+ return (
+
+
+ No LoRAs Loaded
+
+
+ );
+ }
return (
0 ? 'Select a LoRA' : 'No LoRAs available'}
data={data}
- onChange={handleValueChanged}
+ nothingFound="No matching LoRAs"
+ itemComponent={IAIMantineSelectItemWithTooltip}
+ disabled={data.length === 0}
+ filter={(value, item: SelectItem) =>
+ item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
+ item.value.toLowerCase().includes(value.toLowerCase().trim())
+ }
+ onChange={handleChange}
/>
);
};
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx
index 124c180eb3..43dbbba73f 100644
--- a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx
@@ -1,28 +1,29 @@
-import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
+ MainModelInputFieldValue,
ModelInputFieldTemplate,
- ModelInputFieldValue,
} from 'features/nodes/types/types';
+import { SelectItem } from '@mantine/core';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
-import { forEach, isString } from 'lodash-es';
-import { memo, useCallback, useEffect, useMemo } from 'react';
+import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
+import { forEach } from 'lodash-es';
+import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const ModelInputFieldComponent = (
- props: FieldComponentProps
+ props: FieldComponentProps
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
- const { data: mainModels } = useGetMainModelsQuery();
+ const { data: mainModels, isLoading } = useGetMainModelsQuery();
const data = useMemo(() => {
if (!mainModels) {
@@ -46,52 +47,58 @@ const ModelInputFieldComponent = (
return data;
}, [mainModels]);
+ // grab the full model entity from the RTK Query cache
+ // TODO: maybe we should just store the full model entity in state?
const selectedModel = useMemo(
- () => mainModels?.entities[field.value ?? mainModels.ids[0]],
- [mainModels?.entities, mainModels?.ids, field.value]
+ () =>
+ mainModels?.entities[
+ `${field.value?.base_model}/main/${field.value?.model_name}`
+ ] ?? null,
+ [field.value?.base_model, field.value?.model_name, mainModels?.entities]
);
- const handleValueChanged = useCallback(
+ const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
return;
}
+ const newModel = modelIdToMainModelParam(v);
+
+ if (!newModel) {
+ return;
+ }
+
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
- value: v,
+ value: newModel,
})
);
},
[dispatch, field.name, nodeId]
);
- useEffect(() => {
- if (field.value && mainModels?.ids.includes(field.value)) {
- return;
- }
-
- const firstModel = mainModels?.ids[0];
-
- if (!isString(firstModel)) {
- return;
- }
-
- handleValueChanged(firstModel);
- }, [field.value, handleValueChanged, mainModels?.ids]);
-
- return (
+ return isLoading ? (
+
+ ) : (
0 ? 'Select a model' : 'No models available'}
data={data}
- onChange={handleValueChanged}
+ error={data.length === 0}
+ disabled={data.length === 0}
+ onChange={handleChangeModel}
/>
);
};
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/UNetInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/UNetInputFieldComponent.tsx
deleted file mode 100644
index 5926bf113a..0000000000
--- a/invokeai/frontend/web/src/features/nodes/components/fields/UNetInputFieldComponent.tsx
+++ /dev/null
@@ -1,16 +0,0 @@
-import {
- UNetInputFieldTemplate,
- UNetInputFieldValue,
-} from 'features/nodes/types/types';
-import { memo } from 'react';
-import { FieldComponentProps } from './types';
-
-const UNetInputFieldComponent = (
- props: FieldComponentProps
-) => {
- const { nodeId, field } = props;
-
- return null;
-};
-
-export default memo(UNetInputFieldComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/VaeInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/VaeInputFieldComponent.tsx
deleted file mode 100644
index 0fa11ae34e..0000000000
--- a/invokeai/frontend/web/src/features/nodes/components/fields/VaeInputFieldComponent.tsx
+++ /dev/null
@@ -1,16 +0,0 @@
-import {
- VaeInputFieldTemplate,
- VaeInputFieldValue,
-} from 'features/nodes/types/types';
-import { memo } from 'react';
-import { FieldComponentProps } from './types';
-
-const VaeInputFieldComponent = (
- props: FieldComponentProps
-) => {
- const { nodeId, field } = props;
-
- return null;
-};
-
-export default memo(VaeInputFieldComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx
index 54ab7363ba..afbd294a27 100644
--- a/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx
@@ -1,14 +1,16 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
+import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
VaeModelInputFieldTemplate,
VaeModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
+import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
import { forEach } from 'lodash-es';
-import { memo, useCallback, useEffect, useMemo } from 'react';
+import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
@@ -20,73 +22,83 @@ const VaeModelInputFieldComponent = (
>
) => {
const { nodeId, field } = props;
-
+ const vae = field.value;
const dispatch = useAppDispatch();
const { t } = useTranslation();
-
const { data: vaeModels } = useGetVaeModelsQuery();
- const selectedModel = useMemo(
- () => vaeModels?.entities[field.value ?? vaeModels.ids[0]],
- [vaeModels?.entities, vaeModels?.ids, field.value]
- );
-
const data = useMemo(() => {
if (!vaeModels) {
return [];
}
- const data: SelectItem[] = [];
+ const data: SelectItem[] = [
+ {
+ value: 'default',
+ label: 'Default',
+ group: 'Default',
+ },
+ ];
- forEach(vaeModels.entities, (model, id) => {
- if (!model) {
+ forEach(vaeModels.entities, (vae, id) => {
+ if (!vae) {
return;
}
data.push({
value: id,
- label: model.model_name,
- group: MODEL_TYPE_MAP[model.base_model],
+ label: vae.model_name,
+ group: MODEL_TYPE_MAP[vae.base_model],
});
});
- return data;
+ return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [vaeModels]);
- const handleValueChanged = useCallback(
+ // grab the full model entity from the RTK Query cache
+ const selectedVaeModel = useMemo(
+ () =>
+ vaeModels?.entities[`${vae?.base_model}/vae/${vae?.model_name}`] ?? null,
+ [vaeModels?.entities, vae]
+ );
+
+ const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
return;
}
+ const newVaeModel = modelIdToVAEModelParam(v);
+
+ if (!newVaeModel) {
+ return;
+ }
+
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
- value: v,
+ value: newVaeModel,
})
);
},
[dispatch, field.name, nodeId]
);
- useEffect(() => {
- if (field.value && vaeModels?.ids.includes(field.value)) {
- return;
- }
- handleValueChanged('auto');
- }, [field.value, handleValueChanged, vaeModels?.ids]);
-
return (
);
};
diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
index 91b6f685e6..8255c65045 100644
--- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
@@ -1,5 +1,10 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
+import {
+ LoRAModelParam,
+ MainModelParam,
+ VaeModelParam,
+} from 'features/parameters/types/parameterSchemas';
import { cloneDeep, uniqBy } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful';
@@ -73,7 +78,10 @@ const nodesSlice = createSlice({
| ImageField
| RgbaColor
| undefined
- | ImageField[];
+ | ImageField[]
+ | MainModelParam
+ | VaeModelParam
+ | LoRAModelParam;
}>
) => {
const { nodeId, fieldName, value } = action.payload;
diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts
index 3de8cae9ff..4c47c63068 100644
--- a/invokeai/frontend/web/src/features/nodes/types/types.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/types.ts
@@ -1,3 +1,8 @@
+import {
+ LoRAModelParam,
+ MainModelParam,
+ VaeModelParam,
+} from 'features/parameters/types/parameterSchemas';
import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful';
import { Graph, ImageDTO, ImageField } from 'services/api/types';
@@ -92,7 +97,7 @@ export type InputFieldValue =
| VaeInputFieldValue
| ControlInputFieldValue
| EnumInputFieldValue
- | ModelInputFieldValue
+ | MainModelInputFieldValue
| VaeModelInputFieldValue
| LoRAModelInputFieldValue
| ArrayInputFieldValue
@@ -229,19 +234,19 @@ export type ImageCollectionInputFieldValue = FieldValueBase & {
value?: ImageField[];
};
-export type ModelInputFieldValue = FieldValueBase & {
+export type MainModelInputFieldValue = FieldValueBase & {
type: 'model';
- value?: string;
+ value?: MainModelParam;
};
export type VaeModelInputFieldValue = FieldValueBase & {
type: 'vae_model';
- value?: string;
+ value?: VaeModelParam;
};
export type LoRAModelInputFieldValue = FieldValueBase & {
type: 'lora_model';
- value?: string;
+ value?: LoRAModelParam;
};
export type ArrayInputFieldValue = FieldValueBase & {
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts
index 6963cf16b8..64d579ce8b 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts
@@ -1,8 +1,5 @@
import { RootState } from 'app/store/store';
import { InputFieldValue } from 'features/nodes/types/types';
-import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam';
-import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
-import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
import { cloneDeep, omit, reduce } from 'lodash-es';
import { Graph } from 'services/api/types';
import { AnyInvocation } from 'services/events/types';
@@ -27,24 +24,6 @@ export const parseFieldValue = (field: InputFieldValue) => {
}
}
- if (field.type === 'model') {
- if (field.value) {
- return modelIdToMainModelParam(field.value);
- }
- }
-
- if (field.type === 'vae_model') {
- if (field.value) {
- return modelIdToVAEModelParam(field.value);
- }
- }
-
- if (field.type === 'lora_model') {
- if (field.value) {
- return modelIdToLoRAModelParam(field.value);
- }
- }
-
return field.value;
};