From 5a961bb58e695f306a95b2ea2ce300dfeae5e46e Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Fri, 15 Sep 2023 16:52:30 -0400 Subject: [PATCH] first pass to recall LoRAs --- .../ImageMetadataActions.tsx | 79 ++++++------------- .../web/src/features/lora/store/loraSlice.ts | 8 ++ .../web/src/features/nodes/types/types.ts | 16 ++-- .../parameters/hooks/useRecallParameters.ts | 52 +++++++++++- 4 files changed, 90 insertions(+), 65 deletions(-) diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index c1124477e2..a0ea69157c 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -1,8 +1,9 @@ -import { CoreMetadata } from 'features/nodes/types/types'; +import { CoreMetadata, LoRAMetadataType } from 'features/nodes/types/types'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { memo, useCallback } from 'react'; import ImageMetadataItem from './ImageMetadataItem'; import { useTranslation } from 'react-i18next'; +import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas'; type Props = { metadata?: CoreMetadata; @@ -24,6 +25,7 @@ const ImageMetadataActions = (props: Props) => { recallWidth, recallHeight, recallStrength, + recallLoRA, } = useRecallParameters(); const handleRecallPositivePrompt = useCallback(() => { @@ -66,6 +68,13 @@ const ImageMetadataActions = (props: Props) => { recallStrength(metadata?.strength); }, [metadata?.strength, recallStrength]); + const handleRecallLoRA = useCallback( + (lora: LoRAMetadataType) => { + recallLoRA(lora); + }, + [recallLoRA] + ); + if (!metadata || Object.keys(metadata).length === 0) { return null; } @@ -130,20 +139,6 @@ const ImageMetadataActions = (props: Props) => { onClick={handleRecallHeight} /> )} - {/* {metadata.threshold !== undefined && ( - dispatch(setThreshold(Number(metadata.threshold)))} - /> - )} - {metadata.perlin !== undefined && ( - dispatch(setPerlin(Number(metadata.perlin)))} - /> - )} */} {metadata.scheduler && ( { onClick={handleRecallCfgScale} /> )} - {/* {metadata.variations && metadata.variations.length > 0 && ( - handleRecallLoRA(lora)} + /> + ); + } + })} ); }; diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts index 10a1671933..bbe019b1c3 100644 --- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -27,6 +27,13 @@ export const loraSlice = createSlice({ const { model_name, id, base_model } = action.payload; state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig }; }, + loraRecalled: ( + state, + action: PayloadAction + ) => { + const { model_name, id, base_model, weight } = action.payload; + state.loras[id] = { id, model_name, base_model, weight }; + }, loraRemoved: (state, action: PayloadAction) => { const id = action.payload; delete state.loras[id]; @@ -62,6 +69,7 @@ export const { loraWeightChanged, loraWeightReset, lorasCleared, + loraRecalled, } = loraSlice.actions; export default loraSlice.reducer; diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index a4b71457f7..6a03c07d57 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -1057,6 +1057,13 @@ export const isInvocationFieldSchema = ( export type InvocationEdgeExtra = { type: 'default' | 'collapsed' }; +const zLoRAObject = z.object({ + lora: zLoRAModelField.deepPartial(), + weight: z.number(), +}); + +export type LoRAMetadataType = z.infer; + export const zCoreMetadata = z .object({ app_version: z.string().nullish(), @@ -1076,14 +1083,7 @@ export const zCoreMetadata = z .union([zMainModel.deepPartial(), zOnnxModel.deepPartial()]) .nullish(), controlnets: z.array(zControlField.deepPartial()).nullish(), - loras: z - .array( - z.object({ - lora: zLoRAModelField.deepPartial(), - weight: z.number(), - }) - ) - .nullish(), + loras: z.array(zLoRAObject).nullish(), vae: zVaeModelField.nullish(), strength: z.number().nullish(), init_image: z.string().nullish(), diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index 203ff2cb1b..d7fcc2c9a4 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -1,6 +1,10 @@ import { useAppToaster } from 'app/components/Toaster'; import { useAppDispatch } from 'app/store/storeHooks'; -import { CoreMetadata } from 'features/nodes/types/types'; +import { + CoreMetadata, + LoRAMetadataType, + LoraInfo, +} from 'features/nodes/types/types'; import { refinerModelChanged, setNegativeStylePromptSDXL, @@ -30,6 +34,7 @@ import { import { isValidCfgScale, isValidHeight, + isValidLoRAModel, isValidMainModel, isValidNegativePrompt, isValidPositivePrompt, @@ -45,6 +50,8 @@ import { isValidStrength, isValidWidth, } from '../types/parameterSchemas'; +import { loraRecalled } from '../../lora/store/loraSlice'; +import { useGetLoRAModelsQuery } from '../../../services/api/endpoints/models'; export const useRecallParameters = () => { const dispatch = useAppDispatch(); @@ -307,6 +314,48 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); + /** + * Recall LoRA with toast + */ + + const { data: loraModels } = useGetLoRAModelsQuery(); + + const recallLoRA = useCallback( + (lora: LoRAMetadataType) => { + if (!isValidLoRAModel(lora.lora)) { + parameterNotSetToast(); + return; + } + + if (!loraModels || !loraModels.entities) { + return; + } + + const matchingId = Object.keys(loraModels.entities).find((loraId) => { + const matchesBaseModel = + loraModels.entities[loraId]?.base_model === lora.lora.base_model; + const matchesModelName = + loraModels.entities[loraId]?.model_name === lora.lora.model_name; + return matchesBaseModel && matchesModelName; + }); + + if (!matchingId) { +return; +} + + const fullLoRA = loraModels.entities[matchingId]; + + if (!fullLoRA) { +return; +} + + dispatch(loraRecalled({ ...fullLoRA, weight: lora.weight })); + + parameterSetToast(); + }, + [dispatch, parameterSetToast, parameterNotSetToast, loraModels] + ); + /* * Sets image as initial image with toast */ @@ -444,6 +493,7 @@ export const useRecallParameters = () => { recallWidth, recallHeight, recallStrength, + recallLoRA, recallAllParameters, sendToImageToImage, };