feat(ui): simplify lora recall check

This commit is contained in:
psychedelicious 2023-09-18 15:24:11 +10:00
parent fdf9833c39
commit cc0482ae8b
3 changed files with 28 additions and 35 deletions

View File

@ -1,9 +1,9 @@
import { CoreMetadata, LoRAMetadataType } from 'features/nodes/types/types';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { memo, useCallback } from 'react';
import ImageMetadataItem from './ImageMetadataItem';
import { useTranslation } from 'react-i18next';
import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas';
import ImageMetadataItem from './ImageMetadataItem';
type Props = {
metadata?: CoreMetadata;

View File

@ -1,10 +1,6 @@
import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import {
CoreMetadata,
LoRAMetadataType,
LoraInfo,
} from 'features/nodes/types/types';
import { CoreMetadata, LoRAMetadataType } from 'features/nodes/types/types';
import {
refinerModelChanged,
setNegativeStylePromptSDXL,
@ -19,6 +15,11 @@ import {
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { ImageDTO } from 'services/api/types';
import {
loraModelsAdapter,
useGetLoRAModelsQuery,
} from '../../../services/api/endpoints/models';
import { loraRecalled } from '../../lora/store/loraSlice';
import { initialImageSelected, modelSelected } from '../store/actions';
import {
setCfgScale,
@ -50,8 +51,6 @@ import {
isValidStrength,
isValidWidth,
} from '../types/parameterSchemas';
import { loraRecalled } from '../../lora/store/loraSlice';
import { useGetLoRAModelsQuery } from '../../../services/api/endpoints/models';
export const useRecallParameters = () => {
const dispatch = useAppDispatch();
@ -318,45 +317,39 @@ export const useRecallParameters = () => {
* Recall LoRA with toast
*/
const { data: loraModels } = useGetLoRAModelsQuery();
const { loras } = useGetLoRAModelsQuery(undefined, {
selectFromResult: (result) => ({
loras: result.data
? loraModelsAdapter.getSelectors().selectAll(result.data)
: [],
}),
});
const recallLoRA = useCallback(
(lora: LoRAMetadataType) => {
if (!isValidLoRAModel(lora.lora)) {
(loraMetadataItem: LoRAMetadataType) => {
if (!isValidLoRAModel(loraMetadataItem.lora)) {
parameterNotSetToast();
return;
}
if (!loraModels || !loraModels.entities) {
const { base_model, model_name } = loraMetadataItem.lora;
const matchingLoRA = loras.find(
(l) => l.base_model === base_model && l.model_name === model_name
);
if (!matchingLoRA) {
parameterNotSetToast();
return;
}
const matchingId = Object.keys(loraModels.entities).find((loraId) => {
const matchesBaseModel =
loraModels.entities[loraId]?.base_model === lora.lora.base_model;
const matchesModelName =
loraModels.entities[loraId]?.model_name === lora.lora.model_name;
return matchesBaseModel && matchesModelName;
});
if (!matchingId) {
parameterNotSetToast();
return;
}
const fullLoRA = loraModels.entities[matchingId];
if (!fullLoRA) {
parameterNotSetToast();
return;
}
dispatch(loraRecalled({ ...fullLoRA, weight: lora.weight }));
dispatch(
loraRecalled({ ...matchingLoRA, weight: loraMetadataItem.weight })
);
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast, loraModels]
[loras, dispatch, parameterSetToast, parameterNotSetToast]
);
/*

View File

@ -128,7 +128,7 @@ export const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
const onnxModelsAdapter = createEntityAdapter<OnnxModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
export const controlNetModelsAdapter =