mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): get lora select working
This commit is contained in:
parent
2f2097662a
commit
663f135b3c
@ -2,15 +2,16 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|||||||
import type { EntityState } from '@reduxjs/toolkit';
|
import type { EntityState } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import type { GroupBase } from 'chakra-react-select';
|
import type { GroupBase } from 'chakra-react-select';
|
||||||
|
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||||
import { groupBy, map, reduce } from 'lodash-es';
|
import { groupBy, map, reduce } from 'lodash-es';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { AnyModelConfig } from 'services/api/endpoints/models';
|
|
||||||
import { getModelId } from 'services/api/endpoints/models';
|
import { getModelId } from 'services/api/endpoints/models';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
||||||
modelEntities: EntityState<T, string> | undefined;
|
modelEntities: EntityState<T, string> | undefined;
|
||||||
selectedModel?: Pick<T, 'base_model' | 'model_name' | 'model_type'> | null;
|
selectedModel?: ModelIdentifierWithBase | null;
|
||||||
onChange: (value: T | null) => void;
|
onChange: (value: T | null) => void;
|
||||||
getIsDisabled?: (model: T) => boolean;
|
getIsDisabled?: (model: T) => boolean;
|
||||||
isLoading?: boolean;
|
isLoading?: boolean;
|
||||||
@ -28,7 +29,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
|||||||
arg: UseGroupedModelComboboxArg<T>
|
arg: UseGroupedModelComboboxArg<T>
|
||||||
): UseGroupedModelComboboxReturn => {
|
): UseGroupedModelComboboxReturn => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const base_model = useAppSelector((s) => s.generation.model?.base_model ?? 'sdxl');
|
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
|
||||||
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg;
|
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg;
|
||||||
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
|
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
|
||||||
if (!modelEntities) {
|
if (!modelEntities) {
|
||||||
@ -42,8 +43,8 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
|||||||
acc.push({
|
acc.push({
|
||||||
label,
|
label,
|
||||||
options: val.map((model) => ({
|
options: val.map((model) => ({
|
||||||
label: model.model_name,
|
label: model.name,
|
||||||
value: model.id,
|
value: model.key,
|
||||||
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
|
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
|
||||||
})),
|
})),
|
||||||
});
|
});
|
||||||
|
@ -25,18 +25,18 @@ export const LoRACard = memo((props: LoRACardProps) => {
|
|||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
dispatch(loraWeightChanged({ id: lora.id, weight: v }));
|
dispatch(loraWeightChanged({ key: lora.key, weight: v }));
|
||||||
},
|
},
|
||||||
[dispatch, lora.id]
|
[dispatch, lora.key]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleSetLoraToggle = useCallback(() => {
|
const handleSetLoraToggle = useCallback(() => {
|
||||||
dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: !lora.isEnabled }));
|
dispatch(loraIsEnabledChanged({ key: lora.key, isEnabled: !lora.isEnabled }));
|
||||||
}, [dispatch, lora.id, lora.isEnabled]);
|
}, [dispatch, lora.key, lora.isEnabled]);
|
||||||
|
|
||||||
const handleRemoveLora = useCallback(() => {
|
const handleRemoveLora = useCallback(() => {
|
||||||
dispatch(loraRemoved(lora.id));
|
dispatch(loraRemoved(lora.key));
|
||||||
}, [dispatch, lora.id]);
|
}, [dispatch, lora.key]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Card variant="lora">
|
<Card variant="lora">
|
||||||
|
@ -6,8 +6,8 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
|||||||
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
|
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { LoRAConfig } from 'services/api/endpoints/models';
|
|
||||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import type { LoRAConfig } from 'services/api/types';
|
||||||
|
|
||||||
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
|
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
|
||||||
|
|
||||||
@ -16,11 +16,11 @@ const LoRASelect = () => {
|
|||||||
const { data, isLoading } = useGetLoRAModelsQuery();
|
const { data, isLoading } = useGetLoRAModelsQuery();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const addedLoRAs = useAppSelector(selectAddedLoRAs);
|
const addedLoRAs = useAppSelector(selectAddedLoRAs);
|
||||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model);
|
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||||
|
|
||||||
const getIsDisabled = (lora: LoRAConfig): boolean => {
|
const getIsDisabled = (lora: LoRAConfig): boolean => {
|
||||||
const isCompatible = currentBaseModel === lora.base_model;
|
const isCompatible = currentBaseModel === lora.base;
|
||||||
const isAdded = Boolean(addedLoRAs[lora.id]);
|
const isAdded = Boolean(addedLoRAs[lora.key]);
|
||||||
const hasMainModel = Boolean(currentBaseModel);
|
const hasMainModel = Boolean(currentBaseModel);
|
||||||
return !hasMainModel || !isCompatible || isAdded;
|
return !hasMainModel || !isCompatible || isAdded;
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user