mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): simplify lora recall check
This commit is contained in:
parent
fdf9833c39
commit
cc0482ae8b
@ -1,9 +1,9 @@
|
||||
import { CoreMetadata, LoRAMetadataType } from 'features/nodes/types/types';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { memo, useCallback } from 'react';
|
||||
import ImageMetadataItem from './ImageMetadataItem';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas';
|
||||
import ImageMetadataItem from './ImageMetadataItem';
|
||||
|
||||
type Props = {
|
||||
metadata?: CoreMetadata;
|
||||
|
@ -1,10 +1,6 @@
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import {
|
||||
CoreMetadata,
|
||||
LoRAMetadataType,
|
||||
LoraInfo,
|
||||
} from 'features/nodes/types/types';
|
||||
import { CoreMetadata, LoRAMetadataType } from 'features/nodes/types/types';
|
||||
import {
|
||||
refinerModelChanged,
|
||||
setNegativeStylePromptSDXL,
|
||||
@ -19,6 +15,11 @@ import {
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import {
|
||||
loraModelsAdapter,
|
||||
useGetLoRAModelsQuery,
|
||||
} from '../../../services/api/endpoints/models';
|
||||
import { loraRecalled } from '../../lora/store/loraSlice';
|
||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||
import {
|
||||
setCfgScale,
|
||||
@ -50,8 +51,6 @@ import {
|
||||
isValidStrength,
|
||||
isValidWidth,
|
||||
} from '../types/parameterSchemas';
|
||||
import { loraRecalled } from '../../lora/store/loraSlice';
|
||||
import { useGetLoRAModelsQuery } from '../../../services/api/endpoints/models';
|
||||
|
||||
export const useRecallParameters = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
@ -318,45 +317,39 @@ export const useRecallParameters = () => {
|
||||
* Recall LoRA with toast
|
||||
*/
|
||||
|
||||
const { data: loraModels } = useGetLoRAModelsQuery();
|
||||
const { loras } = useGetLoRAModelsQuery(undefined, {
|
||||
selectFromResult: (result) => ({
|
||||
loras: result.data
|
||||
? loraModelsAdapter.getSelectors().selectAll(result.data)
|
||||
: [],
|
||||
}),
|
||||
});
|
||||
|
||||
const recallLoRA = useCallback(
|
||||
(lora: LoRAMetadataType) => {
|
||||
if (!isValidLoRAModel(lora.lora)) {
|
||||
(loraMetadataItem: LoRAMetadataType) => {
|
||||
if (!isValidLoRAModel(loraMetadataItem.lora)) {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
|
||||
if (!loraModels || !loraModels.entities) {
|
||||
const { base_model, model_name } = loraMetadataItem.lora;
|
||||
|
||||
const matchingLoRA = loras.find(
|
||||
(l) => l.base_model === base_model && l.model_name === model_name
|
||||
);
|
||||
|
||||
if (!matchingLoRA) {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
|
||||
const matchingId = Object.keys(loraModels.entities).find((loraId) => {
|
||||
const matchesBaseModel =
|
||||
loraModels.entities[loraId]?.base_model === lora.lora.base_model;
|
||||
const matchesModelName =
|
||||
loraModels.entities[loraId]?.model_name === lora.lora.model_name;
|
||||
return matchesBaseModel && matchesModelName;
|
||||
});
|
||||
|
||||
if (!matchingId) {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
|
||||
const fullLoRA = loraModels.entities[matchingId];
|
||||
|
||||
if (!fullLoRA) {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(loraRecalled({ ...fullLoRA, weight: lora.weight }));
|
||||
dispatch(
|
||||
loraRecalled({ ...matchingLoRA, weight: loraMetadataItem.weight })
|
||||
);
|
||||
|
||||
parameterSetToast();
|
||||
},
|
||||
[dispatch, parameterSetToast, parameterNotSetToast, loraModels]
|
||||
[loras, dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/*
|
||||
|
@ -128,7 +128,7 @@ export const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||
const onnxModelsAdapter = createEntityAdapter<OnnxModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
||||
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
export const controlNetModelsAdapter =
|
||||
|
Loading…
Reference in New Issue
Block a user