feat: Restore Model Merge functionality

This commit is contained in:
blessedcoolant 2023-07-12 22:43:06 +12:00
parent 683229e285
commit 3db1aa738c
5 changed files with 172 additions and 64 deletions

View File

@ -427,6 +427,7 @@
"customSaveLocation": "Custom Save Location", "customSaveLocation": "Custom Save Location",
"merge": "Merge", "merge": "Merge",
"modelsMerged": "Models Merged", "modelsMerged": "Models Merged",
"modelsMergeFailed": "Model Merge Failed",
"mergeModels": "Merge Models", "mergeModels": "Merge Models",
"modelOne": "Model 1", "modelOne": "Model 1",
"modelTwo": "Model 2", "modelTwo": "Model 2",
@ -447,7 +448,8 @@
"weightedSum": "Weighted Sum", "weightedSum": "Weighted Sum",
"none": "none", "none": "none",
"addDifference": "Add Difference", "addDifference": "Add Difference",
"pickModelType": "Pick Model Type" "pickModelType": "Pick Model Type",
"selectModel": "Select Model"
}, },
"parameters": { "parameters": {
"general": "General", "general": "General",

View File

@ -1,35 +1,74 @@
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react'; import {
import { RootState } from 'app/store/store'; Flex,
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; Radio,
RadioGroup,
Text,
Tooltip,
useColorMode,
} from '@chakra-ui/react';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
import IAISelect from 'common/components/IAISelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { addToast } from 'features/system/store/systemSlice';
import { pickBy } from 'lodash-es'; import { pickBy } from 'lodash-es';
import { useState } from 'react'; import { useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import {
useGetMainModelsQuery,
useMergeMainModelsMutation,
} from 'services/api/endpoints/models';
import { BaseModelType, MergeModelConfig } from 'services/api/types';
import { mode } from 'theme/util/mode';
const baseModelTypeSelectData = [
{ label: 'Stable Diffusion 1', value: 'sd-1' },
{ label: 'Stable Diffusion 2', value: 'sd-2' },
];
export default function MergeModelsPanel() { export default function MergeModelsPanel() {
const { t } = useTranslation(); const { t } = useTranslation();
const { colorMode } = useColorMode();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data } = useGetMainModelsQuery(); const { data } = useGetMainModelsQuery();
const diffusersModels = pickBy( const [mergeModels, { isLoading, error, data: mergedModelData }] =
useMergeMainModelsMutation();
const [baseModel, setBaseModel] = useState<BaseModelType>('sd-1');
const sd1DiffusersModels = pickBy(
data?.entities, data?.entities,
(value, _) => value?.model_format === 'diffusers' (value, _) =>
value?.model_format === 'diffusers' && value?.base_model === 'sd-1'
); );
const [modelOne, setModelOne] = useState<string>( const sd2DiffusersModels = pickBy(
Object.keys(diffusersModels)[0] data?.entities,
(value, _) =>
value?.model_format === 'diffusers' && value?.base_model === 'sd-2'
); );
const [modelTwo, setModelTwo] = useState<string>(
Object.keys(diffusersModels)[1] const modelsMap = useMemo(() => {
return {
'sd-1': sd1DiffusersModels,
'sd-2': sd2DiffusersModels,
};
}, [sd1DiffusersModels, sd2DiffusersModels]);
const [modelOne, setModelOne] = useState<string | null>(
Object.keys(modelsMap[baseModel])[0]
); );
const [modelThree, setModelThree] = useState<string>('none'); const [modelTwo, setModelTwo] = useState<string | null>(
Object.keys(modelsMap[baseModel])[1]
);
const [modelThree, setModelThree] = useState<string | null>(null);
const [mergedModelName, setMergedModelName] = useState<string>(''); const [mergedModelName, setMergedModelName] = useState<string>('');
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5); const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
@ -47,41 +86,72 @@ export default function MergeModelsPanel() {
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false); const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
const modelOneList = Object.keys(diffusersModels).filter( const modelOneList = Object.keys(
(model) => model !== modelTwo && model !== modelThree modelsMap[baseModel as keyof typeof modelsMap]
).filter((model) => model !== modelTwo && model !== modelThree);
const modelTwoList = Object.keys(
modelsMap[baseModel as keyof typeof modelsMap]
).filter((model) => model !== modelOne && model !== modelThree);
const modelThreeList = Object.keys(modelsMap[baseModel]).filter(
(model) => model !== modelOne && model !== modelTwo
); );
const modelTwoList = Object.keys(diffusersModels).filter( const handleBaseModelChange = (v: string) => {
(model) => model !== modelOne && model !== modelThree setBaseModel(v as BaseModelType);
); setModelOne(null);
setModelTwo(null);
const modelThreeList = [ };
{ key: t('modelManager.none'), value: 'none' },
...Object.keys(diffusersModels)
.filter((model) => model !== modelOne && model !== modelTwo)
.map((model) => ({ key: model, value: model })),
];
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const mergeModelsHandler = () => { const mergeModelsHandler = () => {
let modelsToMerge: string[] = [modelOne, modelTwo, modelThree]; const models_names: string[] = [];
modelsToMerge = modelsToMerge.filter((model) => model !== 'none');
const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = { let modelsToMerge: (string | null)[] = [modelOne, modelTwo, modelThree];
models_to_merge: modelsToMerge, modelsToMerge = modelsToMerge.filter((model) => model !== null);
modelsToMerge.forEach((model) => {
if (model) {
models_names.push(model?.split('/')[2]);
}
});
const mergeModelsInfo: MergeModelConfig = {
model_names: models_names,
merged_model_name: merged_model_name:
mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'), mergedModelName !== '' ? mergedModelName : models_names.join('-'),
alpha: modelMergeAlpha, alpha: modelMergeAlpha,
interp: modelMergeInterp, interp: modelMergeInterp,
model_merge_save_path: // model_merge_save_path:
modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc, // modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
force: modelMergeForce, force: modelMergeForce,
}; };
dispatch(mergeDiffusersModels(mergeModelsInfo)); mergeModels({
base_model: baseModel,
body: mergeModelsInfo,
});
if (error) {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelsMergeFailed'),
status: 'error',
})
)
);
}
if (mergedModelData) {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelsMerged'),
status: 'success',
})
)
);
}
}; };
return ( return (
@ -90,7 +160,6 @@ export default function MergeModelsPanel() {
sx={{ sx={{
flexDirection: 'column', flexDirection: 'column',
rowGap: 1, rowGap: 1,
bg: 'base.900',
}} }}
> >
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text> <Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
@ -98,26 +167,43 @@ export default function MergeModelsPanel() {
{t('modelManager.modelMergeHeaderHelp2')} {t('modelManager.modelMergeHeaderHelp2')}
</Text> </Text>
</Flex> </Flex>
<Flex columnGap={4}> <Flex columnGap={4}>
<IAISelect <IAIMantineSelect
label="Model Type"
w="100%"
data={baseModelTypeSelectData}
value={baseModel}
onChange={handleBaseModelChange}
/>
<IAIMantineSelect
label={t('modelManager.modelOne')} label={t('modelManager.modelOne')}
validValues={modelOneList} w="100%"
onChange={(e) => setModelOne(e.target.value)} value={modelOne}
placeholder={t('modelManager.selectModel')}
data={modelOneList}
onChange={(v) => setModelOne(v)}
/> />
<IAISelect <IAIMantineSelect
label={t('modelManager.modelTwo')} label={t('modelManager.modelTwo')}
validValues={modelTwoList} w="100%"
onChange={(e) => setModelTwo(e.target.value)} placeholder={t('modelManager.selectModel')}
value={modelTwo}
data={modelTwoList}
onChange={(v) => setModelTwo(v)}
/> />
<IAISelect <IAIMantineSelect
label={t('modelManager.modelThree')} label={t('modelManager.modelThree')}
validValues={modelThreeList} data={modelThreeList}
onChange={(e) => { w="100%"
if (e.target.value !== 'none') { placeholder={t('modelManager.selectModel')}
setModelThree(e.target.value); clearable
onChange={(v) => {
if (!v) {
setModelThree(null);
setModelMergeInterp('add_difference'); setModelMergeInterp('add_difference');
} else { } else {
setModelThree('none'); setModelThree(v);
setModelMergeInterp('weighted_sum'); setModelMergeInterp('weighted_sum');
} }
}} }}
@ -136,7 +222,7 @@ export default function MergeModelsPanel() {
padding: 4, padding: 4,
borderRadius: 'base', borderRadius: 'base',
gap: 4, gap: 4,
bg: 'base.900', bg: mode('base.100', 'base.800')(colorMode),
}} }}
> >
<IAISlider <IAISlider
@ -161,7 +247,7 @@ export default function MergeModelsPanel() {
padding: 4, padding: 4,
borderRadius: 'base', borderRadius: 'base',
gap: 4, gap: 4,
bg: 'base.900', bg: mode('base.100', 'base.800')(colorMode),
}} }}
> >
<Text fontWeight={500} fontSize="sm" variant="subtext"> <Text fontWeight={500} fontSize="sm" variant="subtext">
@ -174,7 +260,7 @@ export default function MergeModelsPanel() {
) => setModelMergeInterp(v)} ) => setModelMergeInterp(v)}
> >
<Flex columnGap={4}> <Flex columnGap={4}>
{modelThree === 'none' ? ( {modelThree === null ? (
<> <>
<Radio value="weighted_sum"> <Radio value="weighted_sum">
<Text fontSize="sm">{t('modelManager.weightedSum')}</Text> <Text fontSize="sm">{t('modelManager.weightedSum')}</Text>
@ -199,7 +285,7 @@ export default function MergeModelsPanel() {
</RadioGroup> </RadioGroup>
</Flex> </Flex>
<Flex {/* <Flex
sx={{ sx={{
flexDirection: 'column', flexDirection: 'column',
padding: 4, padding: 4,
@ -235,7 +321,7 @@ export default function MergeModelsPanel() {
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)} onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
/> />
)} )}
</Flex> </Flex> */}
<IAISimpleCheckbox <IAISimpleCheckbox
label={t('modelManager.ignoreMismatch')} label={t('modelManager.ignoreMismatch')}
@ -246,10 +332,8 @@ export default function MergeModelsPanel() {
<IAIButton <IAIButton
onClick={mergeModelsHandler} onClick={mergeModelsHandler}
isLoading={isProcessing} isLoading={isLoading}
isDisabled={ isDisabled={modelOne === null || modelTwo === null}
modelMergeSaveLocType === 'custom' && modelMergeCustomSaveLoc === ''
}
> >
{t('modelManager.merge')} {t('modelManager.merge')}
</IAIButton> </IAIButton>

View File

@ -7,7 +7,8 @@ import IAIButton from 'common/components/IAIButton';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { useEffect, useState } from 'react'; import { useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useConvertMainModelMutation } from 'services/api/endpoints/models';
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
import { CheckpointModelConfig } from './CheckpointModelEdit'; import { CheckpointModelConfig } from './CheckpointModelEdit';
interface ModelConvertProps { interface ModelConvertProps {
@ -21,7 +22,7 @@ export default function ModelConvert(props: ModelConvertProps) {
const { t } = useTranslation(); const { t } = useTranslation();
const [convertModel, { isLoading, error, data }] = const [convertModel, { isLoading, error, data }] =
useConvertMainModelMutation(); useConvertMainModelsMutation();
const [saveLocation, setSaveLocation] = useState<string>('same'); const [saveLocation, setSaveLocation] = useState<string>('same');
const [customSaveLocation, setCustomSaveLocation] = useState<string>(''); const [customSaveLocation, setCustomSaveLocation] = useState<string>('');

View File

@ -6,6 +6,7 @@ import {
ControlNetModelConfig, ControlNetModelConfig,
LoRAModelConfig, LoRAModelConfig,
MainModelConfig, MainModelConfig,
MergeModelConfig,
TextualInversionModelConfig, TextualInversionModelConfig,
VaeModelConfig, VaeModelConfig,
} from 'services/api/types'; } from 'services/api/types';
@ -49,6 +50,11 @@ type ConvertMainModelQuery = {
model_name: string; model_name: string;
}; };
type MergeMainModelQuery = {
base_model: BaseModelType;
body: MergeModelConfig;
};
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({ const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.name.localeCompare(b.name),
}); });
@ -143,7 +149,7 @@ export const modelsApi = api.injectEndpoints({
}, },
invalidatesTags: ['MainModel'], invalidatesTags: ['MainModel'],
}), }),
convertMainModel: build.mutation< convertMainModels: build.mutation<
EntityState<MainModelConfigEntity>, EntityState<MainModelConfigEntity>,
ConvertMainModelQuery ConvertMainModelQuery
>({ >({
@ -155,6 +161,19 @@ export const modelsApi = api.injectEndpoints({
}, },
invalidatesTags: ['MainModel'], invalidatesTags: ['MainModel'],
}), }),
mergeMainModels: build.mutation<
EntityState<MainModelConfigEntity>,
MergeMainModelQuery
>({
query: ({ base_model, body }) => {
return {
url: `models/merge/${base_model}`,
method: 'PUT',
body: body,
};
},
invalidatesTags: ['MainModel'],
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({ getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }), query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
@ -300,5 +319,6 @@ export const {
useGetVaeModelsQuery, useGetVaeModelsQuery,
useUpdateMainModelsMutation, useUpdateMainModelsMutation,
useDeleteMainModelsMutation, useDeleteMainModelsMutation,
useConvertMainModelMutation, useConvertMainModelsMutation,
useMergeMainModelsMutation,
} = modelsApi; } = modelsApi;

View File

@ -50,6 +50,7 @@ export type AnyModelConfig =
| ControlNetModelConfig | ControlNetModelConfig
| TextualInversionModelConfig | TextualInversionModelConfig
| MainModelConfig; | MainModelConfig;
export type MergeModelConfig = components['schemas']['Body_merge_models'];
// Graphs // Graphs
export type Graph = components['schemas']['Graph']; export type Graph = components['schemas']['Graph'];