feat(ui): provide feedback when recalling invalid lora

This commit is contained in:
psychedelicious 2023-09-18 15:54:40 +10:00
parent cc0482ae8b
commit 94f16b1c69
3 changed files with 83 additions and 32 deletions

View File

@ -1,4 +1,4 @@
import { CoreMetadata, LoRAMetadataType } from 'features/nodes/types/types'; import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -69,7 +69,7 @@ const ImageMetadataActions = (props: Props) => {
}, [metadata?.strength, recallStrength]); }, [metadata?.strength, recallStrength]);
const handleRecallLoRA = useCallback( const handleRecallLoRA = useCallback(
(lora: LoRAMetadataType) => { (lora: LoRAMetadataItem) => {
recallLoRA(lora); recallLoRA(lora);
}, },
[recallLoRA] [recallLoRA]

View File

@ -1057,12 +1057,12 @@ export const isInvocationFieldSchema = (
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' }; export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
const zLoRAObject = z.object({ const zLoRAMetadataItem = z.object({
lora: zLoRAModelField.deepPartial(), lora: zLoRAModelField.deepPartial(),
weight: z.number(), weight: z.number(),
}); });
export type LoRAMetadataType = z.infer<typeof zLoRAObject>; export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
export const zCoreMetadata = z export const zCoreMetadata = z
.object({ .object({
@ -1083,7 +1083,7 @@ export const zCoreMetadata = z
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()]) .union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
.nullish(), .nullish(),
controlnets: z.array(zControlField.deepPartial()).nullish(), controlnets: z.array(zControlField.deepPartial()).nullish(),
loras: z.array(zLoRAObject).nullish(), loras: z.array(zLoRAMetadataItem).nullish(),
vae: zVaeModelField.nullish(), vae: zVaeModelField.nullish(),
strength: z.number().nullish(), strength: z.number().nullish(),
init_image: z.string().nullish(), init_image: z.string().nullish(),

View File

@ -1,6 +1,8 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks'; import { stateSelector } from 'app/store/store';
import { CoreMetadata, LoRAMetadataType } from 'features/nodes/types/types'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
import { import {
refinerModelChanged, refinerModelChanged,
setNegativeStylePromptSDXL, setNegativeStylePromptSDXL,
@ -12,7 +14,7 @@ import {
setRefinerStart, setRefinerStart,
setRefinerSteps, setRefinerSteps,
} from 'features/sdxl/store/sdxlSlice'; } from 'features/sdxl/store/sdxlSlice';
import { useCallback } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { import {
@ -52,10 +54,16 @@ import {
isValidWidth, isValidWidth,
} from '../types/parameterSchemas'; } from '../types/parameterSchemas';
const selector = createSelector(stateSelector, ({ generation }) => {
const { model } = generation;
return { model };
});
export const useRecallParameters = () => { export const useRecallParameters = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const toaster = useAppToaster(); const toaster = useAppToaster();
const { t } = useTranslation(); const { t } = useTranslation();
const { model } = useAppSelector(selector);
const parameterSetToast = useCallback(() => { const parameterSetToast = useCallback(() => {
toaster({ toaster({
@ -66,14 +74,18 @@ export const useRecallParameters = () => {
}); });
}, [t, toaster]); }, [t, toaster]);
const parameterNotSetToast = useCallback(() => { const parameterNotSetToast = useCallback(
toaster({ (description?: string) => {
title: t('toast.parameterNotSet'), toaster({
status: 'warning', title: t('toast.parameterNotSet'),
duration: 2500, description,
isClosable: true, status: 'warning',
}); duration: 2500,
}, [t, toaster]); isClosable: true,
});
},
[t, toaster]
);
const allParameterSetToast = useCallback(() => { const allParameterSetToast = useCallback(() => {
toaster({ toaster({
@ -84,14 +96,18 @@ export const useRecallParameters = () => {
}); });
}, [t, toaster]); }, [t, toaster]);
const allParameterNotSetToast = useCallback(() => { const allParameterNotSetToast = useCallback(
toaster({ (description?: string) => {
title: t('toast.parametersNotSet'), toaster({
status: 'warning', title: t('toast.parametersNotSet'),
duration: 2500, status: 'warning',
isClosable: true, description,
}); duration: 2500,
}, [t, toaster]); isClosable: true,
});
},
[t, toaster]
);
/** /**
* Recall both prompts with toast * Recall both prompts with toast
@ -325,11 +341,10 @@ export const useRecallParameters = () => {
}), }),
}); });
const recallLoRA = useCallback( const prepareLoRAMetadataItem = useCallback(
(loraMetadataItem: LoRAMetadataType) => { (loraMetadataItem: LoRAMetadataItem) => {
if (!isValidLoRAModel(loraMetadataItem.lora)) { if (!isValidLoRAModel(loraMetadataItem.lora)) {
parameterNotSetToast(); return { lora: null, error: 'Invalid LoRA model' };
return;
} }
const { base_model, model_name } = loraMetadataItem.lora; const { base_model, model_name } = loraMetadataItem.lora;
@ -339,17 +354,40 @@ export const useRecallParameters = () => {
); );
if (!matchingLoRA) { if (!matchingLoRA) {
parameterNotSetToast(); return { lora: null, error: 'LoRA model is not installed' };
}
const isCompatibleBaseModel =
matchingLoRA?.base_model === model?.base_model;
if (!isCompatibleBaseModel) {
return {
lora: null,
error: 'LoRA incompatible with currently-selected model',
};
}
return { lora: matchingLoRA, error: null };
},
[loras, model?.base_model]
);
const recallLoRA = useCallback(
(loraMetadataItem: LoRAMetadataItem) => {
const result = prepareLoRAMetadataItem(loraMetadataItem);
if (!result.lora) {
parameterNotSetToast(result.error);
return; return;
} }
dispatch( dispatch(
loraRecalled({ ...matchingLoRA, weight: loraMetadataItem.weight }) loraRecalled({ ...result.lora, weight: loraMetadataItem.weight })
); );
parameterSetToast(); parameterSetToast();
}, },
[loras, dispatch, parameterSetToast, parameterNotSetToast] [prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
); );
/* /*
@ -389,6 +427,7 @@ export const useRecallParameters = () => {
refiner_positive_aesthetic_score, refiner_positive_aesthetic_score,
refiner_negative_aesthetic_score, refiner_negative_aesthetic_score,
refiner_start, refiner_start,
loras,
} = metadata; } = metadata;
if (isValidCfgScale(cfg_scale)) { if (isValidCfgScale(cfg_scale)) {
@ -470,9 +509,21 @@ export const useRecallParameters = () => {
dispatch(setRefinerStart(refiner_start)); dispatch(setRefinerStart(refiner_start));
} }
loras?.forEach((lora) => {
const result = prepareLoRAMetadataItem(lora);
if (result.lora) {
dispatch(loraRecalled({ ...result.lora, weight: lora.weight }));
}
});
allParameterSetToast(); allParameterSetToast();
}, },
[allParameterNotSetToast, allParameterSetToast, dispatch] [
allParameterNotSetToast,
allParameterSetToast,
dispatch,
prepareLoRAMetadataItem,
]
); );
return { return {