fix(ui): get lora select working

This commit is contained in:
psychedelicious 2024-02-20 10:42:44 +11:00
parent 2f2097662a
commit 663f135b3c
3 changed files with 16 additions and 15 deletions

View File

@ -2,15 +2,16 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { groupBy, map, reduce } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/endpoints/models';
import { getModelId } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined;
selectedModel?: Pick<T, 'base_model' | 'model_name' | 'model_type'> | null;
selectedModel?: ModelIdentifierWithBase | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
isLoading?: boolean;
@ -28,7 +29,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
arg: UseGroupedModelComboboxArg<T>
): UseGroupedModelComboboxReturn => {
const { t } = useTranslation();
const base_model = useAppSelector((s) => s.generation.model?.base_model ?? 'sdxl');
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg;
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
if (!modelEntities) {
@ -42,8 +43,8 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
acc.push({
label,
options: val.map((model) => ({
label: model.model_name,
value: model.id,
label: model.name,
value: model.key,
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
})),
});

View File

@ -25,18 +25,18 @@ export const LoRACard = memo((props: LoRACardProps) => {
const handleChange = useCallback(
(v: number) => {
dispatch(loraWeightChanged({ id: lora.id, weight: v }));
dispatch(loraWeightChanged({ key: lora.key, weight: v }));
},
[dispatch, lora.id]
[dispatch, lora.key]
);
const handleSetLoraToggle = useCallback(() => {
dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: !lora.isEnabled }));
}, [dispatch, lora.id, lora.isEnabled]);
dispatch(loraIsEnabledChanged({ key: lora.key, isEnabled: !lora.isEnabled }));
}, [dispatch, lora.key, lora.isEnabled]);
const handleRemoveLora = useCallback(() => {
dispatch(loraRemoved(lora.id));
}, [dispatch, lora.id]);
dispatch(loraRemoved(lora.key));
}, [dispatch, lora.key]);
return (
<Card variant="lora">

View File

@ -6,8 +6,8 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { LoRAConfig } from 'services/api/endpoints/models';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import type { LoRAConfig } from 'services/api/types';
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
@ -16,11 +16,11 @@ const LoRASelect = () => {
const { data, isLoading } = useGetLoRAModelsQuery();
const { t } = useTranslation();
const addedLoRAs = useAppSelector(selectAddedLoRAs);
const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model);
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const getIsDisabled = (lora: LoRAConfig): boolean => {
const isCompatible = currentBaseModel === lora.base_model;
const isAdded = Boolean(addedLoRAs[lora.id]);
const isCompatible = currentBaseModel === lora.base;
const isAdded = Boolean(addedLoRAs[lora.key]);
const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible || isAdded;
};