From 663f135b3cae68b224db2b2bf4e2ae0b012fd9f0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 10:42:44 +1100 Subject: [PATCH] fix(ui): get lora select working --- .../web/src/common/hooks/useGroupedModelCombobox.ts | 11 ++++++----- .../web/src/features/lora/components/LoRACard.tsx | 12 ++++++------ .../web/src/features/lora/components/LoRASelect.tsx | 8 ++++---- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts index 875ce1f1c4..140cf3eaa6 100644 --- a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts @@ -2,15 +2,16 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import type { EntityState } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import type { GroupBase } from 'chakra-react-select'; +import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; import { groupBy, map, reduce } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { AnyModelConfig } from 'services/api/endpoints/models'; import { getModelId } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/types'; type UseGroupedModelComboboxArg = { modelEntities: EntityState | undefined; - selectedModel?: Pick | null; + selectedModel?: ModelIdentifierWithBase | null; onChange: (value: T | null) => void; getIsDisabled?: (model: T) => boolean; isLoading?: boolean; @@ -28,7 +29,7 @@ export const useGroupedModelCombobox = ( arg: UseGroupedModelComboboxArg ): UseGroupedModelComboboxReturn => { const { t } = useTranslation(); - const base_model = useAppSelector((s) => s.generation.model?.base_model ?? 'sdxl'); + const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl'); const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg; const options = useMemo[]>(() => { if (!modelEntities) { @@ -42,8 +43,8 @@ export const useGroupedModelCombobox = ( acc.push({ label, options: val.map((model) => ({ - label: model.model_name, - value: model.id, + label: model.name, + value: model.key, isDisabled: getIsDisabled ? getIsDisabled(model) : false, })), }); diff --git a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx index a194fb1361..05a3c14cb4 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx @@ -25,18 +25,18 @@ export const LoRACard = memo((props: LoRACardProps) => { const handleChange = useCallback( (v: number) => { - dispatch(loraWeightChanged({ id: lora.id, weight: v })); + dispatch(loraWeightChanged({ key: lora.key, weight: v })); }, - [dispatch, lora.id] + [dispatch, lora.key] ); const handleSetLoraToggle = useCallback(() => { - dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: !lora.isEnabled })); - }, [dispatch, lora.id, lora.isEnabled]); + dispatch(loraIsEnabledChanged({ key: lora.key, isEnabled: !lora.isEnabled })); + }, [dispatch, lora.key, lora.isEnabled]); const handleRemoveLora = useCallback(() => { - dispatch(loraRemoved(lora.id)); - }, [dispatch, lora.id]); + dispatch(loraRemoved(lora.key)); + }, [dispatch, lora.key]); return ( diff --git a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx index 910f7087df..7dd606dcc2 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx @@ -6,8 +6,8 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { LoRAConfig } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; +import type { LoRAConfig } from 'services/api/types'; const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras); @@ -16,11 +16,11 @@ const LoRASelect = () => { const { data, isLoading } = useGetLoRAModelsQuery(); const { t } = useTranslation(); const addedLoRAs = useAppSelector(selectAddedLoRAs); - const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model); + const currentBaseModel = useAppSelector((s) => s.generation.model?.base); const getIsDisabled = (lora: LoRAConfig): boolean => { - const isCompatible = currentBaseModel === lora.base_model; - const isAdded = Boolean(addedLoRAs[lora.id]); + const isCompatible = currentBaseModel === lora.base; + const isAdded = Boolean(addedLoRAs[lora.key]); const hasMainModel = Boolean(currentBaseModel); return !hasMainModel || !isCompatible || isAdded; };