Merge branch 'main' into feat/nodes/invocation-cache

This commit is contained in:
Jonathan 2023-09-18 19:54:14 -05:00 committed by GitHub
commit b9ebce9bdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 164 additions and 85 deletions

View File

@ -1,8 +1,9 @@
import { CoreMetadata } from 'features/nodes/types/types'; import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import ImageMetadataItem from './ImageMetadataItem';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas';
import ImageMetadataItem from './ImageMetadataItem';
type Props = { type Props = {
metadata?: CoreMetadata; metadata?: CoreMetadata;
@ -24,6 +25,7 @@ const ImageMetadataActions = (props: Props) => {
recallWidth, recallWidth,
recallHeight, recallHeight,
recallStrength, recallStrength,
recallLoRA,
} = useRecallParameters(); } = useRecallParameters();
const handleRecallPositivePrompt = useCallback(() => { const handleRecallPositivePrompt = useCallback(() => {
@ -66,6 +68,13 @@ const ImageMetadataActions = (props: Props) => {
recallStrength(metadata?.strength); recallStrength(metadata?.strength);
}, [metadata?.strength, recallStrength]); }, [metadata?.strength, recallStrength]);
const handleRecallLoRA = useCallback(
(lora: LoRAMetadataItem) => {
recallLoRA(lora);
},
[recallLoRA]
);
if (!metadata || Object.keys(metadata).length === 0) { if (!metadata || Object.keys(metadata).length === 0) {
return null; return null;
} }
@ -130,20 +139,6 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallHeight} onClick={handleRecallHeight}
/> />
)} )}
{/* {metadata.threshold !== undefined && (
<MetadataItem
label={t('metadata.threshold')}
value={metadata.threshold}
onClick={() => dispatch(setThreshold(Number(metadata.threshold)))}
/>
)}
{metadata.perlin !== undefined && (
<MetadataItem
label={t('metadata.perlin')}
value={metadata.perlin}
onClick={() => dispatch(setPerlin(Number(metadata.perlin)))}
/>
)} */}
{metadata.scheduler && ( {metadata.scheduler && (
<ImageMetadataItem <ImageMetadataItem
label={t('metadata.scheduler')} label={t('metadata.scheduler')}
@ -165,40 +160,6 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallCfgScale} onClick={handleRecallCfgScale}
/> />
)} )}
{/* {metadata.variations && metadata.variations.length > 0 && (
<MetadataItem
label="{t('metadata.variations')}
value={seedWeightsToString(metadata.variations)}
onClick={() =>
dispatch(
setSeedWeights(seedWeightsToString(metadata.variations))
)
}
/>
)}
{metadata.seamless && (
<MetadataItem
label={t('metadata.seamless')}
value={metadata.seamless}
onClick={() => dispatch(setSeamless(metadata.seamless))}
/>
)}
{metadata.hires_fix && (
<MetadataItem
label={t('metadata.hiresFix')}
value={metadata.hires_fix}
onClick={() => dispatch(setHiresFix(metadata.hires_fix))}
/>
)} */}
{/* {init_image_path && (
<MetadataItem
label={t('metadata.initImage')}
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImage(init_image_path))}
/>
)} */}
{metadata.strength && ( {metadata.strength && (
<ImageMetadataItem <ImageMetadataItem
label={t('metadata.strength')} label={t('metadata.strength')}
@ -206,13 +167,19 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallStrength} onClick={handleRecallStrength}
/> />
)} )}
{/* {metadata.fit && ( {metadata.loras &&
<MetadataItem metadata.loras.map((lora, index) => {
label={t('metadata.fit')} if (isValidLoRAModel(lora.lora)) {
value={metadata.fit} return (
onClick={() => dispatch(setShouldFitToWidthHeight(metadata.fit))} <ImageMetadataItem
/> key={index}
)} */} label="LoRA"
value={`${lora.lora.model_name} - ${lora.weight}`}
onClick={() => handleRecallLoRA(lora)}
/>
);
}
})}
</> </>
); );
}; };

View File

