mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/nodes/invocation-cache
This commit is contained in:
commit
b9ebce9bdd
@ -1,8 +1,9 @@
|
|||||||
import { CoreMetadata } from 'features/nodes/types/types';
|
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
|
||||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import ImageMetadataItem from './ImageMetadataItem';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas';
|
||||||
|
import ImageMetadataItem from './ImageMetadataItem';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
metadata?: CoreMetadata;
|
metadata?: CoreMetadata;
|
||||||
@ -24,6 +25,7 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
recallWidth,
|
recallWidth,
|
||||||
recallHeight,
|
recallHeight,
|
||||||
recallStrength,
|
recallStrength,
|
||||||
|
recallLoRA,
|
||||||
} = useRecallParameters();
|
} = useRecallParameters();
|
||||||
|
|
||||||
const handleRecallPositivePrompt = useCallback(() => {
|
const handleRecallPositivePrompt = useCallback(() => {
|
||||||
@ -66,6 +68,13 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
recallStrength(metadata?.strength);
|
recallStrength(metadata?.strength);
|
||||||
}, [metadata?.strength, recallStrength]);
|
}, [metadata?.strength, recallStrength]);
|
||||||
|
|
||||||
|
const handleRecallLoRA = useCallback(
|
||||||
|
(lora: LoRAMetadataItem) => {
|
||||||
|
recallLoRA(lora);
|
||||||
|
},
|
||||||
|
[recallLoRA]
|
||||||
|
);
|
||||||
|
|
||||||
if (!metadata || Object.keys(metadata).length === 0) {
|
if (!metadata || Object.keys(metadata).length === 0) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
@ -130,20 +139,6 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
onClick={handleRecallHeight}
|
onClick={handleRecallHeight}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{/* {metadata.threshold !== undefined && (
|
|
||||||
<MetadataItem
|
|
||||||
label={t('metadata.threshold')}
|
|
||||||
value={metadata.threshold}
|
|
||||||
onClick={() => dispatch(setThreshold(Number(metadata.threshold)))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{metadata.perlin !== undefined && (
|
|
||||||
<MetadataItem
|
|
||||||
label={t('metadata.perlin')}
|
|
||||||
value={metadata.perlin}
|
|
||||||
onClick={() => dispatch(setPerlin(Number(metadata.perlin)))}
|
|
||||||
/>
|
|
||||||
)} */}
|
|
||||||
{metadata.scheduler && (
|
{metadata.scheduler && (
|
||||||
<ImageMetadataItem
|
<ImageMetadataItem
|
||||||
label={t('metadata.scheduler')}
|
label={t('metadata.scheduler')}
|
||||||
@ -165,40 +160,6 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
onClick={handleRecallCfgScale}
|
onClick={handleRecallCfgScale}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{/* {metadata.variations && metadata.variations.length > 0 && (
|
|
||||||
<MetadataItem
|
|
||||||
label="{t('metadata.variations')}
|
|
||||||
value={seedWeightsToString(metadata.variations)}
|
|
||||||
onClick={() =>
|
|
||||||
dispatch(
|
|
||||||
setSeedWeights(seedWeightsToString(metadata.variations))
|
|
||||||
)
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{metadata.seamless && (
|
|
||||||
<MetadataItem
|
|
||||||
label={t('metadata.seamless')}
|
|
||||||
value={metadata.seamless}
|
|
||||||
onClick={() => dispatch(setSeamless(metadata.seamless))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{metadata.hires_fix && (
|
|
||||||
<MetadataItem
|
|
||||||
label={t('metadata.hiresFix')}
|
|
||||||
value={metadata.hires_fix}
|
|
||||||
onClick={() => dispatch(setHiresFix(metadata.hires_fix))}
|
|
||||||
/>
|
|
||||||
)} */}
|
|
||||||
|
|
||||||
{/* {init_image_path && (
|
|
||||||
<MetadataItem
|
|
||||||
label={t('metadata.initImage')}
|
|
||||||
value={init_image_path}
|
|
||||||
isLink
|
|
||||||
onClick={() => dispatch(setInitialImage(init_image_path))}
|
|
||||||
/>
|
|
||||||
)} */}
|
|
||||||
{metadata.strength && (
|
{metadata.strength && (
|
||||||
<ImageMetadataItem
|
<ImageMetadataItem
|
||||||
label={t('metadata.strength')}
|
label={t('metadata.strength')}
|
||||||
@ -206,13 +167,19 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
onClick={handleRecallStrength}
|
onClick={handleRecallStrength}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{/* {metadata.fit && (
|
{metadata.loras &&
|
||||||
<MetadataItem
|
metadata.loras.map((lora, index) => {
|
||||||
label={t('metadata.fit')}
|
if (isValidLoRAModel(lora.lora)) {
|
||||||
value={metadata.fit}
|
return (
|
||||||
onClick={() => dispatch(setShouldFitToWidthHeight(metadata.fit))}
|
<ImageMetadataItem
|
||||||
/>
|
key={index}
|
||||||
)} */}
|
label="LoRA"
|
||||||
|
value={`${lora.lora.model_name} - ${lora.weight}`}
|
||||||
|
onClick={() => handleRecallLoRA(lora)}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
})}
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -27,6 +27,13 @@ export const loraSlice = createSlice({
|
|||||||
const { model_name, id, base_model } = action.payload;
|
const { model_name, id, base_model } = action.payload;
|
||||||
state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig };
|
state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig };
|
||||||
},
|
},
|
||||||
|
loraRecalled: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<LoRAModelConfigEntity & { weight: number }>
|
||||||
|
) => {
|
||||||
|
const { model_name, id, base_model, weight } = action.payload;
|
||||||
|
state.loras[id] = { id, model_name, base_model, weight };
|
||||||
|
},
|
||||||
loraRemoved: (state, action: PayloadAction<string>) => {
|
loraRemoved: (state, action: PayloadAction<string>) => {
|
||||||
const id = action.payload;
|
const id = action.payload;
|
||||||
delete state.loras[id];
|
delete state.loras[id];
|
||||||
@ -62,6 +69,7 @@ export const {
|
|||||||
loraWeightChanged,
|
loraWeightChanged,
|
||||||
loraWeightReset,
|
loraWeightReset,
|
||||||
lorasCleared,
|
lorasCleared,
|
||||||
|
loraRecalled,
|
||||||
} = loraSlice.actions;
|
} = loraSlice.actions;
|
||||||
|
|
||||||
export default loraSlice.reducer;
|
export default loraSlice.reducer;
|
||||||
|
@ -1065,6 +1065,13 @@ export const isInvocationFieldSchema = (
|
|||||||
|
|
||||||
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
|
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
|
||||||
|
|
||||||
|
const zLoRAMetadataItem = z.object({
|
||||||
|
lora: zLoRAModelField.deepPartial(),
|
||||||
|
weight: z.number(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
|
||||||
|
|
||||||
export const zCoreMetadata = z
|
export const zCoreMetadata = z
|
||||||
.object({
|
.object({
|
||||||
app_version: z.string().nullish(),
|
app_version: z.string().nullish(),
|
||||||
@ -1084,14 +1091,7 @@ export const zCoreMetadata = z
|
|||||||
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
|
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
|
||||||
.nullish(),
|
.nullish(),
|
||||||
controlnets: z.array(zControlField.deepPartial()).nullish(),
|
controlnets: z.array(zControlField.deepPartial()).nullish(),
|
||||||
loras: z
|
loras: z.array(zLoRAMetadataItem).nullish(),
|
||||||
.array(
|
|
||||||
z.object({
|
|
||||||
lora: zLoRAModelField.deepPartial(),
|
|
||||||
weight: z.number(),
|
|
||||||
})
|
|
||||||
)
|
|
||||||
.nullish(),
|
|
||||||
vae: zVaeModelField.nullish(),
|
vae: zVaeModelField.nullish(),
|
||||||
strength: z.number().nullish(),
|
strength: z.number().nullish(),
|
||||||
init_image: z.string().nullish(),
|
init_image: z.string().nullish(),
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppToaster } from 'app/components/Toaster';
|
import { useAppToaster } from 'app/components/Toaster';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { CoreMetadata } from 'features/nodes/types/types';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
|
||||||
import {
|
import {
|
||||||
refinerModelChanged,
|
refinerModelChanged,
|
||||||
setNegativeStylePromptSDXL,
|
setNegativeStylePromptSDXL,
|
||||||
@ -15,6 +17,11 @@ import {
|
|||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
|
import {
|
||||||
|
loraModelsAdapter,
|
||||||
|
useGetLoRAModelsQuery,
|
||||||
|
} from '../../../services/api/endpoints/models';
|
||||||
|
import { loraRecalled } from '../../lora/store/loraSlice';
|
||||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||||
import {
|
import {
|
||||||
setCfgScale,
|
setCfgScale,
|
||||||
@ -30,6 +37,7 @@ import {
|
|||||||
import {
|
import {
|
||||||
isValidCfgScale,
|
isValidCfgScale,
|
||||||
isValidHeight,
|
isValidHeight,
|
||||||
|
isValidLoRAModel,
|
||||||
isValidMainModel,
|
isValidMainModel,
|
||||||
isValidNegativePrompt,
|
isValidNegativePrompt,
|
||||||
isValidPositivePrompt,
|
isValidPositivePrompt,
|
||||||
@ -46,10 +54,16 @@ import {
|
|||||||
isValidWidth,
|
isValidWidth,
|
||||||
} from '../types/parameterSchemas';
|
} from '../types/parameterSchemas';
|
||||||
|
|
||||||
|
const selector = createSelector(stateSelector, ({ generation }) => {
|
||||||
|
const { model } = generation;
|
||||||
|
return { model };
|
||||||
|
});
|
||||||
|
|
||||||
export const useRecallParameters = () => {
|
export const useRecallParameters = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const toaster = useAppToaster();
|
const toaster = useAppToaster();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const { model } = useAppSelector(selector);
|
||||||
|
|
||||||
const parameterSetToast = useCallback(() => {
|
const parameterSetToast = useCallback(() => {
|
||||||
toaster({
|
toaster({
|
||||||
@ -60,14 +74,18 @@ export const useRecallParameters = () => {
|
|||||||
});
|
});
|
||||||
}, [t, toaster]);
|
}, [t, toaster]);
|
||||||
|
|
||||||
const parameterNotSetToast = useCallback(() => {
|
const parameterNotSetToast = useCallback(
|
||||||
toaster({
|
(description?: string) => {
|
||||||
title: t('toast.parameterNotSet'),
|
toaster({
|
||||||
status: 'warning',
|
title: t('toast.parameterNotSet'),
|
||||||
duration: 2500,
|
description,
|
||||||
isClosable: true,
|
status: 'warning',
|
||||||
});
|
duration: 2500,
|
||||||
}, [t, toaster]);
|
isClosable: true,
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[t, toaster]
|
||||||
|
);
|
||||||
|
|
||||||
const allParameterSetToast = useCallback(() => {
|
const allParameterSetToast = useCallback(() => {
|
||||||
toaster({
|
toaster({
|
||||||
@ -78,14 +96,18 @@ export const useRecallParameters = () => {
|
|||||||
});
|
});
|
||||||
}, [t, toaster]);
|
}, [t, toaster]);
|
||||||
|
|
||||||
const allParameterNotSetToast = useCallback(() => {
|
const allParameterNotSetToast = useCallback(
|
||||||
toaster({
|
(description?: string) => {
|
||||||
title: t('toast.parametersNotSet'),
|
toaster({
|
||||||
status: 'warning',
|
title: t('toast.parametersNotSet'),
|
||||||
duration: 2500,
|
status: 'warning',
|
||||||
isClosable: true,
|
description,
|
||||||
});
|
duration: 2500,
|
||||||
}, [t, toaster]);
|
isClosable: true,
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[t, toaster]
|
||||||
|
);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Recall both prompts with toast
|
* Recall both prompts with toast
|
||||||
@ -307,6 +329,67 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Recall LoRA with toast
|
||||||
|
*/
|
||||||
|
|
||||||
|
const { loras } = useGetLoRAModelsQuery(undefined, {
|
||||||
|
selectFromResult: (result) => ({
|
||||||
|
loras: result.data
|
||||||
|
? loraModelsAdapter.getSelectors().selectAll(result.data)
|
||||||
|
: [],
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const prepareLoRAMetadataItem = useCallback(
|
||||||
|
(loraMetadataItem: LoRAMetadataItem) => {
|
||||||
|
if (!isValidLoRAModel(loraMetadataItem.lora)) {
|
||||||
|
return { lora: null, error: 'Invalid LoRA model' };
|
||||||
|
}
|
||||||
|
|
||||||
|
const { base_model, model_name } = loraMetadataItem.lora;
|
||||||
|
|
||||||
|
const matchingLoRA = loras.find(
|
||||||
|
(l) => l.base_model === base_model && l.model_name === model_name
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!matchingLoRA) {
|
||||||
|
return { lora: null, error: 'LoRA model is not installed' };
|
||||||
|
}
|
||||||
|
|
||||||
|
const isCompatibleBaseModel =
|
||||||
|
matchingLoRA?.base_model === model?.base_model;
|
||||||
|
|
||||||
|
if (!isCompatibleBaseModel) {
|
||||||
|
return {
|
||||||
|
lora: null,
|
||||||
|
error: 'LoRA incompatible with currently-selected model',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return { lora: matchingLoRA, error: null };
|
||||||
|
},
|
||||||
|
[loras, model?.base_model]
|
||||||
|
);
|
||||||
|
|
||||||
|
const recallLoRA = useCallback(
|
||||||
|
(loraMetadataItem: LoRAMetadataItem) => {
|
||||||
|
const result = prepareLoRAMetadataItem(loraMetadataItem);
|
||||||
|
|
||||||
|
if (!result.lora) {
|
||||||
|
parameterNotSetToast(result.error);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
loraRecalled({ ...result.lora, weight: loraMetadataItem.weight })
|
||||||
|
);
|
||||||
|
|
||||||
|
parameterSetToast();
|
||||||
|
},
|
||||||
|
[prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
|
);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Sets image as initial image with toast
|
* Sets image as initial image with toast
|
||||||
*/
|
*/
|
||||||
@ -344,6 +427,7 @@ export const useRecallParameters = () => {
|
|||||||
refiner_positive_aesthetic_score,
|
refiner_positive_aesthetic_score,
|
||||||
refiner_negative_aesthetic_score,
|
refiner_negative_aesthetic_score,
|
||||||
refiner_start,
|
refiner_start,
|
||||||
|
loras,
|
||||||
} = metadata;
|
} = metadata;
|
||||||
|
|
||||||
if (isValidCfgScale(cfg_scale)) {
|
if (isValidCfgScale(cfg_scale)) {
|
||||||
@ -425,9 +509,21 @@ export const useRecallParameters = () => {
|
|||||||
dispatch(setRefinerStart(refiner_start));
|
dispatch(setRefinerStart(refiner_start));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
loras?.forEach((lora) => {
|
||||||
|
const result = prepareLoRAMetadataItem(lora);
|
||||||
|
if (result.lora) {
|
||||||
|
dispatch(loraRecalled({ ...result.lora, weight: lora.weight }));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
allParameterSetToast();
|
allParameterSetToast();
|
||||||
},
|
},
|
||||||
[allParameterNotSetToast, allParameterSetToast, dispatch]
|
[
|
||||||
|
allParameterNotSetToast,
|
||||||
|
allParameterSetToast,
|
||||||
|
dispatch,
|
||||||
|
prepareLoRAMetadataItem,
|
||||||
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -444,6 +540,7 @@ export const useRecallParameters = () => {
|
|||||||
recallWidth,
|
recallWidth,
|
||||||
recallHeight,
|
recallHeight,
|
||||||
recallStrength,
|
recallStrength,
|
||||||
|
recallLoRA,
|
||||||
recallAllParameters,
|
recallAllParameters,
|
||||||
sendToImageToImage,
|
sendToImageToImage,
|
||||||
};
|
};
|
||||||
|
@ -128,7 +128,7 @@ export const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
|||||||
const onnxModelsAdapter = createEntityAdapter<OnnxModelConfigEntity>({
|
const onnxModelsAdapter = createEntityAdapter<OnnxModelConfigEntity>({
|
||||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||||
});
|
});
|
||||||
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
export const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
||||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||||
});
|
});
|
||||||
export const controlNetModelsAdapter =
|
export const controlNetModelsAdapter =
|
||||||
|
@ -198,6 +198,13 @@ output = "coverage/index.xml"
|
|||||||
max-line-length = 120
|
max-line-length = 120
|
||||||
ignore = ["E203", "E266", "E501", "W503"]
|
ignore = ["E203", "E266", "E501", "W503"]
|
||||||
select = ["B", "C", "E", "F", "W", "T4"]
|
select = ["B", "C", "E", "F", "W", "T4"]
|
||||||
|
exclude = [
|
||||||
|
".git",
|
||||||
|
"__pycache__",
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
"invokeai/frontend/web/node_modules/"
|
||||||
|
]
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
|
Loading…
Reference in New Issue
Block a user