fix(ui): get lora select working

This commit is contained in:
psychedelicious 2024-02-20 10:42:44 +11:00
parent 2f2097662a
commit 663f135b3c
3 changed files with 16 additions and 15 deletions

View File

@ -2,15 +2,16 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit'; import type { EntityState } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select'; import type { GroupBase } from 'chakra-react-select';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { groupBy, map, reduce } from 'lodash-es'; import { groupBy, map, reduce } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/endpoints/models';
import { getModelId } from 'services/api/endpoints/models'; import { getModelId } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = { type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined; modelEntities: EntityState<T, string> | undefined;
selectedModel?: Pick<T, 'base_model' | 'model_name' | 'model_type'> | null; selectedModel?: ModelIdentifierWithBase | null;
onChange: (value: T | null) => void; onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean; getIsDisabled?: (model: T) => boolean;
isLoading?: boolean; isLoading?: boolean;
@ -28,7 +29,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
arg: UseGroupedModelComboboxArg<T> arg: UseGroupedModelComboboxArg<T>
): UseGroupedModelComboboxReturn => { ): UseGroupedModelComboboxReturn => {
const { t } = useTranslation(); const { t } = useTranslation();
const base_model = useAppSelector((s) => s.generation.model?.base_model ?? 'sdxl'); const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg; const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg;
const options = useMemo<GroupBase<ComboboxOption>[]>(() => { const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
if (!modelEntities) { if (!modelEntities) {
@ -42,8 +43,8 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
acc.push({ acc.push({
label, label,
options: val.map((model) => ({ options: val.map((model) => ({
label: model.model_name, label: model.name,
value: model.id, value: model.key,
isDisabled: getIsDisabled ? getIsDisabled(model) : false, isDisabled: getIsDisabled ? getIsDisabled(model) : false,
})), })),
}); });

View File

@ -25,18 +25,18 @@ export const LoRACard = memo((props: LoRACardProps) => {
const handleChange = useCallback( const handleChange = useCallback(
(v: number) => { (v: number) => {
dispatch(loraWeightChanged({ id: lora.id, weight: v })); dispatch(loraWeightChanged({ key: lora.key, weight: v }));
}, },
[dispatch, lora.id] [dispatch, lora.key]
); );
const handleSetLoraToggle = useCallback(() => { const handleSetLoraToggle = useCallback(() => {
dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: !lora.isEnabled })); dispatch(loraIsEnabledChanged({ key: lora.key, isEnabled: !lora.isEnabled }));
}, [dispatch, lora.id, lora.isEnabled]); }, [dispatch, lora.key, lora.isEnabled]);
const handleRemoveLora = useCallback(() => { const handleRemoveLora = useCallback(() => {
dispatch(loraRemoved(lora.id)); dispatch(loraRemoved(lora.key));
}, [dispatch, lora.id]); }, [dispatch, lora.key]);
return ( return (
<Card variant="lora"> <Card variant="lora">

View File

@ -6,8 +6,8 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice'; import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { LoRAConfig } from 'services/api/endpoints/models';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import type { LoRAConfig } from 'services/api/types';
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras); const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
@ -16,11 +16,11 @@ const LoRASelect = () => {
const { data, isLoading } = useGetLoRAModelsQuery(); const { data, isLoading } = useGetLoRAModelsQuery();
const { t } = useTranslation(); const { t } = useTranslation();
const addedLoRAs = useAppSelector(selectAddedLoRAs); const addedLoRAs = useAppSelector(selectAddedLoRAs);
const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model); const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const getIsDisabled = (lora: LoRAConfig): boolean => { const getIsDisabled = (lora: LoRAConfig): boolean => {
const isCompatible = currentBaseModel === lora.base_model; const isCompatible = currentBaseModel === lora.base;
const isAdded = Boolean(addedLoRAs[lora.id]); const isAdded = Boolean(addedLoRAs[lora.key]);
const hasMainModel = Boolean(currentBaseModel); const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible || isAdded; return !hasMainModel || !isCompatible || isAdded;
}; };