feat(ui): simplify lora recall check

This commit is contained in:
psychedelicious 2023-09-18 15:24:11 +10:00
parent fdf9833c39
commit cc0482ae8b
3 changed files with 28 additions and 35 deletions

View File

@ -1,9 +1,9 @@
import { CoreMetadata, LoRAMetadataType } from 'features/nodes/types/types'; import { CoreMetadata, LoRAMetadataType } from 'features/nodes/types/types';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import ImageMetadataItem from './ImageMetadataItem';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas'; import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas';
import ImageMetadataItem from './ImageMetadataItem';
type Props = { type Props = {
metadata?: CoreMetadata; metadata?: CoreMetadata;

View File

@ -1,10 +1,6 @@
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { import { CoreMetadata, LoRAMetadataType } from 'features/nodes/types/types';
CoreMetadata,
LoRAMetadataType,
LoraInfo,
} from 'features/nodes/types/types';
import { import {
refinerModelChanged, refinerModelChanged,
setNegativeStylePromptSDXL, setNegativeStylePromptSDXL,
@ -19,6 +15,11 @@ import {
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import {
loraModelsAdapter,
useGetLoRAModelsQuery,
} from '../../../services/api/endpoints/models';
import { loraRecalled } from '../../lora/store/loraSlice';
import { initialImageSelected, modelSelected } from '../store/actions'; import { initialImageSelected, modelSelected } from '../store/actions';
import { import {
setCfgScale, setCfgScale,
@ -50,8 +51,6 @@ import {
isValidStrength, isValidStrength,
isValidWidth, isValidWidth,
} from '../types/parameterSchemas'; } from '../types/parameterSchemas';
import { loraRecalled } from '../../lora/store/loraSlice';
import { useGetLoRAModelsQuery } from '../../../services/api/endpoints/models';
export const useRecallParameters = () => { export const useRecallParameters = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
@ -318,45 +317,39 @@ export const useRecallParameters = () => {
* Recall LoRA with toast * Recall LoRA with toast
*/ */
const { data: loraModels } = useGetLoRAModelsQuery(); const { loras } = useGetLoRAModelsQuery(undefined, {
selectFromResult: (result) => ({
const recallLoRA = useCallback( loras: result.data
(lora: LoRAMetadataType) => { ? loraModelsAdapter.getSelectors().selectAll(result.data)
if (!isValidLoRAModel(lora.lora)) { : [],
parameterNotSetToast(); }),
return;
}
if (!loraModels || !loraModels.entities) {
parameterNotSetToast();
return;
}
const matchingId = Object.keys(loraModels.entities).find((loraId) => {
const matchesBaseModel =
loraModels.entities[loraId]?.base_model === lora.lora.base_model;
const matchesModelName =
loraModels.entities[loraId]?.model_name === lora.lora.model_name;
return matchesBaseModel && matchesModelName;
}); });
if (!matchingId) { const recallLoRA = useCallback(
(loraMetadataItem: LoRAMetadataType) => {
if (!isValidLoRAModel(loraMetadataItem.lora)) {
parameterNotSetToast(); parameterNotSetToast();
return; return;
} }
const fullLoRA = loraModels.entities[matchingId]; const { base_model, model_name } = loraMetadataItem.lora;
if (!fullLoRA) { const matchingLoRA = loras.find(
(l) => l.base_model === base_model && l.model_name === model_name
);
if (!matchingLoRA) {
parameterNotSetToast(); parameterNotSetToast();
return; return;
} }
dispatch(loraRecalled({ ...fullLoRA, weight: lora.weight })); dispatch(
loraRecalled({ ...matchingLoRA, weight: loraMetadataItem.weight })
);
parameterSetToast(); parameterSetToast();
}, },
[dispatch, parameterSetToast, parameterNotSetToast, loraModels] [loras, dispatch, parameterSetToast, parameterNotSetToast]
); );
/* /*

View File

@ -128,7 +128,7 @@ export const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
const onnxModelsAdapter = createEntityAdapter<OnnxModelConfigEntity>({ const onnxModelsAdapter = createEntityAdapter<OnnxModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({ export const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
export const controlNetModelsAdapter = export const controlNetModelsAdapter =