feat: Restore Model Merge functionality

This commit is contained in:
blessedcoolant 2023-07-12 22:43:06 +12:00
parent 683229e285
commit 3db1aa738c
5 changed files with 172 additions and 64 deletions

View File

@ -427,6 +427,7 @@
"customSaveLocation": "Custom Save Location",
"merge": "Merge",
"modelsMerged": "Models Merged",
"modelsMergeFailed": "Model Merge Failed",
"mergeModels": "Merge Models",
"modelOne": "Model 1",
"modelTwo": "Model 2",
@ -447,7 +448,8 @@
"weightedSum": "Weighted Sum",
"none": "none",
"addDifference": "Add Difference",
"pickModelType": "Pick Model Type"
"pickModelType": "Pick Model Type",
"selectModel": "Select Model"
},
"parameters": {
"general": "General",

View File

@ -1,35 +1,74 @@
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
Flex,
Radio,
RadioGroup,
Text,
Tooltip,
useColorMode,
} from '@chakra-ui/react';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAISelect from 'common/components/IAISelect';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAISlider from 'common/components/IAISlider';
import { addToast } from 'features/system/store/systemSlice';
import { pickBy } from 'lodash-es';
import { useState } from 'react';
import { useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import {
useGetMainModelsQuery,
useMergeMainModelsMutation,
} from 'services/api/endpoints/models';
import { BaseModelType, MergeModelConfig } from 'services/api/types';
import { mode } from 'theme/util/mode';
const baseModelTypeSelectData = [
{ label: 'Stable Diffusion 1', value: 'sd-1' },
{ label: 'Stable Diffusion 2', value: 'sd-2' },
];
export default function MergeModelsPanel() {
const { t } = useTranslation();
const { colorMode } = useColorMode();
const dispatch = useAppDispatch();
const { data } = useGetMainModelsQuery();
const diffusersModels = pickBy(
const [mergeModels, { isLoading, error, data: mergedModelData }] =
useMergeMainModelsMutation();
const [baseModel, setBaseModel] = useState<BaseModelType>('sd-1');
const sd1DiffusersModels = pickBy(
data?.entities,
(value, _) => value?.model_format === 'diffusers'
(value, _) =>
value?.model_format === 'diffusers' && value?.base_model === 'sd-1'
);
const [modelOne, setModelOne] = useState<string>(
Object.keys(diffusersModels)[0]
const sd2DiffusersModels = pickBy(
data?.entities,
(value, _) =>
value?.model_format === 'diffusers' && value?.base_model === 'sd-2'
);
const [modelTwo, setModelTwo] = useState<string>(
Object.keys(diffusersModels)[1]
const modelsMap = useMemo(() => {
return {
'sd-1': sd1DiffusersModels,
'sd-2': sd2DiffusersModels,
};
}, [sd1DiffusersModels, sd2DiffusersModels]);
const [modelOne, setModelOne] = useState<string | null>(
Object.keys(modelsMap[baseModel])[0]
);
const [modelThree, setModelThree] = useState<string>('none');
const [modelTwo, setModelTwo] = useState<string | null>(
Object.keys(modelsMap[baseModel])[1]
);
const [modelThree, setModelThree] = useState<string | null>(null);
const [mergedModelName, setMergedModelName] = useState<string>('');
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
@ -47,41 +86,72 @@ export default function MergeModelsPanel() {
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
const modelOneList = Object.keys(diffusersModels).filter(
(model) => model !== modelTwo && model !== modelThree
const modelOneList = Object.keys(
modelsMap[baseModel as keyof typeof modelsMap]
).filter((model) => model !== modelTwo && model !== modelThree);
const modelTwoList = Object.keys(
modelsMap[baseModel as keyof typeof modelsMap]
).filter((model) => model !== modelOne && model !== modelThree);
const modelThreeList = Object.keys(modelsMap[baseModel]).filter(
(model) => model !== modelOne && model !== modelTwo
);
const modelTwoList = Object.keys(diffusersModels).filter(
(model) => model !== modelOne && model !== modelThree
);
const modelThreeList = [
{ key: t('modelManager.none'), value: 'none' },
...Object.keys(diffusersModels)
.filter((model) => model !== modelOne && model !== modelTwo)
.map((model) => ({ key: model, value: model })),
];
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const handleBaseModelChange = (v: string) => {
setBaseModel(v as BaseModelType);
setModelOne(null);
setModelTwo(null);
};
const mergeModelsHandler = () => {
let modelsToMerge: string[] = [modelOne, modelTwo, modelThree];
modelsToMerge = modelsToMerge.filter((model) => model !== 'none');
const models_names: string[] = [];
const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = {
models_to_merge: modelsToMerge,
let modelsToMerge: (string | null)[] = [modelOne, modelTwo, modelThree];
modelsToMerge = modelsToMerge.filter((model) => model !== null);
modelsToMerge.forEach((model) => {
if (model) {
models_names.push(model?.split('/')[2]);
}
});
const mergeModelsInfo: MergeModelConfig = {
model_names: models_names,
merged_model_name:
mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'),
mergedModelName !== '' ? mergedModelName : models_names.join('-'),
alpha: modelMergeAlpha,
interp: modelMergeInterp,
model_merge_save_path:
modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
// model_merge_save_path:
// modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
force: modelMergeForce,
};
dispatch(mergeDiffusersModels(mergeModelsInfo));
mergeModels({
base_model: baseModel,
body: mergeModelsInfo,
});
if (error) {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelsMergeFailed'),
status: 'error',
})
)
);
}
if (mergedModelData) {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelsMerged'),
status: 'success',
})
)
);
}
};
return (
@ -90,7 +160,6 @@ export default function MergeModelsPanel() {
sx={{
flexDirection: 'column',
rowGap: 1,
bg: 'base.900',
}}
>
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
@ -98,26 +167,43 @@ export default function MergeModelsPanel() {
{t('modelManager.modelMergeHeaderHelp2')}
</Text>
</Flex>
<Flex columnGap={4}>
<IAISelect
<IAIMantineSelect
label="Model Type"
w="100%"
data={baseModelTypeSelectData}
value={baseModel}
onChange={handleBaseModelChange}
/>
<IAIMantineSelect
label={t('modelManager.modelOne')}
validValues={modelOneList}
onChange={(e) => setModelOne(e.target.value)}
w="100%"
value={modelOne}
placeholder={t('modelManager.selectModel')}
data={modelOneList}
onChange={(v) => setModelOne(v)}
/>
<IAISelect
<IAIMantineSelect
label={t('modelManager.modelTwo')}
validValues={modelTwoList}
onChange={(e) => setModelTwo(e.target.value)}
w="100%"
placeholder={t('modelManager.selectModel')}
value={modelTwo}
data={modelTwoList}
onChange={(v) => setModelTwo(v)}
/>
<IAISelect
<IAIMantineSelect
label={t('modelManager.modelThree')}
validValues={modelThreeList}
onChange={(e) => {
if (e.target.value !== 'none') {
setModelThree(e.target.value);
data={modelThreeList}
w="100%"
placeholder={t('modelManager.selectModel')}
clearable
onChange={(v) => {
if (!v) {
setModelThree(null);
setModelMergeInterp('add_difference');
} else {
setModelThree('none');
setModelThree(v);
setModelMergeInterp('weighted_sum');
}
}}
@ -136,7 +222,7 @@ export default function MergeModelsPanel() {
padding: 4,
borderRadius: 'base',
gap: 4,
bg: 'base.900',
bg: mode('base.100', 'base.800')(colorMode),
}}
>
<IAISlider
@ -161,7 +247,7 @@ export default function MergeModelsPanel() {
padding: 4,
borderRadius: 'base',
gap: 4,
bg: 'base.900',
bg: mode('base.100', 'base.800')(colorMode),
}}
>
<Text fontWeight={500} fontSize="sm" variant="subtext">
@ -174,7 +260,7 @@ export default function MergeModelsPanel() {
) => setModelMergeInterp(v)}
>
<Flex columnGap={4}>
{modelThree === 'none' ? (
{modelThree === null ? (
<>
<Radio value="weighted_sum">
<Text fontSize="sm">{t('modelManager.weightedSum')}</Text>
@ -199,7 +285,7 @@ export default function MergeModelsPanel() {
</RadioGroup>
</Flex>
<Flex
{/* <Flex
sx={{
flexDirection: 'column',
padding: 4,
@ -235,7 +321,7 @@ export default function MergeModelsPanel() {
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
/>
)}
</Flex>
</Flex> */}
<IAISimpleCheckbox
label={t('modelManager.ignoreMismatch')}
@ -246,10 +332,8 @@ export default function MergeModelsPanel() {
<IAIButton
onClick={mergeModelsHandler}
isLoading={isProcessing}
isDisabled={
modelMergeSaveLocType === 'custom' && modelMergeCustomSaveLoc === ''
}
isLoading={isLoading}
isDisabled={modelOne === null || modelTwo === null}
>
{t('modelManager.merge')}
</IAIButton>

View File

@ -7,7 +7,8 @@ import IAIButton from 'common/components/IAIButton';
import { addToast } from 'features/system/store/systemSlice';
import { useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useConvertMainModelMutation } from 'services/api/endpoints/models';
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
import { CheckpointModelConfig } from './CheckpointModelEdit';
interface ModelConvertProps {
@ -21,7 +22,7 @@ export default function ModelConvert(props: ModelConvertProps) {
const { t } = useTranslation();
const [convertModel, { isLoading, error, data }] =
useConvertMainModelMutation();
useConvertMainModelsMutation();
const [saveLocation, setSaveLocation] = useState<string>('same');
const [customSaveLocation, setCustomSaveLocation] = useState<string>('');

View File

@ -6,6 +6,7 @@ import {
ControlNetModelConfig,
LoRAModelConfig,
MainModelConfig,
MergeModelConfig,
TextualInversionModelConfig,
VaeModelConfig,
} from 'services/api/types';
@ -49,6 +50,11 @@ type ConvertMainModelQuery = {
model_name: string;
};
type MergeMainModelQuery = {
base_model: BaseModelType;
body: MergeModelConfig;
};
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
@ -143,7 +149,7 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['MainModel'],
}),
convertMainModel: build.mutation<
convertMainModels: build.mutation<
EntityState<MainModelConfigEntity>,
ConvertMainModelQuery
>({
@ -155,6 +161,19 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['MainModel'],
}),
mergeMainModels: build.mutation<
EntityState<MainModelConfigEntity>,
MergeMainModelQuery
>({
query: ({ base_model, body }) => {
return {
url: `models/merge/${base_model}`,
method: 'PUT',
body: body,
};
},
invalidatesTags: ['MainModel'],
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result, error, arg) => {
@ -300,5 +319,6 @@ export const {
useGetVaeModelsQuery,
useUpdateMainModelsMutation,
useDeleteMainModelsMutation,
useConvertMainModelMutation,
useConvertMainModelsMutation,
useMergeMainModelsMutation,
} = modelsApi;

View File

@ -50,6 +50,7 @@ export type AnyModelConfig =
| ControlNetModelConfig
| TextualInversionModelConfig
| MainModelConfig;
export type MergeModelConfig = components['schemas']['Body_merge_models'];
// Graphs
export type Graph = components['schemas']['Graph'];