mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Restore Model Merge functionality
This commit is contained in:
parent
683229e285
commit
3db1aa738c
@ -427,6 +427,7 @@
|
||||
"customSaveLocation": "Custom Save Location",
|
||||
"merge": "Merge",
|
||||
"modelsMerged": "Models Merged",
|
||||
"modelsMergeFailed": "Model Merge Failed",
|
||||
"mergeModels": "Merge Models",
|
||||
"modelOne": "Model 1",
|
||||
"modelTwo": "Model 2",
|
||||
@ -447,7 +448,8 @@
|
||||
"weightedSum": "Weighted Sum",
|
||||
"none": "none",
|
||||
"addDifference": "Add Difference",
|
||||
"pickModelType": "Pick Model Type"
|
||||
"pickModelType": "Pick Model Type",
|
||||
"selectModel": "Select Model"
|
||||
},
|
||||
"parameters": {
|
||||
"general": "General",
|
||||
|
@ -1,35 +1,74 @@
|
||||
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
Flex,
|
||||
Radio,
|
||||
RadioGroup,
|
||||
Text,
|
||||
Tooltip,
|
||||
useColorMode,
|
||||
} from '@chakra-ui/react';
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { pickBy } from 'lodash-es';
|
||||
import { useState } from 'react';
|
||||
import { useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import {
|
||||
useGetMainModelsQuery,
|
||||
useMergeMainModelsMutation,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { BaseModelType, MergeModelConfig } from 'services/api/types';
|
||||
import { mode } from 'theme/util/mode';
|
||||
|
||||
const baseModelTypeSelectData = [
|
||||
{ label: 'Stable Diffusion 1', value: 'sd-1' },
|
||||
{ label: 'Stable Diffusion 2', value: 'sd-2' },
|
||||
];
|
||||
|
||||
export default function MergeModelsPanel() {
|
||||
const { t } = useTranslation();
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { data } = useGetMainModelsQuery();
|
||||
|
||||
const diffusersModels = pickBy(
|
||||
const [mergeModels, { isLoading, error, data: mergedModelData }] =
|
||||
useMergeMainModelsMutation();
|
||||
|
||||
const [baseModel, setBaseModel] = useState<BaseModelType>('sd-1');
|
||||
|
||||
const sd1DiffusersModels = pickBy(
|
||||
data?.entities,
|
||||
(value, _) => value?.model_format === 'diffusers'
|
||||
(value, _) =>
|
||||
value?.model_format === 'diffusers' && value?.base_model === 'sd-1'
|
||||
);
|
||||
|
||||
const [modelOne, setModelOne] = useState<string>(
|
||||
Object.keys(diffusersModels)[0]
|
||||
const sd2DiffusersModels = pickBy(
|
||||
data?.entities,
|
||||
(value, _) =>
|
||||
value?.model_format === 'diffusers' && value?.base_model === 'sd-2'
|
||||
);
|
||||
const [modelTwo, setModelTwo] = useState<string>(
|
||||
Object.keys(diffusersModels)[1]
|
||||
|
||||
const modelsMap = useMemo(() => {
|
||||
return {
|
||||
'sd-1': sd1DiffusersModels,
|
||||
'sd-2': sd2DiffusersModels,
|
||||
};
|
||||
}, [sd1DiffusersModels, sd2DiffusersModels]);
|
||||
|
||||
const [modelOne, setModelOne] = useState<string | null>(
|
||||
Object.keys(modelsMap[baseModel])[0]
|
||||
);
|
||||
const [modelThree, setModelThree] = useState<string>('none');
|
||||
const [modelTwo, setModelTwo] = useState<string | null>(
|
||||
Object.keys(modelsMap[baseModel])[1]
|
||||
);
|
||||
|
||||
const [modelThree, setModelThree] = useState<string | null>(null);
|
||||
|
||||
const [mergedModelName, setMergedModelName] = useState<string>('');
|
||||
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
|
||||
@ -47,41 +86,72 @@ export default function MergeModelsPanel() {
|
||||
|
||||
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
|
||||
|
||||
const modelOneList = Object.keys(diffusersModels).filter(
|
||||
(model) => model !== modelTwo && model !== modelThree
|
||||
const modelOneList = Object.keys(
|
||||
modelsMap[baseModel as keyof typeof modelsMap]
|
||||
).filter((model) => model !== modelTwo && model !== modelThree);
|
||||
|
||||
const modelTwoList = Object.keys(
|
||||
modelsMap[baseModel as keyof typeof modelsMap]
|
||||
).filter((model) => model !== modelOne && model !== modelThree);
|
||||
|
||||
const modelThreeList = Object.keys(modelsMap[baseModel]).filter(
|
||||
(model) => model !== modelOne && model !== modelTwo
|
||||
);
|
||||
|
||||
const modelTwoList = Object.keys(diffusersModels).filter(
|
||||
(model) => model !== modelOne && model !== modelThree
|
||||
);
|
||||
|
||||
const modelThreeList = [
|
||||
{ key: t('modelManager.none'), value: 'none' },
|
||||
...Object.keys(diffusersModels)
|
||||
.filter((model) => model !== modelOne && model !== modelTwo)
|
||||
.map((model) => ({ key: model, value: model })),
|
||||
];
|
||||
|
||||
const isProcessing = useAppSelector(
|
||||
(state: RootState) => state.system.isProcessing
|
||||
);
|
||||
const handleBaseModelChange = (v: string) => {
|
||||
setBaseModel(v as BaseModelType);
|
||||
setModelOne(null);
|
||||
setModelTwo(null);
|
||||
};
|
||||
|
||||
const mergeModelsHandler = () => {
|
||||
let modelsToMerge: string[] = [modelOne, modelTwo, modelThree];
|
||||
modelsToMerge = modelsToMerge.filter((model) => model !== 'none');
|
||||
const models_names: string[] = [];
|
||||
|
||||
const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = {
|
||||
models_to_merge: modelsToMerge,
|
||||
let modelsToMerge: (string | null)[] = [modelOne, modelTwo, modelThree];
|
||||
modelsToMerge = modelsToMerge.filter((model) => model !== null);
|
||||
modelsToMerge.forEach((model) => {
|
||||
if (model) {
|
||||
models_names.push(model?.split('/')[2]);
|
||||
}
|
||||
});
|
||||
|
||||
const mergeModelsInfo: MergeModelConfig = {
|
||||
model_names: models_names,
|
||||
merged_model_name:
|
||||
mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'),
|
||||
mergedModelName !== '' ? mergedModelName : models_names.join('-'),
|
||||
alpha: modelMergeAlpha,
|
||||
interp: modelMergeInterp,
|
||||
model_merge_save_path:
|
||||
modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
|
||||
// model_merge_save_path:
|
||||
// modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
|
||||
force: modelMergeForce,
|
||||
};
|
||||
|
||||
dispatch(mergeDiffusersModels(mergeModelsInfo));
|
||||
mergeModels({
|
||||
base_model: baseModel,
|
||||
body: mergeModelsInfo,
|
||||
});
|
||||
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelsMergeFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (mergedModelData) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelsMerged'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
@ -90,7 +160,6 @@ export default function MergeModelsPanel() {
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
rowGap: 1,
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
|
||||
@ -98,26 +167,43 @@ export default function MergeModelsPanel() {
|
||||
{t('modelManager.modelMergeHeaderHelp2')}
|
||||
</Text>
|
||||
</Flex>
|
||||
|
||||
<Flex columnGap={4}>
|
||||
<IAISelect
|
||||
<IAIMantineSelect
|
||||
label="Model Type"
|
||||
w="100%"
|
||||
data={baseModelTypeSelectData}
|
||||
value={baseModel}
|
||||
onChange={handleBaseModelChange}
|
||||
/>
|
||||
<IAIMantineSelect
|
||||
label={t('modelManager.modelOne')}
|
||||
validValues={modelOneList}
|
||||
onChange={(e) => setModelOne(e.target.value)}
|
||||
w="100%"
|
||||
value={modelOne}
|
||||
placeholder={t('modelManager.selectModel')}
|
||||
data={modelOneList}
|
||||
onChange={(v) => setModelOne(v)}
|
||||
/>
|
||||
<IAISelect
|
||||
<IAIMantineSelect
|
||||
label={t('modelManager.modelTwo')}
|
||||
validValues={modelTwoList}
|
||||
onChange={(e) => setModelTwo(e.target.value)}
|
||||
w="100%"
|
||||
placeholder={t('modelManager.selectModel')}
|
||||
value={modelTwo}
|
||||
data={modelTwoList}
|
||||
onChange={(v) => setModelTwo(v)}
|
||||
/>
|
||||
<IAISelect
|
||||
<IAIMantineSelect
|
||||
label={t('modelManager.modelThree')}
|
||||
validValues={modelThreeList}
|
||||
onChange={(e) => {
|
||||
if (e.target.value !== 'none') {
|
||||
setModelThree(e.target.value);
|
||||
data={modelThreeList}
|
||||
w="100%"
|
||||
placeholder={t('modelManager.selectModel')}
|
||||
clearable
|
||||
onChange={(v) => {
|
||||
if (!v) {
|
||||
setModelThree(null);
|
||||
setModelMergeInterp('add_difference');
|
||||
} else {
|
||||
setModelThree('none');
|
||||
setModelThree(v);
|
||||
setModelMergeInterp('weighted_sum');
|
||||
}
|
||||
}}
|
||||
@ -136,7 +222,7 @@ export default function MergeModelsPanel() {
|
||||
padding: 4,
|
||||
borderRadius: 'base',
|
||||
gap: 4,
|
||||
bg: 'base.900',
|
||||
bg: mode('base.100', 'base.800')(colorMode),
|
||||
}}
|
||||
>
|
||||
<IAISlider
|
||||
@ -161,7 +247,7 @@ export default function MergeModelsPanel() {
|
||||
padding: 4,
|
||||
borderRadius: 'base',
|
||||
gap: 4,
|
||||
bg: 'base.900',
|
||||
bg: mode('base.100', 'base.800')(colorMode),
|
||||
}}
|
||||
>
|
||||
<Text fontWeight={500} fontSize="sm" variant="subtext">
|
||||
@ -174,7 +260,7 @@ export default function MergeModelsPanel() {
|
||||
) => setModelMergeInterp(v)}
|
||||
>
|
||||
<Flex columnGap={4}>
|
||||
{modelThree === 'none' ? (
|
||||
{modelThree === null ? (
|
||||
<>
|
||||
<Radio value="weighted_sum">
|
||||
<Text fontSize="sm">{t('modelManager.weightedSum')}</Text>
|
||||
@ -199,7 +285,7 @@ export default function MergeModelsPanel() {
|
||||
</RadioGroup>
|
||||
</Flex>
|
||||
|
||||
<Flex
|
||||
{/* <Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
padding: 4,
|
||||
@ -235,7 +321,7 @@ export default function MergeModelsPanel() {
|
||||
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex> */}
|
||||
|
||||
<IAISimpleCheckbox
|
||||
label={t('modelManager.ignoreMismatch')}
|
||||
@ -246,10 +332,8 @@ export default function MergeModelsPanel() {
|
||||
|
||||
<IAIButton
|
||||
onClick={mergeModelsHandler}
|
||||
isLoading={isProcessing}
|
||||
isDisabled={
|
||||
modelMergeSaveLocType === 'custom' && modelMergeCustomSaveLoc === ''
|
||||
}
|
||||
isLoading={isLoading}
|
||||
isDisabled={modelOne === null || modelTwo === null}
|
||||
>
|
||||
{t('modelManager.merge')}
|
||||
</IAIButton>
|
||||
|
@ -7,7 +7,8 @@ import IAIButton from 'common/components/IAIButton';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useConvertMainModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import { CheckpointModelConfig } from './CheckpointModelEdit';
|
||||
|
||||
interface ModelConvertProps {
|
||||
@ -21,7 +22,7 @@ export default function ModelConvert(props: ModelConvertProps) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [convertModel, { isLoading, error, data }] =
|
||||
useConvertMainModelMutation();
|
||||
useConvertMainModelsMutation();
|
||||
|
||||
const [saveLocation, setSaveLocation] = useState<string>('same');
|
||||
const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
|
||||
|
@ -6,6 +6,7 @@ import {
|
||||
ControlNetModelConfig,
|
||||
LoRAModelConfig,
|
||||
MainModelConfig,
|
||||
MergeModelConfig,
|
||||
TextualInversionModelConfig,
|
||||
VaeModelConfig,
|
||||
} from 'services/api/types';
|
||||
@ -49,6 +50,11 @@ type ConvertMainModelQuery = {
|
||||
model_name: string;
|
||||
};
|
||||
|
||||
type MergeMainModelQuery = {
|
||||
base_model: BaseModelType;
|
||||
body: MergeModelConfig;
|
||||
};
|
||||
|
||||
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
@ -143,7 +149,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['MainModel'],
|
||||
}),
|
||||
convertMainModel: build.mutation<
|
||||
convertMainModels: build.mutation<
|
||||
EntityState<MainModelConfigEntity>,
|
||||
ConvertMainModelQuery
|
||||
>({
|
||||
@ -155,6 +161,19 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['MainModel'],
|
||||
}),
|
||||
mergeMainModels: build.mutation<
|
||||
EntityState<MainModelConfigEntity>,
|
||||
MergeMainModelQuery
|
||||
>({
|
||||
query: ({ base_model, body }) => {
|
||||
return {
|
||||
url: `models/merge/${base_model}`,
|
||||
method: 'PUT',
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['MainModel'],
|
||||
}),
|
||||
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
|
||||
providesTags: (result, error, arg) => {
|
||||
@ -300,5 +319,6 @@ export const {
|
||||
useGetVaeModelsQuery,
|
||||
useUpdateMainModelsMutation,
|
||||
useDeleteMainModelsMutation,
|
||||
useConvertMainModelMutation,
|
||||
useConvertMainModelsMutation,
|
||||
useMergeMainModelsMutation,
|
||||
} = modelsApi;
|
||||
|
@ -50,6 +50,7 @@ export type AnyModelConfig =
|
||||
| ControlNetModelConfig
|
||||
| TextualInversionModelConfig
|
||||
| MainModelConfig;
|
||||
export type MergeModelConfig = components['schemas']['Body_merge_models'];
|
||||
|
||||
// Graphs
|
||||
export type Graph = components['schemas']['Graph'];
|
||||
|
Loading…
Reference in New Issue
Block a user