mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
first pass to recall LoRAs
This commit is contained in:
parent
627750eded
commit
5a961bb58e
@ -1,8 +1,9 @@
|
||||
import { CoreMetadata } from 'features/nodes/types/types';
|
||||
import { CoreMetadata, LoRAMetadataType } from 'features/nodes/types/types';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { memo, useCallback } from 'react';
|
||||
import ImageMetadataItem from './ImageMetadataItem';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas';
|
||||
|
||||
type Props = {
|
||||
metadata?: CoreMetadata;
|
||||
@ -24,6 +25,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
recallWidth,
|
||||
recallHeight,
|
||||
recallStrength,
|
||||
recallLoRA,
|
||||
} = useRecallParameters();
|
||||
|
||||
const handleRecallPositivePrompt = useCallback(() => {
|
||||
@ -66,6 +68,13 @@ const ImageMetadataActions = (props: Props) => {
|
||||
recallStrength(metadata?.strength);
|
||||
}, [metadata?.strength, recallStrength]);
|
||||
|
||||
const handleRecallLoRA = useCallback(
|
||||
(lora: LoRAMetadataType) => {
|
||||
recallLoRA(lora);
|
||||
},
|
||||
[recallLoRA]
|
||||
);
|
||||
|
||||
if (!metadata || Object.keys(metadata).length === 0) {
|
||||
return null;
|
||||
}
|
||||
@ -130,20 +139,6 @@ const ImageMetadataActions = (props: Props) => {
|
||||
onClick={handleRecallHeight}
|
||||
/>
|
||||
)}
|
||||
{/* {metadata.threshold !== undefined && (
|
||||
<MetadataItem
|
||||
label={t('metadata.threshold')}
|
||||
value={metadata.threshold}
|
||||
onClick={() => dispatch(setThreshold(Number(metadata.threshold)))}
|
||||
/>
|
||||
)}
|
||||
{metadata.perlin !== undefined && (
|
||||
<MetadataItem
|
||||
label={t('metadata.perlin')}
|
||||
value={metadata.perlin}
|
||||
onClick={() => dispatch(setPerlin(Number(metadata.perlin)))}
|
||||
/>
|
||||
)} */}
|
||||
{metadata.scheduler && (
|
||||
<ImageMetadataItem
|
||||
label={t('metadata.scheduler')}
|
||||
@ -165,40 +160,6 @@ const ImageMetadataActions = (props: Props) => {
|
||||
onClick={handleRecallCfgScale}
|
||||
/>
|
||||
)}
|
||||
{/* {metadata.variations && metadata.variations.length > 0 && (
|
||||
<MetadataItem
|
||||
label="{t('metadata.variations')}
|
||||
value={seedWeightsToString(metadata.variations)}
|
||||
onClick={() =>
|
||||
dispatch(
|
||||
setSeedWeights(seedWeightsToString(metadata.variations))
|
||||
)
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{metadata.seamless && (
|
||||
<MetadataItem
|
||||
label={t('metadata.seamless')}
|
||||
value={metadata.seamless}
|
||||
onClick={() => dispatch(setSeamless(metadata.seamless))}
|
||||
/>
|
||||
)}
|
||||
{metadata.hires_fix && (
|
||||
<MetadataItem
|
||||
label={t('metadata.hiresFix')}
|
||||
value={metadata.hires_fix}
|
||||
onClick={() => dispatch(setHiresFix(metadata.hires_fix))}
|
||||
/>
|
||||
)} */}
|
||||
|
||||
{/* {init_image_path && (
|
||||
<MetadataItem
|
||||
label={t('metadata.initImage')}
|
||||
value={init_image_path}
|
||||
isLink
|
||||
onClick={() => dispatch(setInitialImage(init_image_path))}
|
||||
/>
|
||||
)} */}
|
||||
{metadata.strength && (
|
||||
<ImageMetadataItem
|
||||
label={t('metadata.strength')}
|
||||
@ -206,13 +167,19 @@ const ImageMetadataActions = (props: Props) => {
|
||||
onClick={handleRecallStrength}
|
||||
/>
|
||||
)}
|
||||
{/* {metadata.fit && (
|
||||
<MetadataItem
|
||||
label={t('metadata.fit')}
|
||||
value={metadata.fit}
|
||||
onClick={() => dispatch(setShouldFitToWidthHeight(metadata.fit))}
|
||||
/>
|
||||
)} */}
|
||||
{metadata.loras &&
|
||||
metadata.loras.map((lora, index) => {
|
||||
if (isValidLoRAModel(lora.lora)) {
|
||||
return (
|
||||
<ImageMetadataItem
|
||||
key={index}
|
||||
label="LoRA"
|
||||
value={`${lora.lora.model_name} - ${lora.weight}`}
|
||||
onClick={() => handleRecallLoRA(lora)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
})}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
@ -27,6 +27,13 @@ export const loraSlice = createSlice({
|
||||
const { model_name, id, base_model } = action.payload;
|
||||
state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig };
|
||||
},
|
||||
loraRecalled: (
|
||||
state,
|
||||
action: PayloadAction<LoRAModelConfigEntity & { weight: number }>
|
||||
) => {
|
||||
const { model_name, id, base_model, weight } = action.payload;
|
||||
state.loras[id] = { id, model_name, base_model, weight };
|
||||
},
|
||||
loraRemoved: (state, action: PayloadAction<string>) => {
|
||||
const id = action.payload;
|
||||
delete state.loras[id];
|
||||
@ -62,6 +69,7 @@ export const {
|
||||
loraWeightChanged,
|
||||
loraWeightReset,
|
||||
lorasCleared,
|
||||
loraRecalled,
|
||||
} = loraSlice.actions;
|
||||
|
||||
export default loraSlice.reducer;
|
||||
|
@ -1057,6 +1057,13 @@ export const isInvocationFieldSchema = (
|
||||
|
||||
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
|
||||
|
||||
const zLoRAObject = z.object({
|
||||
lora: zLoRAModelField.deepPartial(),
|
||||
weight: z.number(),
|
||||
});
|
||||
|
||||
export type LoRAMetadataType = z.infer<typeof zLoRAObject>;
|
||||
|
||||
export const zCoreMetadata = z
|
||||
.object({
|
||||
app_version: z.string().nullish(),
|
||||
@ -1076,14 +1083,7 @@ export const zCoreMetadata = z
|
||||
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
|
||||
.nullish(),
|
||||
controlnets: z.array(zControlField.deepPartial()).nullish(),
|
||||
loras: z
|
||||
.array(
|
||||
z.object({
|
||||
lora: zLoRAModelField.deepPartial(),
|
||||
weight: z.number(),
|
||||
})
|
||||
)
|
||||
.nullish(),
|
||||
loras: z.array(zLoRAObject).nullish(),
|
||||
vae: zVaeModelField.nullish(),
|
||||
strength: z.number().nullish(),
|
||||
init_image: z.string().nullish(),
|
||||
|
@ -1,6 +1,10 @@
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { CoreMetadata } from 'features/nodes/types/types';
|
||||
import {
|
||||
CoreMetadata,
|
||||
LoRAMetadataType,
|
||||
LoraInfo,
|
||||
} from 'features/nodes/types/types';
|
||||
import {
|
||||
refinerModelChanged,
|
||||
setNegativeStylePromptSDXL,
|
||||
@ -30,6 +34,7 @@ import {
|
||||
import {
|
||||
isValidCfgScale,
|
||||
isValidHeight,
|
||||
isValidLoRAModel,
|
||||
isValidMainModel,
|
||||
isValidNegativePrompt,
|
||||
isValidPositivePrompt,
|
||||
@ -45,6 +50,8 @@ import {
|
||||
isValidStrength,
|
||||
isValidWidth,
|
||||
} from '../types/parameterSchemas';
|
||||
import { loraRecalled } from '../../lora/store/loraSlice';
|
||||
import { useGetLoRAModelsQuery } from '../../../services/api/endpoints/models';
|
||||
|
||||
export const useRecallParameters = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
@ -307,6 +314,48 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall LoRA with toast
|
||||
*/
|
||||
|
||||
const { data: loraModels } = useGetLoRAModelsQuery();
|
||||
|
||||
const recallLoRA = useCallback(
|
||||
(lora: LoRAMetadataType) => {
|
||||
if (!isValidLoRAModel(lora.lora)) {
|
||||
parameterNotSetToast();
|
||||
return;
|
||||
}
|
||||
|
||||
if (!loraModels || !loraModels.entities) {
|
||||
return;
|
||||
}
|
||||
|
||||
const matchingId = Object.keys(loraModels.entities).find((loraId) => {
|
||||
const matchesBaseModel =
|
||||
loraModels.entities[loraId]?.base_model === lora.lora.base_model;
|
||||
const matchesModelName =
|
||||
loraModels.entities[loraId]?.model_name === lora.lora.model_name;
|
||||
return matchesBaseModel && matchesModelName;
|
||||
});
|
||||
|
||||
if (!matchingId) {
|
||||
return;
|
||||
}
|
||||
|
||||
const fullLoRA = loraModels.entities[matchingId];
|
||||
|
||||
if (!fullLoRA) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(loraRecalled({ ...fullLoRA, weight: lora.weight }));
|
||||
|
||||
parameterSetToast();
|
||||
},
|
||||
[dispatch, parameterSetToast, parameterNotSetToast, loraModels]
|
||||
);
|
||||
|
||||
/*
|
||||
* Sets image as initial image with toast
|
||||
*/
|
||||
@ -444,6 +493,7 @@ export const useRecallParameters = () => {
|
||||
recallWidth,
|
||||
recallHeight,
|
||||
recallStrength,
|
||||
recallLoRA,
|
||||
recallAllParameters,
|
||||
sendToImageToImage,
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user