diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 23e13fbbc7..e2e0e4ae95 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -427,6 +427,7 @@ "customSaveLocation": "Custom Save Location", "merge": "Merge", "modelsMerged": "Models Merged", + "modelsMergeFailed": "Model Merge Failed", "mergeModels": "Merge Models", "modelOne": "Model 1", "modelTwo": "Model 2", @@ -447,7 +448,8 @@ "weightedSum": "Weighted Sum", "none": "none", "addDifference": "Add Difference", - "pickModelType": "Pick Model Type" + "pickModelType": "Pick Model Type", + "selectModel": "Select Model" }, "parameters": { "general": "General", diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx index b71b5636b4..693a0130a7 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx @@ -1,35 +1,74 @@ -import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react'; -import { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { + Flex, + Radio, + RadioGroup, + Text, + Tooltip, + useColorMode, +} from '@chakra-ui/react'; +import { makeToast } from 'app/components/Toaster'; +import { useAppDispatch } from 'app/store/storeHooks'; import IAIButton from 'common/components/IAIButton'; import IAIInput from 'common/components/IAIInput'; -import IAISelect from 'common/components/IAISelect'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; import IAISlider from 'common/components/IAISlider'; +import { addToast } from 'features/system/store/systemSlice'; import { pickBy } from 'lodash-es'; -import { useState } from 'react'; +import { useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; -import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { + useGetMainModelsQuery, + useMergeMainModelsMutation, +} from 'services/api/endpoints/models'; +import { BaseModelType, MergeModelConfig } from 'services/api/types'; +import { mode } from 'theme/util/mode'; + +const baseModelTypeSelectData = [ + { label: 'Stable Diffusion 1', value: 'sd-1' }, + { label: 'Stable Diffusion 2', value: 'sd-2' }, +]; export default function MergeModelsPanel() { const { t } = useTranslation(); + const { colorMode } = useColorMode(); const dispatch = useAppDispatch(); const { data } = useGetMainModelsQuery(); - const diffusersModels = pickBy( + const [mergeModels, { isLoading, error, data: mergedModelData }] = + useMergeMainModelsMutation(); + + const [baseModel, setBaseModel] = useState('sd-1'); + + const sd1DiffusersModels = pickBy( data?.entities, - (value, _) => value?.model_format === 'diffusers' + (value, _) => + value?.model_format === 'diffusers' && value?.base_model === 'sd-1' ); - const [modelOne, setModelOne] = useState( - Object.keys(diffusersModels)[0] + const sd2DiffusersModels = pickBy( + data?.entities, + (value, _) => + value?.model_format === 'diffusers' && value?.base_model === 'sd-2' ); - const [modelTwo, setModelTwo] = useState( - Object.keys(diffusersModels)[1] + + const modelsMap = useMemo(() => { + return { + 'sd-1': sd1DiffusersModels, + 'sd-2': sd2DiffusersModels, + }; + }, [sd1DiffusersModels, sd2DiffusersModels]); + + const [modelOne, setModelOne] = useState( + Object.keys(modelsMap[baseModel])[0] ); - const [modelThree, setModelThree] = useState('none'); + const [modelTwo, setModelTwo] = useState( + Object.keys(modelsMap[baseModel])[1] + ); + + const [modelThree, setModelThree] = useState(null); const [mergedModelName, setMergedModelName] = useState(''); const [modelMergeAlpha, setModelMergeAlpha] = useState(0.5); @@ -47,41 +86,72 @@ export default function MergeModelsPanel() { const [modelMergeForce, setModelMergeForce] = useState(false); - const modelOneList = Object.keys(diffusersModels).filter( - (model) => model !== modelTwo && model !== modelThree + const modelOneList = Object.keys( + modelsMap[baseModel as keyof typeof modelsMap] + ).filter((model) => model !== modelTwo && model !== modelThree); + + const modelTwoList = Object.keys( + modelsMap[baseModel as keyof typeof modelsMap] + ).filter((model) => model !== modelOne && model !== modelThree); + + const modelThreeList = Object.keys(modelsMap[baseModel]).filter( + (model) => model !== modelOne && model !== modelTwo ); - const modelTwoList = Object.keys(diffusersModels).filter( - (model) => model !== modelOne && model !== modelThree - ); - - const modelThreeList = [ - { key: t('modelManager.none'), value: 'none' }, - ...Object.keys(diffusersModels) - .filter((model) => model !== modelOne && model !== modelTwo) - .map((model) => ({ key: model, value: model })), - ]; - - const isProcessing = useAppSelector( - (state: RootState) => state.system.isProcessing - ); + const handleBaseModelChange = (v: string) => { + setBaseModel(v as BaseModelType); + setModelOne(null); + setModelTwo(null); + }; const mergeModelsHandler = () => { - let modelsToMerge: string[] = [modelOne, modelTwo, modelThree]; - modelsToMerge = modelsToMerge.filter((model) => model !== 'none'); + const models_names: string[] = []; - const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = { - models_to_merge: modelsToMerge, + let modelsToMerge: (string | null)[] = [modelOne, modelTwo, modelThree]; + modelsToMerge = modelsToMerge.filter((model) => model !== null); + modelsToMerge.forEach((model) => { + if (model) { + models_names.push(model?.split('/')[2]); + } + }); + + const mergeModelsInfo: MergeModelConfig = { + model_names: models_names, merged_model_name: - mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'), + mergedModelName !== '' ? mergedModelName : models_names.join('-'), alpha: modelMergeAlpha, interp: modelMergeInterp, - model_merge_save_path: - modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc, + // model_merge_save_path: + // modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc, force: modelMergeForce, }; - dispatch(mergeDiffusersModels(mergeModelsInfo)); + mergeModels({ + base_model: baseModel, + body: mergeModelsInfo, + }); + + if (error) { + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelsMergeFailed'), + status: 'error', + }) + ) + ); + } + + if (mergedModelData) { + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelsMerged'), + status: 'success', + }) + ) + ); + } }; return ( @@ -90,7 +160,6 @@ export default function MergeModelsPanel() { sx={{ flexDirection: 'column', rowGap: 1, - bg: 'base.900', }} > {t('modelManager.modelMergeHeaderHelp1')} @@ -98,26 +167,43 @@ export default function MergeModelsPanel() { {t('modelManager.modelMergeHeaderHelp2')} + - + setModelOne(e.target.value)} + w="100%" + value={modelOne} + placeholder={t('modelManager.selectModel')} + data={modelOneList} + onChange={(v) => setModelOne(v)} /> - setModelTwo(e.target.value)} + w="100%" + placeholder={t('modelManager.selectModel')} + value={modelTwo} + data={modelTwoList} + onChange={(v) => setModelTwo(v)} /> - { - if (e.target.value !== 'none') { - setModelThree(e.target.value); + data={modelThreeList} + w="100%" + placeholder={t('modelManager.selectModel')} + clearable + onChange={(v) => { + if (!v) { + setModelThree(null); setModelMergeInterp('add_difference'); } else { - setModelThree('none'); + setModelThree(v); setModelMergeInterp('weighted_sum'); } }} @@ -136,7 +222,7 @@ export default function MergeModelsPanel() { padding: 4, borderRadius: 'base', gap: 4, - bg: 'base.900', + bg: mode('base.100', 'base.800')(colorMode), }} > @@ -174,7 +260,7 @@ export default function MergeModelsPanel() { ) => setModelMergeInterp(v)} > - {modelThree === 'none' ? ( + {modelThree === null ? ( <> {t('modelManager.weightedSum')} @@ -199,7 +285,7 @@ export default function MergeModelsPanel() { - setModelMergeCustomSaveLoc(e.target.value)} /> )} - + */} {t('modelManager.merge')} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx index 3e7ca7469a..63e5767585 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx @@ -7,7 +7,8 @@ import IAIButton from 'common/components/IAIButton'; import { addToast } from 'features/system/store/systemSlice'; import { useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; -import { useConvertMainModelMutation } from 'services/api/endpoints/models'; + +import { useConvertMainModelsMutation } from 'services/api/endpoints/models'; import { CheckpointModelConfig } from './CheckpointModelEdit'; interface ModelConvertProps { @@ -21,7 +22,7 @@ export default function ModelConvert(props: ModelConvertProps) { const { t } = useTranslation(); const [convertModel, { isLoading, error, data }] = - useConvertMainModelMutation(); + useConvertMainModelsMutation(); const [saveLocation, setSaveLocation] = useState('same'); const [customSaveLocation, setCustomSaveLocation] = useState(''); diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index d5f08c864c..cea586a3ff 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -6,6 +6,7 @@ import { ControlNetModelConfig, LoRAModelConfig, MainModelConfig, + MergeModelConfig, TextualInversionModelConfig, VaeModelConfig, } from 'services/api/types'; @@ -49,6 +50,11 @@ type ConvertMainModelQuery = { model_name: string; }; +type MergeMainModelQuery = { + base_model: BaseModelType; + body: MergeModelConfig; +}; + const mainModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.name.localeCompare(b.name), }); @@ -143,7 +149,7 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['MainModel'], }), - convertMainModel: build.mutation< + convertMainModels: build.mutation< EntityState, ConvertMainModelQuery >({ @@ -155,6 +161,19 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['MainModel'], }), + mergeMainModels: build.mutation< + EntityState, + MergeMainModelQuery + >({ + query: ({ base_model, body }) => { + return { + url: `models/merge/${base_model}`, + method: 'PUT', + body: body, + }; + }, + invalidatesTags: ['MainModel'], + }), getLoRAModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'lora' } }), providesTags: (result, error, arg) => { @@ -300,5 +319,6 @@ export const { useGetVaeModelsQuery, useUpdateMainModelsMutation, useDeleteMainModelsMutation, - useConvertMainModelMutation, + useConvertMainModelsMutation, + useMergeMainModelsMutation, } = modelsApi; diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts index ab8214a903..9c154cbc46 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.d.ts @@ -50,6 +50,7 @@ export type AnyModelConfig = | ControlNetModelConfig | TextualInversionModelConfig | MainModelConfig; +export type MergeModelConfig = components['schemas']['Body_merge_models']; // Graphs export type Graph = components['schemas']['Graph'];