first pass to recall LoRAs

This commit is contained in:
Mary Hipp 2023-09-15 16:52:30 -04:00 committed by psychedelicious
parent 627750eded
commit 5a961bb58e
4 changed files with 90 additions and 65 deletions

View File

@ -1,8 +1,9 @@
import { CoreMetadata } from 'features/nodes/types/types';
import { CoreMetadata, LoRAMetadataType } from 'features/nodes/types/types';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { memo, useCallback } from 'react';
import ImageMetadataItem from './ImageMetadataItem';
import { useTranslation } from 'react-i18next';
import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas';
type Props = {
metadata?: CoreMetadata;
@ -24,6 +25,7 @@ const ImageMetadataActions = (props: Props) => {
recallWidth,
recallHeight,
recallStrength,
recallLoRA,
} = useRecallParameters();
const handleRecallPositivePrompt = useCallback(() => {
@ -66,6 +68,13 @@ const ImageMetadataActions = (props: Props) => {
recallStrength(metadata?.strength);
}, [metadata?.strength, recallStrength]);
const handleRecallLoRA = useCallback(
(lora: LoRAMetadataType) => {
recallLoRA(lora);
},
[recallLoRA]
);
if (!metadata || Object.keys(metadata).length === 0) {
return null;
}
@ -130,20 +139,6 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallHeight}
/>
)}
{/* {metadata.threshold !== undefined && (
<MetadataItem
label={t('metadata.threshold')}
value={metadata.threshold}
onClick={() => dispatch(setThreshold(Number(metadata.threshold)))}
/>
)}
{metadata.perlin !== undefined && (
<MetadataItem
label={t('metadata.perlin')}
value={metadata.perlin}
onClick={() => dispatch(setPerlin(Number(metadata.perlin)))}
/>
)} */}
{metadata.scheduler && (
<ImageMetadataItem
label={t('metadata.scheduler')}
@ -165,40 +160,6 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallCfgScale}
/>
)}
{/* {metadata.variations && metadata.variations.length > 0 && (
<MetadataItem
label="{t('metadata.variations')}
value={seedWeightsToString(metadata.variations)}
onClick={() =>
dispatch(
setSeedWeights(seedWeightsToString(metadata.variations))
)
}
/>
)}
{metadata.seamless && (
<MetadataItem
label={t('metadata.seamless')}
value={metadata.seamless}
onClick={() => dispatch(setSeamless(metadata.seamless))}
/>
)}
{metadata.hires_fix && (
<MetadataItem
label={t('metadata.hiresFix')}
value={metadata.hires_fix}
onClick={() => dispatch(setHiresFix(metadata.hires_fix))}
/>
)} */}
{/* {init_image_path && (
<MetadataItem
label={t('metadata.initImage')}
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImage(init_image_path))}
/>
)} */}
{metadata.strength && (
<ImageMetadataItem
label={t('metadata.strength')}
@ -206,13 +167,19 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallStrength}
/>
)}
{/* {metadata.fit && (
<MetadataItem
label={t('metadata.fit')}
value={metadata.fit}
onClick={() => dispatch(setShouldFitToWidthHeight(metadata.fit))}
{metadata.loras &&
metadata.loras.map((lora, index) => {
if (isValidLoRAModel(lora.lora)) {
return (
<ImageMetadataItem
key={index}
label="LoRA"
value={`${lora.lora.model_name} - ${lora.weight}`}
onClick={() => handleRecallLoRA(lora)}
/>
)} */}
);
}
})}
</>
);
};

View File

@ -27,6 +27,13 @@ export const loraSlice = createSlice({
const { model_name, id, base_model } = action.payload;
state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig };
},
loraRecalled: (
state,
action: PayloadAction<LoRAModelConfigEntity & { weight: number }>
) => {
const { model_name, id, base_model, weight } = action.payload;
state.loras[id] = { id, model_name, base_model, weight };
},
loraRemoved: (state, action: PayloadAction<string>) => {
const id = action.payload;
delete state.loras[id];
@ -62,6 +69,7 @@ export const {
loraWeightChanged,
loraWeightReset,
lorasCleared,
loraRecalled,
} = loraSlice.actions;
export default loraSlice.reducer;

View File

@ -1057,6 +1057,13 @@ export const isInvocationFieldSchema = (
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
const zLoRAObject = z.object({
lora: zLoRAModelField.deepPartial(),
weight: z.number(),
});
export type LoRAMetadataType = z.infer<typeof zLoRAObject>;
export const zCoreMetadata = z
.object({
app_version: z.string().nullish(),
@ -1076,14 +1083,7 @@ export const zCoreMetadata = z
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
.nullish(),
controlnets: z.array(zControlField.deepPartial()).nullish(),
loras: z
.array(
z.object({
lora: zLoRAModelField.deepPartial(),
weight: z.number(),
})
)
.nullish(),
loras: z.array(zLoRAObject).nullish(),
vae: zVaeModelField.nullish(),
strength: z.number().nullish(),
init_image: z.string().nullish(),

View File

@ -1,6 +1,10 @@
import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import { CoreMetadata } from 'features/nodes/types/types';
import {
CoreMetadata,
LoRAMetadataType,
LoraInfo,
} from 'features/nodes/types/types';
import {
refinerModelChanged,
setNegativeStylePromptSDXL,
@ -30,6 +34,7 @@ import {
import {
isValidCfgScale,
isValidHeight,
isValidLoRAModel,
isValidMainModel,
isValidNegativePrompt,
isValidPositivePrompt,
@ -45,6 +50,8 @@ import {
isValidStrength,
isValidWidth,
} from '../types/parameterSchemas';
import { loraRecalled } from '../../lora/store/loraSlice';
import { useGetLoRAModelsQuery } from '../../../services/api/endpoints/models';
export const useRecallParameters = () => {
const dispatch = useAppDispatch();
@ -307,6 +314,48 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast]
);
/**
* Recall LoRA with toast
*/
const { data: loraModels } = useGetLoRAModelsQuery();
const recallLoRA = useCallback(
(lora: LoRAMetadataType) => {
if (!isValidLoRAModel(lora.lora)) {
parameterNotSetToast();
return;
}
if (!loraModels || !loraModels.entities) {
return;
}
const matchingId = Object.keys(loraModels.entities).find((loraId) => {
const matchesBaseModel =
loraModels.entities[loraId]?.base_model === lora.lora.base_model;
const matchesModelName =
loraModels.entities[loraId]?.model_name === lora.lora.model_name;
return matchesBaseModel && matchesModelName;
});
if (!matchingId) {
return;
}
const fullLoRA = loraModels.entities[matchingId];
if (!fullLoRA) {
return;
}
dispatch(loraRecalled({ ...fullLoRA, weight: lora.weight }));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast, loraModels]
);
/*
* Sets image as initial image with toast
*/
@ -444,6 +493,7 @@ export const useRecallParameters = () => {
recallWidth,
recallHeight,
recallStrength,
recallLoRA,
recallAllParameters,
sendToImageToImage,
};