From e7e3045a8ac1ed714cf4ccbb6ae16a9f3e3958ed Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 11:14:08 +1100 Subject: [PATCH] fix(ui): get vae model select working --- .../common/hooks/useGroupedModelCombobox.ts | 4 +--- .../web/src/common/hooks/useModelCombobox.ts | 12 ++++++------ .../VAEModel/ParamVAEModelSelect.tsx | 18 ++++++++++-------- .../web/src/services/api/endpoints/models.ts | 3 --- 4 files changed, 17 insertions(+), 20 deletions(-) diff --git a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts index 140cf3eaa6..fc5bc455ee 100644 --- a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts @@ -6,7 +6,6 @@ import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; import { groupBy, map, reduce } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { getModelId } from 'services/api/endpoints/models'; import type { AnyModelConfig } from 'services/api/types'; type UseGroupedModelComboboxArg = { @@ -58,8 +57,7 @@ export const useGroupedModelCombobox = ( const value = useMemo( () => - options.flatMap((o) => o.options).find((m) => (selectedModel ? m.value === getModelId(selectedModel) : false)) ?? - null, + options.flatMap((o) => o.options).find((m) => (selectedModel ? m.value === selectedModel.key : false)) ?? null, [options, selectedModel] ); diff --git a/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts index 07e6aeb34c..e0718d6413 100644 --- a/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts @@ -1,14 +1,14 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import type { EntityState } from '@reduxjs/toolkit'; +import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; import { map } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { AnyModelConfig } from 'services/api/endpoints/models'; -import { getModelId } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/types'; type UseModelComboboxArg = { modelEntities: EntityState | undefined; - selectedModel?: Pick | null; + selectedModel?: ModelIdentifierWithBase | null; onChange: (value: T | null) => void; getIsDisabled?: (model: T) => boolean; optionsFilter?: (model: T) => boolean; @@ -33,14 +33,14 @@ export const useModelCombobox = (arg: UseModelCombobox return map(modelEntities.entities) .filter(optionsFilter) .map((model) => ({ - label: model.model_name, - value: model.id, + label: model.name, + value: model.key, isDisabled: getIsDisabled ? getIsDisabled(model) : false, })); }, [optionsFilter, getIsDisabled, modelEntities]); const value = useMemo( - () => options.find((m) => (selectedModel ? m.value === getModelId(selectedModel) : false)), + () => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)), [options, selectedModel] ); diff --git a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx index cc0164153d..1810c3ff68 100644 --- a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx @@ -7,8 +7,8 @@ import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/ge import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import type { VAEConfig } from 'services/api/endpoints/models'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; +import type { VAEConfig } from 'services/api/types'; const selector = createMemoizedSelector(selectGenerationSlice, (generation) => { const { model, vae } = generation; @@ -22,25 +22,27 @@ const ParamVAEModelSelect = () => { const { data, isLoading } = useGetVaeModelsQuery(); const getIsDisabled = useCallback( (vae: VAEConfig): boolean => { - const isCompatible = model?.base_model === vae.base_model; - const hasMainModel = Boolean(model?.base_model); + const isCompatible = model?.base === vae.base; + const hasMainModel = Boolean(model?.base); return !hasMainModel || !isCompatible; }, - [model?.base_model] + [model?.base] ); const _onChange = useCallback( (vae: VAEConfig | null) => { - dispatch(vaeSelected(vae ? pick(vae, 'base_model', 'model_name') : null)); + dispatch(vaeSelected(vae ? pick(vae, 'key', 'base') : null)); }, [dispatch] ); - const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ + const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({ modelEntities: data, onChange: _onChange, - selectedModel: vae ? { ...vae, model_type: 'vae' } : null, + selectedModel: vae ? pick(vae, 'key', 'base') : null, isLoading, getIsDisabled, }); + + console.log(value) return ( @@ -50,7 +52,7 @@ const ParamVAEModelSelect = () => { input; - type UpdateMainModelArg = { base_model: BaseModelType; model_name: string;