diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx index 861f919b33..41e579b3cc 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx @@ -1,15 +1,17 @@ +import { Flex, Text } from '@chakra-ui/react'; import { SelectItem } from '@mantine/core'; import { useAppDispatch } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { VaeModelInputFieldTemplate, VaeModelInputFieldValue, } from 'features/nodes/types/types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; -import { forEach, isString } from 'lodash-es'; -import { memo, useCallback, useEffect, useMemo } from 'react'; -import { useTranslation } from 'react-i18next'; +import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam'; +import { forEach } from 'lodash-es'; +import { memo, useCallback, useMemo } from 'react'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; import { FieldComponentProps } from './types'; @@ -20,17 +22,10 @@ const LoRAModelInputFieldComponent = ( > ) => { const { nodeId, field } = props; - + const lora = field.value; const dispatch = useAppDispatch(); - const { t } = useTranslation(); - const { data: loraModels } = useGetLoRAModelsQuery(); - const selectedModel = useMemo( - () => loraModels?.entities[field.value ?? loraModels.ids[0]], - [loraModels?.entities, loraModels?.ids, field.value] - ); - const data = useMemo(() => { if (!loraModels) { return []; @@ -38,62 +33,78 @@ const LoRAModelInputFieldComponent = ( const data: SelectItem[] = []; - forEach(loraModels.entities, (model, id) => { - if (!model) { + forEach(loraModels.entities, (lora, id) => { + if (!lora) { return; } data.push({ value: id, - label: model.model_name, - group: MODEL_TYPE_MAP[model.base_model], + label: lora.model_name, + group: MODEL_TYPE_MAP[lora.base_model], }); }); - return data; + return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1)); }, [loraModels]); - const handleValueChanged = useCallback( + const selectedLoRAModel = useMemo( + () => + loraModels?.entities[`${lora?.base_model}/lora/${lora?.model_name}`] ?? + null, + [loraModels?.entities, lora?.base_model, lora?.model_name] + ); + + const handleChange = useCallback( (v: string | null) => { if (!v) { return; } + const newLoRAModel = modelIdToLoRAModelParam(v); + + if (!newLoRAModel) { + return; + } + dispatch( fieldValueChanged({ nodeId, fieldName: field.name, - value: v, + value: newLoRAModel, }) ); }, [dispatch, field.name, nodeId] ); - useEffect(() => { - if (field.value && loraModels?.ids.includes(field.value)) { - return; - } - - const firstLora = loraModels?.ids[0]; - - if (!isString(firstLora)) { - return; - } - - handleValueChanged(firstLora); - }, [field.value, handleValueChanged, loraModels?.ids]); + if (loraModels?.ids.length === 0) { + return ( + + + No LoRAs Loaded + + + ); + } return ( 0 ? 'Select a LoRA' : 'No LoRAs available'} data={data} - onChange={handleValueChanged} + nothingFound="No matching LoRAs" + itemComponent={IAIMantineSelectItemWithTooltip} + disabled={data.length === 0} + filter={(value, item: SelectItem) => + item.label?.toLowerCase().includes(value.toLowerCase().trim()) || + item.value.toLowerCase().includes(value.toLowerCase().trim()) + } + onChange={handleChange} /> ); }; diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx index 124c180eb3..43dbbba73f 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx @@ -1,28 +1,29 @@ -import { SelectItem } from '@mantine/core'; import { useAppDispatch } from 'app/store/storeHooks'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { + MainModelInputFieldValue, ModelInputFieldTemplate, - ModelInputFieldValue, } from 'features/nodes/types/types'; +import { SelectItem } from '@mantine/core'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; -import { forEach, isString } from 'lodash-es'; -import { memo, useCallback, useEffect, useMemo } from 'react'; +import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; +import { forEach } from 'lodash-es'; +import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { FieldComponentProps } from './types'; const ModelInputFieldComponent = ( - props: FieldComponentProps + props: FieldComponentProps ) => { const { nodeId, field } = props; const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { data: mainModels } = useGetMainModelsQuery(); + const { data: mainModels, isLoading } = useGetMainModelsQuery(); const data = useMemo(() => { if (!mainModels) { @@ -46,52 +47,58 @@ const ModelInputFieldComponent = ( return data; }, [mainModels]); + // grab the full model entity from the RTK Query cache + // TODO: maybe we should just store the full model entity in state? const selectedModel = useMemo( - () => mainModels?.entities[field.value ?? mainModels.ids[0]], - [mainModels?.entities, mainModels?.ids, field.value] + () => + mainModels?.entities[ + `${field.value?.base_model}/main/${field.value?.model_name}` + ] ?? null, + [field.value?.base_model, field.value?.model_name, mainModels?.entities] ); - const handleValueChanged = useCallback( + const handleChangeModel = useCallback( (v: string | null) => { if (!v) { return; } + const newModel = modelIdToMainModelParam(v); + + if (!newModel) { + return; + } + dispatch( fieldValueChanged({ nodeId, fieldName: field.name, - value: v, + value: newModel, }) ); }, [dispatch, field.name, nodeId] ); - useEffect(() => { - if (field.value && mainModels?.ids.includes(field.value)) { - return; - } - - const firstModel = mainModels?.ids[0]; - - if (!isString(firstModel)) { - return; - } - - handleValueChanged(firstModel); - }, [field.value, handleValueChanged, mainModels?.ids]); - - return ( + return isLoading ? ( + + ) : ( 0 ? 'Select a model' : 'No models available'} data={data} - onChange={handleValueChanged} + error={data.length === 0} + disabled={data.length === 0} + onChange={handleChangeModel} /> ); }; diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/UNetInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/UNetInputFieldComponent.tsx deleted file mode 100644 index 5926bf113a..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/fields/UNetInputFieldComponent.tsx +++ /dev/null @@ -1,16 +0,0 @@ -import { - UNetInputFieldTemplate, - UNetInputFieldValue, -} from 'features/nodes/types/types'; -import { memo } from 'react'; -import { FieldComponentProps } from './types'; - -const UNetInputFieldComponent = ( - props: FieldComponentProps -) => { - const { nodeId, field } = props; - - return null; -}; - -export default memo(UNetInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/VaeInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/VaeInputFieldComponent.tsx deleted file mode 100644 index 0fa11ae34e..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/fields/VaeInputFieldComponent.tsx +++ /dev/null @@ -1,16 +0,0 @@ -import { - VaeInputFieldTemplate, - VaeInputFieldValue, -} from 'features/nodes/types/types'; -import { memo } from 'react'; -import { FieldComponentProps } from './types'; - -const VaeInputFieldComponent = ( - props: FieldComponentProps -) => { - const { nodeId, field } = props; - - return null; -}; - -export default memo(VaeInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx index 54ab7363ba..afbd294a27 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx @@ -1,14 +1,16 @@ import { SelectItem } from '@mantine/core'; import { useAppDispatch } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { VaeModelInputFieldTemplate, VaeModelInputFieldValue, } from 'features/nodes/types/types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; +import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam'; import { forEach } from 'lodash-es'; -import { memo, useCallback, useEffect, useMemo } from 'react'; +import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; import { FieldComponentProps } from './types'; @@ -20,73 +22,83 @@ const VaeModelInputFieldComponent = ( > ) => { const { nodeId, field } = props; - + const vae = field.value; const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { data: vaeModels } = useGetVaeModelsQuery(); - const selectedModel = useMemo( - () => vaeModels?.entities[field.value ?? vaeModels.ids[0]], - [vaeModels?.entities, vaeModels?.ids, field.value] - ); - const data = useMemo(() => { if (!vaeModels) { return []; } - const data: SelectItem[] = []; + const data: SelectItem[] = [ + { + value: 'default', + label: 'Default', + group: 'Default', + }, + ]; - forEach(vaeModels.entities, (model, id) => { - if (!model) { + forEach(vaeModels.entities, (vae, id) => { + if (!vae) { return; } data.push({ value: id, - label: model.model_name, - group: MODEL_TYPE_MAP[model.base_model], + label: vae.model_name, + group: MODEL_TYPE_MAP[vae.base_model], }); }); - return data; + return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1)); }, [vaeModels]); - const handleValueChanged = useCallback( + // grab the full model entity from the RTK Query cache + const selectedVaeModel = useMemo( + () => + vaeModels?.entities[`${vae?.base_model}/vae/${vae?.model_name}`] ?? null, + [vaeModels?.entities, vae] + ); + + const handleChangeModel = useCallback( (v: string | null) => { if (!v) { return; } + const newVaeModel = modelIdToVAEModelParam(v); + + if (!newVaeModel) { + return; + } + dispatch( fieldValueChanged({ nodeId, fieldName: field.name, - value: v, + value: newVaeModel, }) ); }, [dispatch, field.name, nodeId] ); - useEffect(() => { - if (field.value && vaeModels?.ids.includes(field.value)) { - return; - } - handleValueChanged('auto'); - }, [field.value, handleValueChanged, vaeModels?.ids]); - return ( ); }; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 91b6f685e6..8255c65045 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -1,5 +1,10 @@ import { createSlice, PayloadAction } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; +import { + LoRAModelParam, + MainModelParam, + VaeModelParam, +} from 'features/parameters/types/parameterSchemas'; import { cloneDeep, uniqBy } from 'lodash-es'; import { OpenAPIV3 } from 'openapi-types'; import { RgbaColor } from 'react-colorful'; @@ -73,7 +78,10 @@ const nodesSlice = createSlice({ | ImageField | RgbaColor | undefined - | ImageField[]; + | ImageField[] + | MainModelParam + | VaeModelParam + | LoRAModelParam; }> ) => { const { nodeId, fieldName, value } = action.payload; diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 3de8cae9ff..4c47c63068 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -1,3 +1,8 @@ +import { + LoRAModelParam, + MainModelParam, + VaeModelParam, +} from 'features/parameters/types/parameterSchemas'; import { OpenAPIV3 } from 'openapi-types'; import { RgbaColor } from 'react-colorful'; import { Graph, ImageDTO, ImageField } from 'services/api/types'; @@ -92,7 +97,7 @@ export type InputFieldValue = | VaeInputFieldValue | ControlInputFieldValue | EnumInputFieldValue - | ModelInputFieldValue + | MainModelInputFieldValue | VaeModelInputFieldValue | LoRAModelInputFieldValue | ArrayInputFieldValue @@ -229,19 +234,19 @@ export type ImageCollectionInputFieldValue = FieldValueBase & { value?: ImageField[]; }; -export type ModelInputFieldValue = FieldValueBase & { +export type MainModelInputFieldValue = FieldValueBase & { type: 'model'; - value?: string; + value?: MainModelParam; }; export type VaeModelInputFieldValue = FieldValueBase & { type: 'vae_model'; - value?: string; + value?: VaeModelParam; }; export type LoRAModelInputFieldValue = FieldValueBase & { type: 'lora_model'; - value?: string; + value?: LoRAModelParam; }; export type ArrayInputFieldValue = FieldValueBase & { diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts index 6963cf16b8..64d579ce8b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts @@ -1,8 +1,5 @@ import { RootState } from 'app/store/store'; import { InputFieldValue } from 'features/nodes/types/types'; -import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam'; -import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; -import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam'; import { cloneDeep, omit, reduce } from 'lodash-es'; import { Graph } from 'services/api/types'; import { AnyInvocation } from 'services/events/types'; @@ -27,24 +24,6 @@ export const parseFieldValue = (field: InputFieldValue) => { } } - if (field.type === 'model') { - if (field.value) { - return modelIdToMainModelParam(field.value); - } - } - - if (field.type === 'vae_model') { - if (field.value) { - return modelIdToVAEModelParam(field.value); - } - } - - if (field.type === 'lora_model') { - if (field.value) { - return modelIdToLoRAModelParam(field.value); - } - } - return field.value; };