mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): provide feedback when recalling invalid lora
This commit is contained in:
parent
cc0482ae8b
commit
94f16b1c69
@ -1,4 +1,4 @@
|
||||
import { CoreMetadata, LoRAMetadataType } from 'features/nodes/types/types';
|
||||
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -69,7 +69,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
}, [metadata?.strength, recallStrength]);
|
||||
|
||||
const handleRecallLoRA = useCallback(
|
||||
(lora: LoRAMetadataType) => {
|
||||
(lora: LoRAMetadataItem) => {
|
||||
recallLoRA(lora);
|
||||
},
|
||||
[recallLoRA]
|
||||
|
@ -1057,12 +1057,12 @@ export const isInvocationFieldSchema = (
|
||||
|
||||
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
|
||||
|
||||
const zLoRAObject = z.object({
|
||||
const zLoRAMetadataItem = z.object({
|
||||
lora: zLoRAModelField.deepPartial(),
|
||||
weight: z.number(),
|
||||
});
|
||||
|
||||
export type LoRAMetadataType = z.infer<typeof zLoRAObject>;
|
||||
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
|
||||
|
||||
export const zCoreMetadata = z
|
||||
.object({
|
||||
@ -1083,7 +1083,7 @@ export const zCoreMetadata = z
|
||||
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
|
||||
.nullish(),
|
||||
controlnets: z.array(zControlField.deepPartial()).nullish(),
|
||||
loras: z.array(zLoRAObject).nullish(),
|
||||
loras: z.array(zLoRAMetadataItem).nullish(),
|
||||
vae: zVaeModelField.nullish(),
|
||||
strength: z.number().nullish(),
|
||||
init_image: z.string().nullish(),
|
||||
|
@ -1,6 +1,8 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { CoreMetadata, LoRAMetadataType } from 'features/nodes/types/types';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
|
||||
import {
|
||||
refinerModelChanged,
|
||||
setNegativeStylePromptSDXL,
|
||||
@ -12,7 +14,7 @@ import {
|
||||
setRefinerStart,
|
||||
setRefinerSteps,
|
||||
} from 'features/sdxl/store/sdxlSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import {
|
||||
@ -52,10 +54,16 @@ import {
|
||||
isValidWidth,
|
||||
} from '../types/parameterSchemas';
|
||||
|
||||
const selector = createSelector(stateSelector, ({ generation }) => {
|
||||
const { model } = generation;
|
||||
return { model };
|
||||
});
|
||||
|
||||
export const useRecallParameters = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const toaster = useAppToaster();
|
||||
const { t } = useTranslation();
|
||||
const { model } = useAppSelector(selector);
|
||||
|
||||
const parameterSetToast = useCallback(() => {
|
||||
toaster({
|
||||
@ -66,14 +74,18 @@ export const useRecallParameters = () => {
|
||||
});
|
||||
}, [t, toaster]);
|
||||
|
||||
const parameterNotSetToast = useCallback(() => {
|
||||
toaster({
|
||||
title: t('toast.parameterNotSet'),
|
||||
status: 'warning',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
}, [t, toaster]);
|
||||
const parameterNotSetToast = useCallback(
|
||||
(description?: string) => {
|
||||
toaster({
|
||||
title: t('toast.parameterNotSet'),
|
||||
description,
|
||||
status: 'warning',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
},
|
||||
[t, toaster]
|
||||
);
|
||||
|
||||
const allParameterSetToast = useCallback(() => {
|
||||
toaster({
|
||||
@ -84,14 +96,18 @@ export const useRecallParameters = () => {
|
||||
});
|
||||
}, [t, toaster]);
|
||||
|
||||
const allParameterNotSetToast = useCallback(() => {
|
||||
toaster({
|
||||
title: t('toast.parametersNotSet'),
|
||||
status: 'warning',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
}, [t, toaster]);
|
||||
const allParameterNotSetToast = useCallback(
|
||||
(description?: string) => {
|
||||
toaster({
|
||||
title: t('toast.parametersNotSet'),
|
||||
status: 'warning',
|
||||
description,
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
},
|
||||
[t, toaster]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall both prompts with toast
|
||||
@ -325,11 +341,10 @@ export const useRecallParameters = () => {
|
||||
}),
|
||||
});
|
||||
|
||||
const recallLoRA = useCallback(
|
||||
(loraMetadataItem: LoRAMetadataType) => {
|
||||
const prepareLoRAMetadataItem = useCallback(
|
||||
(loraMetadataItem: LoRAMetadataItem) => {
|
||||
if (!isValidLoRAModel(loraMetadataItem.lora)) {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
return { lora: null, error: 'Invalid LoRA model' };
|
||||
}
|
||||
|
||||
const { base_model, model_name } = loraMetadataItem.lora;
|
||||
@ -339,17 +354,40 @@ export const useRecallParameters = () => {
|
||||
);
|
||||
|
||||
if (!matchingLoRA) {
|
||||
parameterNotSetToast();
|
||||
return { lora: null, error: 'LoRA model is not installed' };
|
||||
}
|
||||
|
||||
const isCompatibleBaseModel =
|
||||
matchingLoRA?.base_model === model?.base_model;
|
||||
|
||||
if (!isCompatibleBaseModel) {
|
||||
return {
|
||||
lora: null,
|
||||
error: 'LoRA incompatible with currently-selected model',
|
||||
};
|
||||
}
|
||||
|
||||
return { lora: matchingLoRA, error: null };
|
||||
},
|
||||
[loras, model?.base_model]
|
||||
);
|
||||
|
||||
const recallLoRA = useCallback(
|
||||
(loraMetadataItem: LoRAMetadataItem) => {
|
||||
const result = prepareLoRAMetadataItem(loraMetadataItem);
|
||||
|
||||
if (!result.lora) {
|
||||
parameterNotSetToast(result.error);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
loraRecalled({ ...matchingLoRA, weight: loraMetadataItem.weight })
|
||||
loraRecalled({ ...result.lora, weight: loraMetadataItem.weight })
|
||||
);
|
||||
|
||||
parameterSetToast();
|
||||
},
|
||||
[loras, dispatch, parameterSetToast, parameterNotSetToast]
|
||||
[prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/*
|
||||
@ -389,6 +427,7 @@ export const useRecallParameters = () => {
|
||||
refiner_positive_aesthetic_score,
|
||||
refiner_negative_aesthetic_score,
|
||||
refiner_start,
|
||||
loras,
|
||||
} = metadata;
|
||||
|
||||
if (isValidCfgScale(cfg_scale)) {
|
||||
@ -470,9 +509,21 @@ export const useRecallParameters = () => {
|
||||
dispatch(setRefinerStart(refiner_start));
|
||||
}
|
||||
|
||||
loras?.forEach((lora) => {
|
||||
const result = prepareLoRAMetadataItem(lora);
|
||||
if (result.lora) {
|
||||
dispatch(loraRecalled({ ...result.lora, weight: lora.weight }));
|
||||
}
|
||||
});
|
||||
|
||||
allParameterSetToast();
|
||||
},
|
||||
[allParameterNotSetToast, allParameterSetToast, dispatch]
|
||||
[
|
||||
allParameterNotSetToast,
|
||||
allParameterSetToast,
|
||||
dispatch,
|
||||
prepareLoRAMetadataItem,
|
||||
]
|
||||
);
|
||||
|
||||
return {
|
||||
|
Loading…
Reference in New Issue
Block a user