@ -27,6 +27,13 @@ export const loraSlice = createSlice({
const { model_name, id, base_model } = action.payload; const { model_name, id, base_model } = action.payload;
state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig }; state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig };
}, },
loraRecalled: (
state,
action: PayloadAction<LoRAModelConfigEntity & { weight: number }>
) => {
const { model_name, id, base_model, weight } = action.payload;
state.loras[id] = { id, model_name, base_model, weight };
},
loraRemoved: (state, action: PayloadAction<string>) => { loraRemoved: (state, action: PayloadAction<string>) => {
const id = action.payload; const id = action.payload;
delete state.loras[id]; delete state.loras[id];
@ -62,6 +69,7 @@ export const {
loraWeightChanged, loraWeightChanged,
loraWeightReset, loraWeightReset,
lorasCleared, lorasCleared,
loraRecalled,
} = loraSlice.actions; } = loraSlice.actions;
export default loraSlice.reducer; export default loraSlice.reducer;

View File

@ -1065,6 +1065,13 @@ export const isInvocationFieldSchema = (
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' }; export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
const zLoRAMetadataItem = z.object({
lora: zLoRAModelField.deepPartial(),
weight: z.number(),
});
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
export const zCoreMetadata = z export const zCoreMetadata = z
.object({ .object({
app_version: z.string().nullish(), app_version: z.string().nullish(),
@ -1084,14 +1091,7 @@ export const zCoreMetadata = z
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()]) .union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
.nullish(), .nullish(),
controlnets: z.array(zControlField.deepPartial()).nullish(), controlnets: z.array(zControlField.deepPartial()).nullish(),
loras: z loras: z.array(zLoRAMetadataItem).nullish(),
.array(
z.object({
lora: zLoRAModelField.deepPartial(),
weight: z.number(),
})
)
.nullish(),
vae: zVaeModelField.nullish(), vae: zVaeModelField.nullish(),
strength: z.number().nullish(), strength: z.number().nullish(),
init_image: z.string().nullish(), init_image: z.string().nullish(),

View File

@ -1,6 +1,8 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks'; import { stateSelector } from 'app/store/store';
import { CoreMetadata } from 'features/nodes/types/types'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
import { import {
refinerModelChanged, refinerModelChanged,
setNegativeStylePromptSDXL, setNegativeStylePromptSDXL,
@ -15,6 +17,11 @@ import {
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import {
loraModelsAdapter,
useGetLoRAModelsQuery,
} from '../../../services/api/endpoints/models';
import { loraRecalled } from '../../lora/store/loraSlice';
import { initialImageSelected, modelSelected } from '../store/actions'; import { initialImageSelected, modelSelected } from '../store/actions';
import { import {
setCfgScale, setCfgScale,
@ -30,6 +37,7 @@ import {
import { import {
isValidCfgScale, isValidCfgScale,
isValidHeight, isValidHeight,
isValidLoRAModel,
isValidMainModel, isValidMainModel,
isValidNegativePrompt, isValidNegativePrompt,
isValidPositivePrompt, isValidPositivePrompt,
@ -46,10 +54,16 @@ import {
isValidWidth, isValidWidth,
} from '../types/parameterSchemas'; } from '../types/parameterSchemas';
const selector = createSelector(stateSelector, ({ generation }) => {
const { model } = generation;
return { model };
});
export const useRecallParameters = () => { export const useRecallParameters = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const toaster = useAppToaster(); const toaster = useAppToaster();
const { t } = useTranslation(); const { t } = useTranslation();
const { model } = useAppSelector(selector);
const parameterSetToast = useCallback(() => { const parameterSetToast = useCallback(() => {
toaster({ toaster({
@ -60,14 +74,18 @@ export const useRecallParameters = () => {
}); });
}, [t, toaster]); }, [t, toaster]);
const parameterNotSetToast = useCallback(() => { const parameterNotSetToast = useCallback(
toaster({ (description?: string) => {
title: t('toast.parameterNotSet'), toaster({
status: 'warning', title: t('toast.parameterNotSet'),
duration: 2500, description,
isClosable: true, status: 'warning',
}); duration: 2500,
}, [t, toaster]); isClosable: true,
});
},
[t, toaster]
);
const allParameterSetToast = useCallback(() => { const allParameterSetToast = useCallback(() => {
toaster({ toaster({
@ -78,14 +96,18 @@ export const useRecallParameters = () => {
}); });
}, [t, toaster]); }, [t, toaster]);
const allParameterNotSetToast = useCallback(() => { const allParameterNotSetToast = useCallback(
toaster({ (description?: string) => {
title: t('toast.parametersNotSet'), toaster({
status: 'warning', title: t('toast.parametersNotSet'),
duration: 2500, status: 'warning',
isClosable: true, description,
}); duration: 2500,
}, [t, toaster]); isClosable: true,
});
},
[t, toaster]
);
/** /**
* Recall both prompts with toast * Recall both prompts with toast
@ -307,6 +329,67 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall LoRA with toast
*/
const { loras } = useGetLoRAModelsQuery(undefined, {
selectFromResult: (result) => ({
loras: result.data
? loraModelsAdapter.getSelectors().selectAll(result.data)
: [],
}),
});
const prepareLoRAMetadataItem = useCallback(
(loraMetadataItem: LoRAMetadataItem) => {
if (!isValidLoRAModel(loraMetadataItem.lora)) {
return { lora: null, error: 'Invalid LoRA model' };
}
const { base_model, model_name } = loraMetadataItem.lora;
const matchingLoRA = loras.find(
(l) => l.base_model === base_model && l.model_name === model_name
);
if (!matchingLoRA) {
return { lora: null, error: 'LoRA model is not installed' };
}
const isCompatibleBaseModel =
matchingLoRA?.base_model === model?.base_model;
if (!isCompatibleBaseModel) {
return {
lora: null,
error: 'LoRA incompatible with currently-selected model',
};
}
return { lora: matchingLoRA, error: null };
},
[loras, model?.base_model]
);
const recallLoRA = useCallback(
(loraMetadataItem: LoRAMetadataItem) => {
const result = prepareLoRAMetadataItem(loraMetadataItem);
if (!result.lora) {
parameterNotSetToast(result.error);
return;
}
dispatch(
loraRecalled({ ...result.lora, weight: loraMetadataItem.weight })
);
parameterSetToast();
},
[prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
);
/* /*
* Sets image as initial image with toast * Sets image as initial image with toast
*/ */
@ -344,6 +427,7 @@ export const useRecallParameters = () => {
refiner_positive_aesthetic_score, refiner_positive_aesthetic_score,
refiner_negative_aesthetic_score, refiner_negative_aesthetic_score,
refiner_start, refiner_start,
loras,
} = metadata; } = metadata;
if (isValidCfgScale(cfg_scale)) { if (isValidCfgScale(cfg_scale)) {
@ -425,9 +509,21 @@ export const useRecallParameters = () => {
dispatch(setRefinerStart(refiner_start)); dispatch(setRefinerStart(refiner_start));
} }
loras?.forEach((lora) => {
const result = prepareLoRAMetadataItem(lora);
if (result.lora) {
dispatch(loraRecalled({ ...result.lora, weight: lora.weight }));
}
});
allParameterSetToast(); allParameterSetToast();
}, },
[allParameterNotSetToast, allParameterSetToast, dispatch] [
allParameterNotSetToast,
allParameterSetToast,
dispatch,
prepareLoRAMetadataItem,
]
); );
return { return {
@ -444,6 +540,7 @@ export const useRecallParameters = () => {
recallWidth, recallWidth,
recallHeight, recallHeight,
recallStrength, recallStrength,
recallLoRA,
recallAllParameters, recallAllParameters,
sendToImageToImage, sendToImageToImage,
}; };

View File

@ -128,7 +128,7 @@ export const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
const onnxModelsAdapter = createEntityAdapter<OnnxModelConfigEntity>({ const onnxModelsAdapter = createEntityAdapter<OnnxModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({ export const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
export const controlNetModelsAdapter = export const controlNetModelsAdapter =

View File

@ -198,6 +198,13 @@ output = "coverage/index.xml"
max-line-length = 120 max-line-length = 120
ignore = ["E203", "E266", "E501", "W503"] ignore = ["E203", "E266", "E501", "W503"]
select = ["B", "C", "E", "F", "W", "T4"] select = ["B", "C", "E", "F", "W", "T4"]
exclude = [
".git",
"__pycache__",
"build",
"dist",
"invokeai/frontend/web/node_modules/"
]
[tool.black] [tool.black]
line-length = 120 line-length = 120