mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Restore Model Merge functionality
This commit is contained in:
parent
683229e285
commit
3db1aa738c
@ -427,6 +427,7 @@
|
|||||||
"customSaveLocation": "Custom Save Location",
|
"customSaveLocation": "Custom Save Location",
|
||||||
"merge": "Merge",
|
"merge": "Merge",
|
||||||
"modelsMerged": "Models Merged",
|
"modelsMerged": "Models Merged",
|
||||||
|
"modelsMergeFailed": "Model Merge Failed",
|
||||||
"mergeModels": "Merge Models",
|
"mergeModels": "Merge Models",
|
||||||
"modelOne": "Model 1",
|
"modelOne": "Model 1",
|
||||||
"modelTwo": "Model 2",
|
"modelTwo": "Model 2",
|
||||||
@ -447,7 +448,8 @@
|
|||||||
"weightedSum": "Weighted Sum",
|
"weightedSum": "Weighted Sum",
|
||||||
"none": "none",
|
"none": "none",
|
||||||
"addDifference": "Add Difference",
|
"addDifference": "Add Difference",
|
||||||
"pickModelType": "Pick Model Type"
|
"pickModelType": "Pick Model Type",
|
||||||
|
"selectModel": "Select Model"
|
||||||
},
|
},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"general": "General",
|
"general": "General",
|
||||||
|
@ -1,35 +1,74 @@
|
|||||||
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
|
import {
|
||||||
import { RootState } from 'app/store/store';
|
Flex,
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
Radio,
|
||||||
|
RadioGroup,
|
||||||
|
Text,
|
||||||
|
Tooltip,
|
||||||
|
useColorMode,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { makeToast } from 'app/components/Toaster';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAIInput from 'common/components/IAIInput';
|
import IAIInput from 'common/components/IAIInput';
|
||||||
import IAISelect from 'common/components/IAISelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { pickBy } from 'lodash-es';
|
import { pickBy } from 'lodash-es';
|
||||||
import { useState } from 'react';
|
import { useMemo, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
import {
|
||||||
|
useGetMainModelsQuery,
|
||||||
|
useMergeMainModelsMutation,
|
||||||
|
} from 'services/api/endpoints/models';
|
||||||
|
import { BaseModelType, MergeModelConfig } from 'services/api/types';
|
||||||
|
import { mode } from 'theme/util/mode';
|
||||||
|
|
||||||
|
const baseModelTypeSelectData = [
|
||||||
|
{ label: 'Stable Diffusion 1', value: 'sd-1' },
|
||||||
|
{ label: 'Stable Diffusion 2', value: 'sd-2' },
|
||||||
|
];
|
||||||
|
|
||||||
export default function MergeModelsPanel() {
|
export default function MergeModelsPanel() {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const { colorMode } = useColorMode();
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { data } = useGetMainModelsQuery();
|
const { data } = useGetMainModelsQuery();
|
||||||
|
|
||||||
const diffusersModels = pickBy(
|
const [mergeModels, { isLoading, error, data: mergedModelData }] =
|
||||||
|
useMergeMainModelsMutation();
|
||||||
|
|
||||||
|
const [baseModel, setBaseModel] = useState<BaseModelType>('sd-1');
|
||||||
|
|
||||||
|
const sd1DiffusersModels = pickBy(
|
||||||
data?.entities,
|
data?.entities,
|
||||||
(value, _) => value?.model_format === 'diffusers'
|
(value, _) =>
|
||||||
|
value?.model_format === 'diffusers' && value?.base_model === 'sd-1'
|
||||||
);
|
);
|
||||||
|
|
||||||
const [modelOne, setModelOne] = useState<string>(
|
const sd2DiffusersModels = pickBy(
|
||||||
Object.keys(diffusersModels)[0]
|
data?.entities,
|
||||||
|
(value, _) =>
|
||||||
|
value?.model_format === 'diffusers' && value?.base_model === 'sd-2'
|
||||||
);
|
);
|
||||||
const [modelTwo, setModelTwo] = useState<string>(
|
|
||||||
Object.keys(diffusersModels)[1]
|
const modelsMap = useMemo(() => {
|
||||||
|
return {
|
||||||
|
'sd-1': sd1DiffusersModels,
|
||||||
|
'sd-2': sd2DiffusersModels,
|
||||||
|
};
|
||||||
|
}, [sd1DiffusersModels, sd2DiffusersModels]);
|
||||||
|
|
||||||
|
const [modelOne, setModelOne] = useState<string | null>(
|
||||||
|
Object.keys(modelsMap[baseModel])[0]
|
||||||
);
|
);
|
||||||
const [modelThree, setModelThree] = useState<string>('none');
|
const [modelTwo, setModelTwo] = useState<string | null>(
|
||||||
|
Object.keys(modelsMap[baseModel])[1]
|
||||||
|
);
|
||||||
|
|
||||||
|
const [modelThree, setModelThree] = useState<string | null>(null);
|
||||||
|
|
||||||
const [mergedModelName, setMergedModelName] = useState<string>('');
|
const [mergedModelName, setMergedModelName] = useState<string>('');
|
||||||
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
|
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
|
||||||
@ -47,41 +86,72 @@ export default function MergeModelsPanel() {
|
|||||||
|
|
||||||
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
|
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
|
||||||
|
|
||||||
const modelOneList = Object.keys(diffusersModels).filter(
|
const modelOneList = Object.keys(
|
||||||
(model) => model !== modelTwo && model !== modelThree
|
modelsMap[baseModel as keyof typeof modelsMap]
|
||||||
|
).filter((model) => model !== modelTwo && model !== modelThree);
|
||||||
|
|
||||||
|
const modelTwoList = Object.keys(
|
||||||
|
modelsMap[baseModel as keyof typeof modelsMap]
|
||||||
|
).filter((model) => model !== modelOne && model !== modelThree);
|
||||||
|
|
||||||
|
const modelThreeList = Object.keys(modelsMap[baseModel]).filter(
|
||||||
|
(model) => model !== modelOne && model !== modelTwo
|
||||||
);
|
);
|
||||||
|
|
||||||
const modelTwoList = Object.keys(diffusersModels).filter(
|
const handleBaseModelChange = (v: string) => {
|
||||||
(model) => model !== modelOne && model !== modelThree
|
setBaseModel(v as BaseModelType);
|
||||||
);
|
setModelOne(null);
|
||||||
|
setModelTwo(null);
|
||||||
const modelThreeList = [
|
};
|
||||||
{ key: t('modelManager.none'), value: 'none' },
|
|
||||||
...Object.keys(diffusersModels)
|
|
||||||
.filter((model) => model !== modelOne && model !== modelTwo)
|
|
||||||
.map((model) => ({ key: model, value: model })),
|
|
||||||
];
|
|
||||||
|
|
||||||
const isProcessing = useAppSelector(
|
|
||||||
(state: RootState) => state.system.isProcessing
|
|
||||||
);
|
|
||||||
|
|
||||||
const mergeModelsHandler = () => {
|
const mergeModelsHandler = () => {
|
||||||
let modelsToMerge: string[] = [modelOne, modelTwo, modelThree];
|
const models_names: string[] = [];
|
||||||
modelsToMerge = modelsToMerge.filter((model) => model !== 'none');
|
|
||||||
|
|
||||||
const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = {
|
let modelsToMerge: (string | null)[] = [modelOne, modelTwo, modelThree];
|
||||||
models_to_merge: modelsToMerge,
|
modelsToMerge = modelsToMerge.filter((model) => model !== null);
|
||||||
|
modelsToMerge.forEach((model) => {
|
||||||
|
if (model) {
|
||||||
|
models_names.push(model?.split('/')[2]);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const mergeModelsInfo: MergeModelConfig = {
|
||||||
|
model_names: models_names,
|
||||||
merged_model_name:
|
merged_model_name:
|
||||||
mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'),
|
mergedModelName !== '' ? mergedModelName : models_names.join('-'),
|
||||||
alpha: modelMergeAlpha,
|
alpha: modelMergeAlpha,
|
||||||
interp: modelMergeInterp,
|
interp: modelMergeInterp,
|
||||||
model_merge_save_path:
|
// model_merge_save_path:
|
||||||
modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
|
// modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
|
||||||
force: modelMergeForce,
|
force: modelMergeForce,
|
||||||
};
|
};
|
||||||
|
|
||||||
dispatch(mergeDiffusersModels(mergeModelsInfo));
|
mergeModels({
|
||||||
|
base_model: baseModel,
|
||||||
|
body: mergeModelsInfo,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('modelManager.modelsMergeFailed'),
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mergedModelData) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('modelManager.modelsMerged'),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -90,7 +160,6 @@ export default function MergeModelsPanel() {
|
|||||||
sx={{
|
sx={{
|
||||||
flexDirection: 'column',
|
flexDirection: 'column',
|
||||||
rowGap: 1,
|
rowGap: 1,
|
||||||
bg: 'base.900',
|
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
|
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
|
||||||
@ -98,26 +167,43 @@ export default function MergeModelsPanel() {
|
|||||||
{t('modelManager.modelMergeHeaderHelp2')}
|
{t('modelManager.modelMergeHeaderHelp2')}
|
||||||
</Text>
|
</Text>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
<Flex columnGap={4}>
|
<Flex columnGap={4}>
|
||||||
<IAISelect
|
<IAIMantineSelect
|
||||||
|
label="Model Type"
|
||||||
|
w="100%"
|
||||||
|
data={baseModelTypeSelectData}
|
||||||
|
value={baseModel}
|
||||||
|
onChange={handleBaseModelChange}
|
||||||
|
/>
|
||||||
|
<IAIMantineSelect
|
||||||
label={t('modelManager.modelOne')}
|
label={t('modelManager.modelOne')}
|
||||||
validValues={modelOneList}
|
w="100%"
|
||||||
onChange={(e) => setModelOne(e.target.value)}
|
value={modelOne}
|
||||||
|
placeholder={t('modelManager.selectModel')}
|
||||||
|
data={modelOneList}
|
||||||
|
onChange={(v) => setModelOne(v)}
|
||||||
/>
|
/>
|
||||||
<IAISelect
|
<IAIMantineSelect
|
||||||
label={t('modelManager.modelTwo')}
|
label={t('modelManager.modelTwo')}
|
||||||
validValues={modelTwoList}
|
w="100%"
|
||||||
onChange={(e) => setModelTwo(e.target.value)}
|
placeholder={t('modelManager.selectModel')}
|
||||||
|
value={modelTwo}
|
||||||
|
data={modelTwoList}
|
||||||
|
onChange={(v) => setModelTwo(v)}
|
||||||
/>
|
/>
|
||||||
<IAISelect
|
<IAIMantineSelect
|
||||||
label={t('modelManager.modelThree')}
|
label={t('modelManager.modelThree')}
|
||||||
validValues={modelThreeList}
|
data={modelThreeList}
|
||||||
onChange={(e) => {
|
w="100%"
|
||||||
if (e.target.value !== 'none') {
|
placeholder={t('modelManager.selectModel')}
|
||||||
setModelThree(e.target.value);
|
clearable
|
||||||
|
onChange={(v) => {
|
||||||
|
if (!v) {
|
||||||
|
setModelThree(null);
|
||||||
setModelMergeInterp('add_difference');
|
setModelMergeInterp('add_difference');
|
||||||
} else {
|
} else {
|
||||||
setModelThree('none');
|
setModelThree(v);
|
||||||
setModelMergeInterp('weighted_sum');
|
setModelMergeInterp('weighted_sum');
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
@ -136,7 +222,7 @@ export default function MergeModelsPanel() {
|
|||||||
padding: 4,
|
padding: 4,
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
gap: 4,
|
gap: 4,
|
||||||
bg: 'base.900',
|
bg: mode('base.100', 'base.800')(colorMode),
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
@ -161,7 +247,7 @@ export default function MergeModelsPanel() {
|
|||||||
padding: 4,
|
padding: 4,
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
gap: 4,
|
gap: 4,
|
||||||
bg: 'base.900',
|
bg: mode('base.100', 'base.800')(colorMode),
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Text fontWeight={500} fontSize="sm" variant="subtext">
|
<Text fontWeight={500} fontSize="sm" variant="subtext">
|
||||||
@ -174,7 +260,7 @@ export default function MergeModelsPanel() {
|
|||||||
) => setModelMergeInterp(v)}
|
) => setModelMergeInterp(v)}
|
||||||
>
|
>
|
||||||
<Flex columnGap={4}>
|
<Flex columnGap={4}>
|
||||||
{modelThree === 'none' ? (
|
{modelThree === null ? (
|
||||||
<>
|
<>
|
||||||
<Radio value="weighted_sum">
|
<Radio value="weighted_sum">
|
||||||
<Text fontSize="sm">{t('modelManager.weightedSum')}</Text>
|
<Text fontSize="sm">{t('modelManager.weightedSum')}</Text>
|
||||||
@ -199,7 +285,7 @@ export default function MergeModelsPanel() {
|
|||||||
</RadioGroup>
|
</RadioGroup>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
<Flex
|
{/* <Flex
|
||||||
sx={{
|
sx={{
|
||||||
flexDirection: 'column',
|
flexDirection: 'column',
|
||||||
padding: 4,
|
padding: 4,
|
||||||
@ -235,7 +321,7 @@ export default function MergeModelsPanel() {
|
|||||||
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
|
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex> */}
|
||||||
|
|
||||||
<IAISimpleCheckbox
|
<IAISimpleCheckbox
|
||||||
label={t('modelManager.ignoreMismatch')}
|
label={t('modelManager.ignoreMismatch')}
|
||||||
@ -246,10 +332,8 @@ export default function MergeModelsPanel() {
|
|||||||
|
|
||||||
<IAIButton
|
<IAIButton
|
||||||
onClick={mergeModelsHandler}
|
onClick={mergeModelsHandler}
|
||||||
isLoading={isProcessing}
|
isLoading={isLoading}
|
||||||
isDisabled={
|
isDisabled={modelOne === null || modelTwo === null}
|
||||||
modelMergeSaveLocType === 'custom' && modelMergeCustomSaveLoc === ''
|
|
||||||
}
|
|
||||||
>
|
>
|
||||||
{t('modelManager.merge')}
|
{t('modelManager.merge')}
|
||||||
</IAIButton>
|
</IAIButton>
|
||||||
|
@ -7,7 +7,8 @@ import IAIButton from 'common/components/IAIButton';
|
|||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useConvertMainModelMutation } from 'services/api/endpoints/models';
|
|
||||||
|
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
|
||||||
import { CheckpointModelConfig } from './CheckpointModelEdit';
|
import { CheckpointModelConfig } from './CheckpointModelEdit';
|
||||||
|
|
||||||
interface ModelConvertProps {
|
interface ModelConvertProps {
|
||||||
@ -21,7 +22,7 @@ export default function ModelConvert(props: ModelConvertProps) {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const [convertModel, { isLoading, error, data }] =
|
const [convertModel, { isLoading, error, data }] =
|
||||||
useConvertMainModelMutation();
|
useConvertMainModelsMutation();
|
||||||
|
|
||||||
const [saveLocation, setSaveLocation] = useState<string>('same');
|
const [saveLocation, setSaveLocation] = useState<string>('same');
|
||||||
const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
|
const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
|
||||||
|
@ -6,6 +6,7 @@ import {
|
|||||||
ControlNetModelConfig,
|
ControlNetModelConfig,
|
||||||
LoRAModelConfig,
|
LoRAModelConfig,
|
||||||
MainModelConfig,
|
MainModelConfig,
|
||||||
|
MergeModelConfig,
|
||||||
TextualInversionModelConfig,
|
TextualInversionModelConfig,
|
||||||
VaeModelConfig,
|
VaeModelConfig,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
@ -49,6 +50,11 @@ type ConvertMainModelQuery = {
|
|||||||
model_name: string;
|
model_name: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
type MergeMainModelQuery = {
|
||||||
|
base_model: BaseModelType;
|
||||||
|
body: MergeModelConfig;
|
||||||
|
};
|
||||||
|
|
||||||
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
});
|
});
|
||||||
@ -143,7 +149,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['MainModel'],
|
invalidatesTags: ['MainModel'],
|
||||||
}),
|
}),
|
||||||
convertMainModel: build.mutation<
|
convertMainModels: build.mutation<
|
||||||
EntityState<MainModelConfigEntity>,
|
EntityState<MainModelConfigEntity>,
|
||||||
ConvertMainModelQuery
|
ConvertMainModelQuery
|
||||||
>({
|
>({
|
||||||
@ -155,6 +161,19 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['MainModel'],
|
invalidatesTags: ['MainModel'],
|
||||||
}),
|
}),
|
||||||
|
mergeMainModels: build.mutation<
|
||||||
|
EntityState<MainModelConfigEntity>,
|
||||||
|
MergeMainModelQuery
|
||||||
|
>({
|
||||||
|
query: ({ base_model, body }) => {
|
||||||
|
return {
|
||||||
|
url: `models/merge/${base_model}`,
|
||||||
|
method: 'PUT',
|
||||||
|
body: body,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
invalidatesTags: ['MainModel'],
|
||||||
|
}),
|
||||||
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
||||||
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
|
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
|
||||||
providesTags: (result, error, arg) => {
|
providesTags: (result, error, arg) => {
|
||||||
@ -300,5 +319,6 @@ export const {
|
|||||||
useGetVaeModelsQuery,
|
useGetVaeModelsQuery,
|
||||||
useUpdateMainModelsMutation,
|
useUpdateMainModelsMutation,
|
||||||
useDeleteMainModelsMutation,
|
useDeleteMainModelsMutation,
|
||||||
useConvertMainModelMutation,
|
useConvertMainModelsMutation,
|
||||||
|
useMergeMainModelsMutation,
|
||||||
} = modelsApi;
|
} = modelsApi;
|
||||||
|
@ -50,6 +50,7 @@ export type AnyModelConfig =
|
|||||||
| ControlNetModelConfig
|
| ControlNetModelConfig
|
||||||
| TextualInversionModelConfig
|
| TextualInversionModelConfig
|
||||||
| MainModelConfig;
|
| MainModelConfig;
|
||||||
|
export type MergeModelConfig = components['schemas']['Body_merge_models'];
|
||||||
|
|
||||||
// Graphs
|
// Graphs
|
||||||
export type Graph = components['schemas']['Graph'];
|
export type Graph = components['schemas']['Graph'];
|
||||||
|
Loading…
Reference in New Issue
Block a user