mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): get lora select working
This commit is contained in:
parent
2f2097662a
commit
663f135b3c
@ -2,15 +2,16 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { GroupBase } from 'chakra-react-select';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import { groupBy, map, reduce } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/endpoints/models';
|
||||
import { getModelId } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
||||
modelEntities: EntityState<T, string> | undefined;
|
||||
selectedModel?: Pick<T, 'base_model' | 'model_name' | 'model_type'> | null;
|
||||
selectedModel?: ModelIdentifierWithBase | null;
|
||||
onChange: (value: T | null) => void;
|
||||
getIsDisabled?: (model: T) => boolean;
|
||||
isLoading?: boolean;
|
||||
@ -28,7 +29,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
arg: UseGroupedModelComboboxArg<T>
|
||||
): UseGroupedModelComboboxReturn => {
|
||||
const { t } = useTranslation();
|
||||
const base_model = useAppSelector((s) => s.generation.model?.base_model ?? 'sdxl');
|
||||
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
|
||||
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg;
|
||||
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
|
||||
if (!modelEntities) {
|
||||
@ -42,8 +43,8 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
acc.push({
|
||||
label,
|
||||
options: val.map((model) => ({
|
||||
label: model.model_name,
|
||||
value: model.id,
|
||||
label: model.name,
|
||||
value: model.key,
|
||||
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
|
||||
})),
|
||||
});
|
||||
|
@ -25,18 +25,18 @@ export const LoRACard = memo((props: LoRACardProps) => {
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: number) => {
|
||||
dispatch(loraWeightChanged({ id: lora.id, weight: v }));
|
||||
dispatch(loraWeightChanged({ key: lora.key, weight: v }));
|
||||
},
|
||||
[dispatch, lora.id]
|
||||
[dispatch, lora.key]
|
||||
);
|
||||
|
||||
const handleSetLoraToggle = useCallback(() => {
|
||||
dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: !lora.isEnabled }));
|
||||
}, [dispatch, lora.id, lora.isEnabled]);
|
||||
dispatch(loraIsEnabledChanged({ key: lora.key, isEnabled: !lora.isEnabled }));
|
||||
}, [dispatch, lora.key, lora.isEnabled]);
|
||||
|
||||
const handleRemoveLora = useCallback(() => {
|
||||
dispatch(loraRemoved(lora.id));
|
||||
}, [dispatch, lora.id]);
|
||||
dispatch(loraRemoved(lora.key));
|
||||
}, [dispatch, lora.key]);
|
||||
|
||||
return (
|
||||
<Card variant="lora">
|
||||
|
@ -6,8 +6,8 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { LoRAConfig } from 'services/api/endpoints/models';
|
||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { LoRAConfig } from 'services/api/types';
|
||||
|
||||
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
|
||||
|
||||
@ -16,11 +16,11 @@ const LoRASelect = () => {
|
||||
const { data, isLoading } = useGetLoRAModelsQuery();
|
||||
const { t } = useTranslation();
|
||||
const addedLoRAs = useAppSelector(selectAddedLoRAs);
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model);
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
|
||||
const getIsDisabled = (lora: LoRAConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === lora.base_model;
|
||||
const isAdded = Boolean(addedLoRAs[lora.id]);
|
||||
const isCompatible = currentBaseModel === lora.base;
|
||||
const isAdded = Boolean(addedLoRAs[lora.key]);
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
return !hasMainModel || !isCompatible || isAdded;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user