diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index a0ea69157c..cf58334667 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -1,9 +1,9 @@ import { CoreMetadata, LoRAMetadataType } from 'features/nodes/types/types'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { memo, useCallback } from 'react'; -import ImageMetadataItem from './ImageMetadataItem'; import { useTranslation } from 'react-i18next'; import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas'; +import ImageMetadataItem from './ImageMetadataItem'; type Props = { metadata?: CoreMetadata; diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index 9045646ad5..fc33200bd3 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -1,10 +1,6 @@ import { useAppToaster } from 'app/components/Toaster'; import { useAppDispatch } from 'app/store/storeHooks'; -import { - CoreMetadata, - LoRAMetadataType, - LoraInfo, -} from 'features/nodes/types/types'; +import { CoreMetadata, LoRAMetadataType } from 'features/nodes/types/types'; import { refinerModelChanged, setNegativeStylePromptSDXL, @@ -19,6 +15,11 @@ import { import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { ImageDTO } from 'services/api/types'; +import { + loraModelsAdapter, + useGetLoRAModelsQuery, +} from '../../../services/api/endpoints/models'; +import { loraRecalled } from '../../lora/store/loraSlice'; import { initialImageSelected, modelSelected } from '../store/actions'; import { setCfgScale, @@ -50,8 +51,6 @@ import { isValidStrength, isValidWidth, } from '../types/parameterSchemas'; -import { loraRecalled } from '../../lora/store/loraSlice'; -import { useGetLoRAModelsQuery } from '../../../services/api/endpoints/models'; export const useRecallParameters = () => { const dispatch = useAppDispatch(); @@ -318,45 +317,39 @@ export const useRecallParameters = () => { * Recall LoRA with toast */ - const { data: loraModels } = useGetLoRAModelsQuery(); + const { loras } = useGetLoRAModelsQuery(undefined, { + selectFromResult: (result) => ({ + loras: result.data + ? loraModelsAdapter.getSelectors().selectAll(result.data) + : [], + }), + }); const recallLoRA = useCallback( - (lora: LoRAMetadataType) => { - if (!isValidLoRAModel(lora.lora)) { + (loraMetadataItem: LoRAMetadataType) => { + if (!isValidLoRAModel(loraMetadataItem.lora)) { parameterNotSetToast(); return; } - if (!loraModels || !loraModels.entities) { + const { base_model, model_name } = loraMetadataItem.lora; + + const matchingLoRA = loras.find( + (l) => l.base_model === base_model && l.model_name === model_name + ); + + if (!matchingLoRA) { parameterNotSetToast(); return; } - const matchingId = Object.keys(loraModels.entities).find((loraId) => { - const matchesBaseModel = - loraModels.entities[loraId]?.base_model === lora.lora.base_model; - const matchesModelName = - loraModels.entities[loraId]?.model_name === lora.lora.model_name; - return matchesBaseModel && matchesModelName; - }); - - if (!matchingId) { - parameterNotSetToast(); - return; - } - - const fullLoRA = loraModels.entities[matchingId]; - - if (!fullLoRA) { - parameterNotSetToast(); - return; - } - - dispatch(loraRecalled({ ...fullLoRA, weight: lora.weight })); + dispatch( + loraRecalled({ ...matchingLoRA, weight: loraMetadataItem.weight }) + ); parameterSetToast(); }, - [dispatch, parameterSetToast, parameterNotSetToast, loraModels] + [loras, dispatch, parameterSetToast, parameterNotSetToast] ); /* diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 9be8bd13f6..9db7762344 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -128,7 +128,7 @@ export const mainModelsAdapter = createEntityAdapter({ const onnxModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), }); -const loraModelsAdapter = createEntityAdapter({ +export const loraModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), }); export const controlNetModelsAdapter =