clean up old model manager components and endpoints

This commit is contained in:
Mary Hipp 2024-02-23 16:03:56 -05:00 committed by psychedelicious
parent 7b1b6d3235
commit cfcb68696c
35 changed files with 11 additions and 2636 deletions

View File

@ -15,8 +15,7 @@ import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynam
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
import { modelManagerPersistConfig, modelManagerSlice } from 'features/modelManager/store/modelManagerSlice';
import { modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice';
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
@ -55,7 +54,6 @@ const allReducers = {
[deleteImageModalSlice.name]: deleteImageModalSlice.reducer,
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
[loraSlice.name]: loraSlice.reducer,
[modelManagerSlice.name]: modelManagerSlice.reducer,
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
[sdxlSlice.name]: sdxlSlice.reducer,
[queueSlice.name]: queueSlice.reducer,
@ -103,7 +101,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[dynamicPromptsPersistConfig.name]: dynamicPromptsPersistConfig,
[sdxlPersistConfig.name]: sdxlPersistConfig,
[loraPersistConfig.name]: loraPersistConfig,
[modelManagerPersistConfig.name]: modelManagerPersistConfig,
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
[hrfPersistConfig.name]: hrfPersistConfig,
};

View File

@ -1,32 +0,0 @@
import type { ButtonProps } from '@invoke-ai/ui-library';
import { Button } from '@invoke-ai/ui-library';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsClockwiseBold } from 'react-icons/pi';
import { useSyncModels } from './useSyncModels';
export const SyncModelsButton = memo((props: Omit<ButtonProps, 'children'>) => {
const { t } = useTranslation();
const { syncModels, isLoading } = useSyncModels();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
if (!isSyncModelEnabled) {
return null;
}
return (
<Button
isLoading={isLoading}
onClick={syncModels}
leftIcon={<PiArrowsClockwiseBold />}
minW="max-content"
{...props}
>
{t('modelManager.syncModels')}
</Button>
);
});
SyncModelsButton.displayName = 'SyncModelsButton';

View File

@ -1,33 +0,0 @@
import type { IconButtonProps } from '@invoke-ai/ui-library';
import { IconButton } from '@invoke-ai/ui-library';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsClockwiseBold } from 'react-icons/pi';
import { useSyncModels } from './useSyncModels';
export const SyncModelsIconButton = memo((props: Omit<IconButtonProps, 'aria-label'>) => {
const { t } = useTranslation();
const { syncModels, isLoading } = useSyncModels();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
if (!isSyncModelEnabled) {
return null;
}
return (
<IconButton
icon={<PiArrowsClockwiseBold />}
tooltip={t('modelManager.syncModels')}
aria-label={t('modelManager.syncModels')}
isLoading={isLoading}
onClick={syncModels}
size="sm"
variant="ghost"
{...props}
/>
);
});
SyncModelsIconButton.displayName = 'SyncModelsIconButton';

View File

@ -1,40 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useSyncModelsMutation } from 'services/api/endpoints/models';
export const useSyncModels = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [_syncModels, { isLoading }] = useSyncModelsMutation();
const syncModels = useCallback(() => {
_syncModels()
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelsSynced')}`,
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelSyncFailed')}`,
status: 'error',
})
)
);
}
});
}, [dispatch, _syncModels, t]);
return { syncModels, isLoading };
};

View File

@ -1,47 +0,0 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
type ModelManagerState = {
_version: 1;
searchFolder: string | null;
advancedAddScanModel: string | null;
};
export const initialModelManagerState: ModelManagerState = {
_version: 1,
searchFolder: null,
advancedAddScanModel: null,
};
export const modelManagerSlice = createSlice({
name: 'modelmanager',
initialState: initialModelManagerState,
reducers: {
setSearchFolder: (state, action: PayloadAction<string | null>) => {
state.searchFolder = action.payload;
},
setAdvancedAddScanModel: (state, action: PayloadAction<string | null>) => {
state.advancedAddScanModel = action.payload;
},
},
});
export const { setSearchFolder, setAdvancedAddScanModel } = modelManagerSlice.actions;
export const selectModelManagerSlice = (state: RootState) => state.modelmanager;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migrateModelManagerState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const modelManagerPersistConfig: PersistConfig<ModelManagerState> = {
name: modelManagerSlice.name,
initialState: initialModelManagerState,
migrate: migrateModelManagerState,
persistDenylist: [],
};

View File

@ -1,35 +0,0 @@
import { Button, ButtonGroup, Flex, Text } from '@invoke-ai/ui-library';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetModelImportsQuery } from 'services/api/endpoints/models';
import AdvancedAddModels from './AdvancedAddModels';
import SimpleAddModels from './SimpleAddModels';
const AddModels = () => {
const { t } = useTranslation();
const [addModelMode, setAddModelMode] = useState<'simple' | 'advanced'>('simple');
const handleAddModelSimple = useCallback(() => setAddModelMode('simple'), []);
const handleAddModelAdvanced = useCallback(() => setAddModelMode('advanced'), []);
const { data } = useGetModelImportsQuery();
console.log({ data });
return (
<Flex flexDirection="column" width="100%" overflow="scroll" maxHeight={window.innerHeight - 250} gap={4}>
<ButtonGroup>
<Button size="sm" isChecked={addModelMode === 'simple'} onClick={handleAddModelSimple}>
{t('common.simple')}
</Button>
<Button size="sm" isChecked={addModelMode === 'advanced'} onClick={handleAddModelAdvanced}>
{t('common.advanced')}
</Button>
</ButtonGroup>
<Flex p={4} borderRadius={4} bg="base.800">
{addModelMode === 'simple' && <SimpleAddModels />}
{addModelMode === 'advanced' && <AdvancedAddModels />}
</Flex>
<Flex>{data?.map((model) => <Text key={model.id}>{model.status}</Text>)}</Flex>
</Flex>
);
};
export default memo(AddModels);

View File

@ -1,168 +0,0 @@
import { Button, Checkbox, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { setAdvancedAddScanModel } from 'features/modelManager/store/modelManagerSlice';
import BaseModelSelect from 'features/modelManager/subpanels/shared/BaseModelSelect';
import CheckpointConfigsSelect from 'features/modelManager/subpanels/shared/CheckpointConfigsSelect';
import ModelVariantSelect from 'features/modelManager/subpanels/shared/ModelVariantSelect';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import type { CSSProperties, FocusEventHandler } from 'react';
import { memo, useCallback, useState } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/types';
import { getModelName } from './util';
type AdvancedAddCheckpointProps = {
model_path?: string;
};
const AdvancedAddCheckpoint = (props: AdvancedAddCheckpointProps) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { model_path } = props;
const {
register,
handleSubmit,
control,
getValues,
setValue,
formState: { errors },
reset,
} = useForm<CheckpointModelConfig>({
defaultValues: {
model_name: model_path ? getModelName(model_path) : '',
base_model: 'sd-1',
model_type: 'main',
path: model_path ? model_path : '',
description: '',
model_format: 'checkpoint',
error: undefined,
vae: '',
variant: 'normal',
config: 'configs\\stable-diffusion\\v1-inference.yaml',
},
mode: 'onChange',
});
const [addMainModel] = useAddMainModelsMutation();
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
const onSubmit = useCallback<SubmitHandler<CheckpointModelConfig>>(
(values) => {
addMainModel({
body: values,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelAdded', {
modelName: values.model_name,
}),
status: 'success',
})
)
);
reset();
// Close Advanced Panel in Scan Models tab
if (model_path) {
dispatch(setAdvancedAddScanModel(null));
}
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddFailed'),
status: 'error',
})
)
);
}
});
},
[addMainModel, dispatch, model_path, reset, t]
);
const onBlur: FocusEventHandler<HTMLInputElement> = useCallback(
(e) => {
if (getValues().model_name === '') {
const modelName = getModelName(e.currentTarget.value);
if (modelName) {
setValue('model_name', modelName as string);
}
}
},
[getValues, setValue]
);
const handleChangeUseCustomConfig = useCallback(() => setUseCustomConfig((prev) => !prev), []);
return (
<form onSubmit={handleSubmit(onSubmit)} style={formStyles}>
<Flex flexDirection="column" gap={2}>
<FormControl isInvalid={Boolean(errors.model_name)}>
<FormLabel>{t('modelManager.model')}</FormLabel>
<Input
{...register('model_name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
})}
/>
{errors.model_name?.message && <FormErrorMessage>{errors.model_name?.message}</FormErrorMessage>}
</FormControl>
<BaseModelSelect<CheckpointModelConfig> control={control} name="base_model" />
<FormControl isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
onBlur,
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</FormControl>
<FormControl>
<FormLabel>{t('modelManager.description')}</FormLabel>
<Input {...register('description')} />
</FormControl>
<FormControl>
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</FormControl>
<ModelVariantSelect<CheckpointModelConfig> control={control} name="variant" />
<Flex flexDirection="column" width="100%" gap={2}>
{!useCustomConfig ? (
<CheckpointConfigsSelect control={control} name="config" />
) : (
<FormControl isRequired>
<FormLabel>{t('modelManager.config')}</FormLabel>
<Input {...register('config')} />
</FormControl>
)}
<FormControl>
<FormLabel>{t('modelManager.useCustomConfig')}</FormLabel>
<Checkbox isChecked={useCustomConfig} onChange={handleChangeUseCustomConfig} />
</FormControl>
<Button mt={2} type="submit">
{t('modelManager.addModel')}
</Button>
</Flex>
</Flex>
</form>
);
};
const formStyles: CSSProperties = {
width: '100%',
};
export default memo(AdvancedAddCheckpoint);

View File

@ -1,148 +0,0 @@
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { setAdvancedAddScanModel } from 'features/modelManager/store/modelManagerSlice';
import BaseModelSelect from 'features/modelManager/subpanels/shared/BaseModelSelect';
import ModelVariantSelect from 'features/modelManager/subpanels/shared/ModelVariantSelect';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import type { CSSProperties, FocusEventHandler } from 'react';
import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
import type { DiffusersModelConfig } from 'services/api/types';
import { getModelName } from './util';
type AdvancedAddDiffusersProps = {
model_path?: string;
};
const AdvancedAddDiffusers = (props: AdvancedAddDiffusersProps) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { model_path } = props;
const [addMainModel] = useAddMainModelsMutation();
const {
register,
handleSubmit,
control,
getValues,
setValue,
formState: { errors },
reset,
} = useForm<DiffusersModelConfig>({
defaultValues: {
model_name: model_path ? getModelName(model_path, false) : '',
base_model: 'sd-1',
model_type: 'main',
path: model_path ? model_path : '',
description: '',
model_format: 'diffusers',
error: undefined,
vae: '',
variant: 'normal',
},
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
(values) => {
addMainModel({
body: values,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelAdded', {
modelName: values.model_name,
}),
status: 'success',
})
)
);
reset();
// Close Advanced Panel in Scan Models tab
if (model_path) {
dispatch(setAdvancedAddScanModel(null));
}
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddFailed'),
status: 'error',
})
)
);
}
});
},
[addMainModel, dispatch, model_path, reset, t]
);
const onBlur: FocusEventHandler<HTMLInputElement> = useCallback(
(e) => {
if (getValues().model_name === '') {
const modelName = getModelName(e.currentTarget.value, false);
if (modelName) {
setValue('model_name', modelName as string);
}
}
},
[getValues, setValue]
);
return (
<form onSubmit={handleSubmit(onSubmit)} style={formStyles}>
<Flex flexDirection="column" gap={2}>
<FormControl isInvalid={Boolean(errors.model_name)}>
<FormLabel>{t('modelManager.name')}</FormLabel>
<Input
{...register('model_name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
})}
/>
{errors.model_name?.message && <FormErrorMessage>{errors.model_name?.message}</FormErrorMessage>}
</FormControl>
<BaseModelSelect<DiffusersModelConfig> control={control} name="base_model" />
<FormControl isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
onBlur,
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</FormControl>
<FormControl>
<FormLabel>{t('modelManager.description')}</FormLabel>
<Input {...register('description')} />
</FormControl>
<FormControl>
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</FormControl>
<ModelVariantSelect<DiffusersModelConfig> control={control} name="variant" />
<Button mt={2} type="submit">
{t('modelManager.addModel')}
</Button>
</Flex>
</form>
);
};
const formStyles: CSSProperties = {
width: '100%',
};
export default memo(AdvancedAddDiffusers);

View File

@ -1,50 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { z } from 'zod';
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
import AdvancedAddDiffusers from './AdvancedAddDiffusers';
export const zManualAddMode = z.enum(['diffusers', 'checkpoint']);
export type ManualAddMode = z.infer<typeof zManualAddMode>;
export const isManualAddMode = (v: unknown): v is ManualAddMode => zManualAddMode.safeParse(v).success;
const AdvancedAddModels = () => {
const [advancedAddMode, setAdvancedAddMode] = useState<ManualAddMode>('diffusers');
const { t } = useTranslation();
const handleChange: ComboboxOnChange = useCallback((v) => {
if (!isManualAddMode(v?.value)) {
return;
}
setAdvancedAddMode(v.value);
}, []);
const options: ComboboxOption[] = useMemo(
() => [
{ label: t('modelManager.diffusersModels'), value: 'diffusers' },
{ label: t('modelManager.checkpointOrSafetensors'), value: 'checkpoint' },
],
[t]
);
const value = useMemo(() => options.find((o) => o.value === advancedAddMode), [options, advancedAddMode]);
return (
<Flex flexDirection="column" gap={4} width="100%">
<FormControl>
<FormLabel>{t('modelManager.modelType')}</FormLabel>
<Combobox value={value} options={options} onChange={handleChange} />
</FormControl>
<Flex p={4} borderRadius={4} bg="base.850">
{advancedAddMode === 'diffusers' && <AdvancedAddDiffusers />}
{advancedAddMode === 'checkpoint' && <AdvancedAddCheckpoint />}
</Flex>
</Flex>
);
};
export default memo(AdvancedAddModels);

View File

@ -1,176 +0,0 @@
import { Button, Flex, FormControl, FormLabel, Input, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { setAdvancedAddScanModel } from 'features/modelManager/store/modelManagerSlice';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { difference, forEach, intersection, map, values } from 'lodash-es';
import type { ChangeEvent, MouseEvent } from 'react';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import type { SearchFolderResponse } from 'services/api/endpoints/models';
import {
useGetMainModelsQuery,
useGetModelsInFolderQuery,
useImportMainModelsMutation,
} from 'services/api/endpoints/models';
const FoundModelsList = () => {
const searchFolder = useAppSelector((s) => s.modelmanager.searchFolder);
const [nameFilter, setNameFilter] = useState<string>('');
// Get paths of models that are already installed
const { data: installedModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
// Get all model paths from a given directory
const { foundModels, alreadyInstalled, filteredModels } = useGetModelsInFolderQuery(
{
search_path: searchFolder ? searchFolder : '',
},
{
selectFromResult: ({ data }) => {
const installedModelValues = values(installedModels?.entities);
const installedModelPaths = map(installedModelValues, 'path');
// Only select models those that aren't already installed to Invoke
const notInstalledModels = difference(data, installedModelPaths);
const alreadyInstalled = intersection(data, installedModelPaths);
return {
foundModels: data,
alreadyInstalled: foundModelsFilter(alreadyInstalled, nameFilter),
filteredModels: foundModelsFilter(notInstalledModels, nameFilter),
};
},
}
);
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const quickAddHandler = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
const model_name = e.currentTarget.id.split('\\').splice(-1)[0];
importMainModel({
body: {
location: e.currentTarget.id,
},
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `Added Model: ${model_name}`,
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddFailed'),
status: 'error',
})
)
);
}
});
},
[dispatch, importMainModel, t]
);
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setNameFilter(e.target.value);
}, []);
const handleClickSetAdvanced = useCallback((model: string) => dispatch(setAdvancedAddScanModel(model)), [dispatch]);
const renderModels = ({ models, showActions = true }: { models: string[]; showActions?: boolean }) => {
return models.map((model) => {
return (
<Flex key={model} p={4} gap={4} alignItems="center" borderRadius={4} bg="base.800">
<Flex w="full" minW="25%" flexDir="column">
<Text fontWeight="semibold">{model.split('\\').slice(-1)[0]}</Text>
<Text fontSize="sm" color="base.400">
{model}
</Text>
</Flex>
{showActions ? (
<Flex gap={2}>
<Button id={model} onClick={quickAddHandler} isLoading={isLoading}>
{t('modelManager.quickAdd')}
</Button>
<Button onClick={handleClickSetAdvanced.bind(null, model)} isLoading={isLoading}>
{t('modelManager.advanced')}
</Button>
</Flex>
) : (
<Text fontWeight="semibold" p={2} borderRadius={4} color="invokeBlue.100" bg="invokeBlue.600">
{t('common.installed')}
</Text>
)}
</Flex>
);
});
};
const renderFoundModels = () => {
if (!searchFolder) {
return null;
}
if (!foundModels || foundModels.length === 0) {
return (
<Flex w="full" h="full" justifyContent="center" alignItems="center" height={96} userSelect="none" bg="base.900">
<Text variant="subtext">{t('modelManager.noModels')}</Text>
</Flex>
);
}
return (
<Flex flexDirection="column" gap={2} w="100%" minW="50%">
<FormControl>
<FormLabel>{t('modelManager.search')}</FormLabel>
<Input onChange={handleSearchFilter} />
</FormControl>
<Flex p={2} gap={2}>
<Text fontWeight="semibold">
{t('modelManager.modelsFound')}: {foundModels.length}
</Text>
<Text fontWeight="semibold" color="invokeBlue.200">
{t('common.notInstalled')}: {filteredModels.length}
</Text>
</Flex>
<ScrollableContent>
<Flex gap={2} flexDirection="column">
{renderModels({ models: filteredModels })}
{renderModels({ models: alreadyInstalled, showActions: false })}
</Flex>
</ScrollableContent>
</Flex>
);
};
return renderFoundModels();
};
const foundModelsFilter = (data: SearchFolderResponse | undefined, nameFilter: string) => {
const filteredModels: SearchFolderResponse = [];
forEach(data, (model) => {
if (!model) {
return null;
}
if (model.includes(nameFilter)) {
filteredModels.push(model);
}
});
return filteredModels;
};
export default memo(FoundModelsList);

View File

@ -1,95 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Box, Combobox, Flex, FormControl, FormLabel, IconButton, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setAdvancedAddScanModel } from 'features/modelManager/store/modelManagerSlice';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
import AdvancedAddDiffusers from './AdvancedAddDiffusers';
import type { ManualAddMode } from './AdvancedAddModels';
import { isManualAddMode } from './AdvancedAddModels';
const ScanAdvancedAddModels = () => {
const advancedAddScanModel = useAppSelector((s) => s.modelmanager.advancedAddScanModel);
const { t } = useTranslation();
const options: ComboboxOption[] = useMemo(
() => [
{ label: t('modelManager.diffusersModels'), value: 'diffusers' },
{ label: t('modelManager.checkpointOrSafetensors'), value: 'checkpoint' },
],
[t]
);
const [advancedAddMode, setAdvancedAddMode] = useState<ManualAddMode>('diffusers');
const [isCheckpoint, setIsCheckpoint] = useState<boolean>(true);
useEffect(() => {
advancedAddScanModel && ['.ckpt', '.safetensors', '.pth', '.pt'].some((ext) => advancedAddScanModel.endsWith(ext))
? setAdvancedAddMode('checkpoint')
: setAdvancedAddMode('diffusers');
}, [advancedAddScanModel, setAdvancedAddMode, isCheckpoint]);
const dispatch = useAppDispatch();
const handleClickSetAdvanced = useCallback(() => dispatch(setAdvancedAddScanModel(null)), [dispatch]);
const handleChangeAddMode = useCallback<ComboboxOnChange>((v) => {
if (!isManualAddMode(v?.value)) {
return;
}
setAdvancedAddMode(v.value);
if (v.value === 'checkpoint') {
setIsCheckpoint(true);
} else {
setIsCheckpoint(false);
}
}, []);
const value = useMemo(() => options.find((o) => o.value === advancedAddMode), [options, advancedAddMode]);
if (!advancedAddScanModel) {
return null;
}
return (
<Box
display="flex"
flexDirection="column"
minWidth="40%"
maxHeight="calc(100vh - 300px)"
overflow="scroll"
p={4}
gap={4}
borderRadius={4}
bg="base.800"
>
<Flex justifyContent="space-between" alignItems="center">
<Text size="xl" fontWeight="semibold">
{isCheckpoint || advancedAddMode === 'checkpoint' ? 'Add Checkpoint Model' : 'Add Diffusers Model'}
</Text>
<IconButton
icon={<PiXBold />}
aria-label={t('modelManager.closeAdvanced')}
onClick={handleClickSetAdvanced}
size="sm"
/>
</Flex>
<FormControl>
<FormLabel>{t('modelManager.modelType')}</FormLabel>
<Combobox value={value} options={options} onChange={handleChangeAddMode} />
</FormControl>
{isCheckpoint ? (
<AdvancedAddCheckpoint key={advancedAddScanModel} model_path={advancedAddScanModel} />
) : (
<AdvancedAddDiffusers key={advancedAddScanModel} model_path={advancedAddScanModel} />
)}
</Box>
);
};
export default memo(ScanAdvancedAddModels);

View File

@ -1,22 +0,0 @@
import { Flex } from '@invoke-ai/ui-library';
import { memo } from 'react';
import FoundModelsList from './FoundModelsList';
import ScanAdvancedAddModels from './ScanAdvancedAddModels';
import SearchFolderForm from './SearchFolderForm';
const ScanModels = () => {
return (
<Flex flexDirection="column" w="100%" h="full" gap={4}>
<SearchFolderForm />
<Flex gap={4}>
<Flex overflow="scroll" gap={4} w="100%" h="full">
<FoundModelsList />
</Flex>
<ScanAdvancedAddModels />
</Flex>
</Flex>
);
};
export default memo(ScanModels);

View File

@ -1,103 +0,0 @@
import { Flex, IconButton, Input, Text } from '@invoke-ai/ui-library';
import { useForm } from '@mantine/form';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setAdvancedAddScanModel, setSearchFolder } from 'features/modelManager/store/modelManagerSlice';
import type { CSSProperties } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsCounterClockwiseBold, PiMagnifyingGlassBold, PiTrashSimpleBold } from 'react-icons/pi';
import { useGetModelsInFolderQuery } from 'services/api/endpoints/models';
type SearchFolderForm = {
folder: string;
};
function SearchFolderForm() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const searchFolder = useAppSelector((s) => s.modelmanager.searchFolder);
const { refetch: refetchFoundModels } = useGetModelsInFolderQuery({
search_path: searchFolder ? searchFolder : '',
});
const searchFolderForm = useForm<SearchFolderForm>({
initialValues: {
folder: '',
},
});
const searchFolderFormSubmitHandler = useCallback(
(values: SearchFolderForm) => {
dispatch(setSearchFolder(values.folder));
},
[dispatch]
);
const scanAgainHandler = useCallback(() => {
refetchFoundModels();
}, [refetchFoundModels]);
const handleClickClearCheckpointFolder = useCallback(() => {
dispatch(setSearchFolder(null));
dispatch(setAdvancedAddScanModel(null));
}, [dispatch]);
return (
<form onSubmit={searchFolderForm.onSubmit((values) => searchFolderFormSubmitHandler(values))} style={formStyles}>
<Flex w="100%" gap={2} borderRadius={4} alignItems="center">
<Flex w="100%" alignItems="center" gap={4} minH={12}>
<Text fontSize="sm" fontWeight="semibold" color="base.300" minW="max-content">
{t('common.folder')}
</Text>
{!searchFolder ? (
<Input w="100%" size="md" {...searchFolderForm.getInputProps('folder')} />
) : (
<Flex w="100%" p={2} px={4} bg="base.700" borderRadius={4} fontSize="sm" fontWeight="bold">
{searchFolder}
</Flex>
)}
</Flex>
<Flex gap={2}>
{!searchFolder ? (
<IconButton
aria-label={t('modelManager.findModels')}
tooltip={t('modelManager.findModels')}
icon={<PiMagnifyingGlassBold />}
fontSize={18}
size="sm"
type="submit"
/>
) : (
<IconButton
aria-label={t('modelManager.scanAgain')}
tooltip={t('modelManager.scanAgain')}
icon={<PiArrowsCounterClockwiseBold />}
onClick={scanAgainHandler}
fontSize={18}
size="sm"
/>
)}
<IconButton
aria-label={t('modelManager.clearCheckpointFolder')}
tooltip={t('modelManager.clearCheckpointFolder')}
icon={<PiTrashSimpleBold />}
size="sm"
onClick={handleClickClearCheckpointFolder}
isDisabled={!searchFolder}
colorScheme="red"
/>
</Flex>
</Flex>
</form>
);
}
export default memo(SearchFolderForm);
const formStyles: CSSProperties = {
width: '100%',
};

View File

@ -1,92 +0,0 @@
import type { ComboboxOption } from '@invoke-ai/ui-library';
import { Button, Combobox, Flex, FormControl, FormLabel, Input } from '@invoke-ai/ui-library';
import { useForm } from '@mantine/form';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import type { CSSProperties } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [
{ label: 'None', value: 'none' },
{ label: 'v_prediction', value: 'v_prediction' },
{ label: 'epsilon', value: 'epsilon' },
{ label: 'sample', value: 'sample' },
];
type ExtendedImportModelConfig = {
location: string;
prediction_type?: 'v_prediction' | 'epsilon' | 'sample' | 'none' | undefined;
};
const SimpleAddModels = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
const addModelForm = useForm<ExtendedImportModelConfig>({
initialValues: {
location: '',
prediction_type: undefined,
},
});
const handleAddModelSubmit = (values: ExtendedImportModelConfig) => {
const importModelResponseBody = {
config: values.prediction_type === 'none' ? undefined : values.prediction_type,
};
importMainModel({ source: values.location, config: importModelResponseBody })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
addModelForm.reset();
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
};
return (
<form onSubmit={addModelForm.onSubmit((v) => handleAddModelSubmit(v))} style={formStyles}>
<Flex flexDirection="column" width="100%" gap={4}>
<FormControl>
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
<Input placeholder={t('modelManager.simpleModelDesc')} w="100%" {...addModelForm.getInputProps('location')} />
</FormControl>
<FormControl>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<Combobox options={options} defaultValue={options[0]} {...addModelForm.getInputProps('prediction_type')} />
</FormControl>
<Button type="submit" isLoading={isLoading}>
{t('modelManager.addModel')}
</Button>
</Flex>
</form>
);
};
const formStyles: CSSProperties = {
width: '100%',
};
export default memo(SimpleAddModels);

View File

@ -1,15 +0,0 @@
export function getModelName(filepath: string, isCheckpoint: boolean = true) {
let regex;
if (isCheckpoint) {
regex = new RegExp('[^\\\\/]+(?=\\.)');
} else {
regex = new RegExp('[^\\\\/]+(?=[\\\\/]?$)');
}
const match = filepath.match(regex);
if (match) {
return match[0];
} else {
return '';
}
}

View File

@ -1,34 +0,0 @@
import { Button, ButtonGroup, Flex } from '@invoke-ai/ui-library';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import AddModels from './AddModelsPanel/AddModels';
import ScanModels from './AddModelsPanel/ScanModels';
type AddModelTabs = 'add' | 'scan';
const ImportModelsPanel = () => {
const [addModelTab, setAddModelTab] = useState<AddModelTabs>('add');
const { t } = useTranslation();
const handleClickAddTab = useCallback(() => setAddModelTab('add'), []);
const handleClickScanTab = useCallback(() => setAddModelTab('scan'), []);
return (
<Flex flexDirection="column" gap={4} h="full">
<ButtonGroup>
<Button onClick={handleClickAddTab} isChecked={addModelTab === 'add'} size="sm" width="100%">
{t('modelManager.addModel')}
</Button>
<Button onClick={handleClickScanTab} isChecked={addModelTab === 'scan'} size="sm" width="100%">
{t('modelManager.scanForModels')}
</Button>
</ButtonGroup>
{addModelTab === 'add' && <AddModels />}
{addModelTab === 'scan' && <ScanModels />}
</Flex>
);
};
export default memo(ImportModelsPanel);

View File

@ -1,352 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import {
Button,
Checkbox,
Combobox,
CompositeNumberInput,
CompositeSlider,
Flex,
FormControl,
FormHelperText,
FormLabel,
Input,
Radio,
RadioGroup,
Text,
Tooltip,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { pickBy } from 'lodash-es';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery, useMergeMainModelsMutation } from 'services/api/endpoints/models';
import type { BaseModelType, MergeModelConfig } from 'services/api/types';
const baseModelTypeSelectOptions: ComboboxOption[] = [
{ label: 'Stable Diffusion 1', value: 'sd-1' },
{ label: 'Stable Diffusion 2', value: 'sd-2' },
];
type MergeInterpolationMethods = 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
const MergeModelsPanel = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { data } = useGetMainModelsQuery(ALL_BASE_MODELS);
const [mergeModels, { isLoading }] = useMergeMainModelsMutation();
const [baseModel, setBaseModel] = useState<BaseModelType>('sd-1');
const valueBaseModel = useMemo(() => baseModelTypeSelectOptions.find((o) => o.value === baseModel), [baseModel]);
const sd1DiffusersModels = pickBy(
data?.entities,
(value, _) => value?.model_format === 'diffusers' && value?.base_model === 'sd-1'
);
const sd2DiffusersModels = pickBy(
data?.entities,
(value, _) => value?.model_format === 'diffusers' && value?.base_model === 'sd-2'
);
const modelsMap = useMemo(() => {
return {
'sd-1': sd1DiffusersModels,
'sd-2': sd2DiffusersModels,
};
}, [sd1DiffusersModels, sd2DiffusersModels]);
const [modelOne, setModelOne] = useState<string | null>(
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])?.[0] ?? null
);
const [modelTwo, setModelTwo] = useState<string | null>(
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])?.[1] ?? null
);
const [modelThree, setModelThree] = useState<string | null>(null);
const [mergedModelName, setMergedModelName] = useState<string>('');
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
const [modelMergeInterp, setModelMergeInterp] = useState<MergeInterpolationMethods>('weighted_sum');
const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState<'root' | 'custom'>('root');
const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] = useState<string>('');
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
const optionsModelOne = useMemo(
() =>
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])
.filter((model) => model !== modelTwo && model !== modelThree)
.map((model) => ({ label: model, value: model })),
[modelsMap, baseModel, modelTwo, modelThree]
);
const optionsModelTwo = useMemo(
() =>
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])
.filter((model) => model !== modelOne && model !== modelThree)
.map((model) => ({ label: model, value: model })),
[modelsMap, baseModel, modelOne, modelThree]
);
const optionsModelThree = useMemo(
() =>
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])
.filter((model) => model !== modelOne && model !== modelTwo)
.map((model) => ({ label: model, value: model })),
[modelsMap, baseModel, modelOne, modelTwo]
);
const onChangeBaseModel = useCallback<ComboboxOnChange>((v) => {
if (!v) {
return;
}
if (!(v.value === 'sd-1' || v.value === 'sd-2')) {
return;
}
setBaseModel(v.value);
setModelOne(null);
setModelTwo(null);
}, []);
const onChangeModelOne = useCallback<ComboboxOnChange>((v) => {
if (!v) {
return;
}
setModelOne(v.value);
}, []);
const onChangeModelTwo = useCallback<ComboboxOnChange>((v) => {
if (!v) {
return;
}
setModelTwo(v.value);
}, []);
const onChangeModelThree = useCallback<ComboboxOnChange>((v) => {
if (!v) {
setModelThree(null);
setModelMergeInterp('add_difference');
} else {
setModelThree(v.value);
setModelMergeInterp('weighted_sum');
}
}, []);
const valueModelOne = useMemo(() => optionsModelOne.find((o) => o.value === modelOne), [modelOne, optionsModelOne]);
const valueModelTwo = useMemo(() => optionsModelTwo.find((o) => o.value === modelTwo), [modelTwo, optionsModelTwo]);
const valueModelThree = useMemo(
() => optionsModelThree.find((o) => o.value === modelThree),
[modelThree, optionsModelThree]
);
const handleChangeMergedModelName = useCallback(
(e: ChangeEvent<HTMLInputElement>) => setMergedModelName(e.target.value),
[]
);
const handleChangeModelMergeAlpha = useCallback((v: number) => setModelMergeAlpha(v), []);
const handleResetModelMergeAlpha = useCallback(() => setModelMergeAlpha(0.5), []);
const handleChangeMergeInterp = useCallback((v: MergeInterpolationMethods) => setModelMergeInterp(v), []);
const handleChangeMergeSaveLocType = useCallback((v: 'root' | 'custom') => setModelMergeSaveLocType(v), []);
const handleChangeMergeCustomSaveLoc = useCallback(
(e: ChangeEvent<HTMLInputElement>) => setModelMergeCustomSaveLoc(e.target.value),
[]
);
const handleChangeModelMergeForce = useCallback(
(e: ChangeEvent<HTMLInputElement>) => setModelMergeForce(e.target.checked),
[]
);
const mergeModelsHandler = useCallback(() => {
const models_names: string[] = [];
let modelsToMerge: (string | null)[] = [modelOne, modelTwo, modelThree];
modelsToMerge = modelsToMerge.filter((model) => model !== null);
modelsToMerge.forEach((model) => {
const n = model?.split('/')?.[2];
if (n) {
models_names.push(n);
}
});
const mergeModelsInfo: MergeModelConfig['body'] = {
model_names: models_names,
merged_model_name: mergedModelName !== '' ? mergedModelName : models_names.join('-'),
alpha: modelMergeAlpha,
interp: modelMergeInterp,
force: modelMergeForce,
merge_dest_directory: modelMergeSaveLocType === 'root' ? undefined : modelMergeCustomSaveLoc,
};
mergeModels({
base_model: baseModel,
body: { body: mergeModelsInfo },
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelsMerged'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelsMergeFailed'),
status: 'error',
})
)
);
}
});
}, [
baseModel,
dispatch,
mergeModels,
mergedModelName,
modelMergeAlpha,
modelMergeCustomSaveLoc,
modelMergeForce,
modelMergeInterp,
modelMergeSaveLocType,
modelOne,
modelThree,
modelTwo,
t,
]);
return (
<Flex flexDir="column" gap={4}>
<Flex flexDir="column" gap={1}>
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
<Text fontSize="sm" variant="subtext">
{t('modelManager.modelMergeHeaderHelp2')}
</Text>
</Flex>
<Flex columnGap={4}>
<FormControl w="full">
<FormLabel>{t('modelManager.modelType')}</FormLabel>
<Combobox options={baseModelTypeSelectOptions} value={valueBaseModel} onChange={onChangeBaseModel} />
</FormControl>
<FormControl w="full">
<FormLabel>{t('modelManager.modelOne')}</FormLabel>
<Combobox options={optionsModelOne} value={valueModelOne} onChange={onChangeModelOne} />
</FormControl>
<FormControl w="full">
<FormLabel>{t('modelManager.modelTwo')}</FormLabel>
<Combobox options={optionsModelTwo} value={valueModelTwo} onChange={onChangeModelTwo} />
</FormControl>
<FormControl w="full">
<FormLabel>{t('modelManager.modelThree')}</FormLabel>
<Combobox options={optionsModelThree} value={valueModelThree} onChange={onChangeModelThree} isClearable />
</FormControl>
</Flex>
<FormControl>
<FormLabel>{t('modelManager.mergedModelName')}</FormLabel>
<Input value={mergedModelName} onChange={handleChangeMergedModelName} />
</FormControl>
<Flex flexDirection="column" padding={4} borderRadius="base" gap={4} bg="base.800">
<FormControl>
<FormLabel>{t('modelManager.alpha')}</FormLabel>
<CompositeSlider
min={0.01}
max={0.99}
step={0.01}
value={modelMergeAlpha}
onChange={handleChangeModelMergeAlpha}
onReset={handleResetModelMergeAlpha}
marks
/>
<CompositeNumberInput
min={0.01}
max={0.99}
step={0.01}
value={modelMergeAlpha}
onChange={handleChangeModelMergeAlpha}
onReset={handleResetModelMergeAlpha}
/>
<FormHelperText>{t('modelManager.modelMergeAlphaHelp')}</FormHelperText>
</FormControl>
</Flex>
<Flex padding={4} gap={4} borderRadius="base" bg="base.800">
<Text fontSize="sm" variant="subtext">
{t('modelManager.interpolationType')}
</Text>
<RadioGroup value={modelMergeInterp} onChange={handleChangeMergeInterp}>
<Flex columnGap={4}>
{modelThree === null ? (
<>
<Radio value="weighted_sum">
<Text fontSize="sm">{t('modelManager.weightedSum')}</Text>
</Radio>
<Radio value="sigmoid">
<Text fontSize="sm">{t('modelManager.sigmoid')}</Text>
</Radio>
<Radio value="inv_sigmoid">
<Text fontSize="sm">{t('modelManager.inverseSigmoid')}</Text>
</Radio>
</>
) : (
<Radio value="add_difference">
<Tooltip label={t('modelManager.modelMergeInterpAddDifferenceHelp')}>
<Text fontSize="sm">{t('modelManager.addDifference')}</Text>
</Tooltip>
</Radio>
)}
</Flex>
</RadioGroup>
</Flex>
<Flex flexDirection="column" padding={4} borderRadius="base" gap={4} bg="base.900">
<Flex columnGap={4}>
<Text fontSize="sm" variant="subtext">
{t('modelManager.mergedModelSaveLocation')}
</Text>
<RadioGroup value={modelMergeSaveLocType} onChange={handleChangeMergeSaveLocType}>
<Flex columnGap={4}>
<Radio value="root">
<Text fontSize="sm">{t('modelManager.invokeAIFolder')}</Text>
</Radio>
<Radio value="custom">
<Text fontSize="sm">{t('modelManager.custom')}</Text>
</Radio>
</Flex>
</RadioGroup>
</Flex>
{modelMergeSaveLocType === 'custom' && (
<FormControl>
<FormLabel>{t('modelManager.mergedModelCustomSaveLocation')}</FormLabel>
<Input value={modelMergeCustomSaveLoc} onChange={handleChangeMergeCustomSaveLoc} />
</FormControl>
)}
</Flex>
<FormControl>
<FormLabel>{t('modelManager.ignoreMismatch')}</FormLabel>
<Checkbox isChecked={modelMergeForce} onChange={handleChangeModelMergeForce} />
</FormControl>
<Button onClick={mergeModelsHandler} isLoading={isLoading} isDisabled={modelOne === null || modelTwo === null}>
{t('modelManager.merge')}
</Button>
</Flex>
);
};
export default memo(MergeModelsPanel);

View File

@ -1,63 +0,0 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { memo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
import type { DiffusersModelConfig, LoRAConfig, MainModelConfig } from 'services/api/types';
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit';
import ModelList from './ModelManagerPanel/ModelList';
const ModelManagerPanel = () => {
const [selectedModelId, setSelectedModelId] = useState<string>();
const { mainModel } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
mainModel: selectedModelId ? data?.entities[selectedModelId] : undefined,
}),
});
const { loraModel } = useGetLoRAModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
loraModel: selectedModelId ? data?.entities[selectedModelId] : undefined,
}),
});
const model = mainModel ? mainModel : loraModel;
return (
<Flex gap={8} w="full" h="full">
<ModelList selectedModelId={selectedModelId} setSelectedModelId={setSelectedModelId} />
<ModelEdit model={model} />
</Flex>
);
};
type ModelEditProps = {
model: MainModelConfig | LoRAConfig | undefined;
};
const ModelEdit = (props: ModelEditProps) => {
const { t } = useTranslation();
const { model } = props;
if (model?.format === 'checkpoint') {
return <CheckpointModelEdit key={model.key} model={model} />;
}
if (model?.format === 'diffusers') {
return <DiffusersModelEdit key={model.key} model={model as DiffusersModelConfig} />;
}
if (model?.type === 'lora') {
return <LoRAModelEdit key={model.key} model={model} />;
}
return (
<Flex w="full" h="full" justifyContent="center" alignItems="center" maxH={96} userSelect="none">
<Text variant="subtext">{t('modelManager.noModelSelected')}</Text>
</Flex>
);
};
export default memo(ModelManagerPanel);

View File

@ -1,185 +0,0 @@
import {
Badge,
Button,
Checkbox,
Divider,
Flex,
FormControl,
FormErrorMessage,
FormLabel,
Input,
Text,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import BaseModelSelect from 'features/modelManager/subpanels/shared/BaseModelSelect';
import CheckpointConfigsSelect from 'features/modelManager/subpanels/shared/CheckpointConfigsSelect';
import ModelVariantSelect from 'features/modelManager/subpanels/shared/ModelVariantSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback, useEffect, useState } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useGetCheckpointConfigsQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/types';
import ModelConvert from './ModelConvert';
type CheckpointModelEditProps = {
model: CheckpointModelConfig;
};
const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
const { model } = props;
const [updateModel, { isLoading }] = useUpdateModelsMutation();
const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery();
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
useEffect(() => {
if (!availableCheckpointConfigs?.includes(model.config)) {
setUseCustomConfig(true);
}
}, [availableCheckpointConfigs, model.config]);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const {
register,
handleSubmit,
control,
formState: { errors },
reset,
} = useForm<CheckpointModelConfig>({
defaultValues: {
name: model.name ? model.name : '',
base: model.base,
type: 'main',
path: model.path ? model.path : '',
description: model.description ? model.description : '',
format: 'checkpoint',
vae: model.vae ? model.vae : '',
config: model.config ? model.config : '',
variant: model.variant,
},
mode: 'onChange',
});
const handleChangeUseCustomConfig = useCallback(() => setUseCustomConfig((prev) => !prev), []);
const onSubmit = useCallback<SubmitHandler<CheckpointModelConfig>>(
(values) => {
const responseBody = {
key: model.key,
body: values,
};
updateModel(responseBody)
.unwrap()
.then((payload) => {
reset(payload as CheckpointModelConfig, { keepDefaultValues: true });
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((_) => {
reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[dispatch, model.key, reset, t, updateModel]
);
return (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex justifyContent="space-between" alignItems="center">
<Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold">
{model.name}
</Text>
<Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[model.base]} {t('modelManager.model')}
</Text>
</Flex>
{![''].includes(model.base) ? (
<ModelConvert model={model} />
) : (
<Badge p={2} borderRadius={4} bg="error.400">
{t('modelManager.conversionNotSupported')}
</Badge>
)}
</Flex>
<Divider />
<Flex flexDirection="column" maxHeight={window.innerHeight - 270} overflowY="scroll">
<form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<FormControl isInvalid={Boolean(errors.name)}>
<FormLabel>{t('modelManager.name')}</FormLabel>
<Input
{...register('name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
})}
/>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</FormControl>
<FormControl>
<FormLabel>{t('modelManager.description')}</FormLabel>
<Input {...register('description')} />
</FormControl>
<BaseModelSelect<CheckpointModelConfig> control={control} name="base" />
<ModelVariantSelect<CheckpointModelConfig> control={control} name="variant" />
<FormControl isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</FormControl>
<FormControl>
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</FormControl>
<Flex flexDirection="column" gap={2}>
{!useCustomConfig ? (
<CheckpointConfigsSelect control={control} name="config" />
) : (
<FormControl isRequired>
<FormLabel>{t('modelManager.config')}</FormLabel>
<Input {...register('config')} />
</FormControl>
)}
<FormControl>
<FormLabel>{t('modelManager.useCustomConfig')}</FormLabel>
<Checkbox isChecked={useCustomConfig} onChange={handleChangeUseCustomConfig} />
</FormControl>
</Flex>
<Button type="submit" isLoading={isLoading}>
{t('modelManager.updateModel')}
</Button>
</Flex>
</form>
</Flex>
</Flex>
);
};
export default memo(CheckpointModelEdit);

View File

@ -1,133 +0,0 @@
import { Button, Divider, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import BaseModelSelect from 'features/modelManager/subpanels/shared/BaseModelSelect';
import ModelVariantSelect from 'features/modelManager/subpanels/shared/ModelVariantSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useUpdateModelsMutation } from 'services/api/endpoints/models';
import type { DiffusersModelConfig } from 'services/api/types';
type DiffusersModelEditProps = {
model: DiffusersModelConfig;
};
const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
const { model } = props;
const [updateModel, { isLoading }] = useUpdateModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const {
register,
handleSubmit,
control,
formState: { errors },
reset,
} = useForm<DiffusersModelConfig>({
defaultValues: {
name: model.name ? model.name : '',
base: model.base,
type: 'main',
path: model.path ? model.path : '',
description: model.description ? model.description : '',
format: 'diffusers',
vae: model.vae ? model.vae : '',
variant: model.variant,
},
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
(values) => {
const responseBody = {
key: model.key,
body: values,
};
updateModel(responseBody)
.unwrap()
.then((payload) => {
reset(payload as DiffusersModelConfig, { keepDefaultValues: true });
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((_) => {
reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[dispatch, model.key, reset, t, updateModel]
);
return (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold">
{model.name}
</Text>
<Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[model.base]} {t('modelManager.model')}
</Text>
</Flex>
<Divider />
<form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<FormControl isInvalid={Boolean(errors.name)}>
<FormLabel>{t('modelManager.name')}</FormLabel>
<Input
{...register('name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
})}
/>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</FormControl>
<FormControl>
<FormLabel>{t('modelManager.description')}</FormLabel>
<Input {...register('description')} />
</FormControl>
<BaseModelSelect<DiffusersModelConfig> control={control} name="base" />
<ModelVariantSelect<DiffusersModelConfig> control={control} name="variant" />
<FormControl isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</FormControl>
<FormControl>
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</FormControl>
<Button type="submit" isLoading={isLoading}>
{t('modelManager.updateModel')}
</Button>
</Flex>
</form>
</Flex>
);
};
export default memo(DiffusersModelEdit);

View File

@ -1,127 +0,0 @@
import { Button, Divider, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import BaseModelSelect from 'features/modelManager/subpanels/shared/BaseModelSelect';
import { LORA_MODEL_FORMAT_MAP, MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useUpdateModelsMutation } from 'services/api/endpoints/models';
import type { LoRAModelConfig } from 'services/api/types';
type LoRAModelEditProps = {
model: LoRAModelConfig;
};
const LoRAModelEdit = (props: LoRAModelEditProps) => {
const { model } = props;
const [updateModel, { isLoading }] = useUpdateModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const {
register,
handleSubmit,
control,
formState: { errors },
reset,
} = useForm<LoRAModelConfig>({
defaultValues: {
name: model.name ? model.name : '',
base: model.base,
type: 'lora',
path: model.path ? model.path : '',
description: model.description ? model.description : '',
format: model.format,
},
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<LoRAModelConfig>>(
(values) => {
const responseBody = {
key: model.key,
body: values,
};
updateModel(responseBody)
.unwrap()
.then((payload) => {
reset(payload as LoRAModelConfig, { keepDefaultValues: true });
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((_) => {
reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[dispatch, model.key, reset, t, updateModel]
);
return (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold">
{model.name}
</Text>
<Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[model.base]} {t('modelManager.model')} {LORA_MODEL_FORMAT_MAP[model.format]}{' '}
{t('common.format')}
</Text>
</Flex>
<Divider />
<form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<FormControl isInvalid={Boolean(errors.name)}>
<FormLabel>{t('modelManager.name')}</FormLabel>
<Input
{...register('name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
})}
/>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</FormControl>
<FormControl>
<FormLabel>{t('modelManager.description')}</FormLabel>
<Input {...register('description')} />
</FormControl>
<BaseModelSelect<LoRAModelConfig> control={control} name="base" />
<FormControl isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</FormControl>
<Button type="submit" isLoading={isLoading}>
{t('modelManager.updateModel')}
</Button>
</Flex>
</form>
</Flex>
);
};
export default memo(LoRAModelEdit);

View File

@ -1,165 +0,0 @@
import {
Button,
ConfirmationAlertDialog,
Flex,
FormControl,
FormLabel,
Input,
ListItem,
Radio,
RadioGroup,
Text,
Tooltip,
UnorderedList,
useDisclosure,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/types';
interface ModelConvertProps {
model: CheckpointModelConfig;
}
type SaveLocation = 'InvokeAIRoot' | 'Custom';
const ModelConvert = (props: ModelConvertProps) => {
const { model } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
const { isOpen, onOpen, onClose } = useDisclosure();
const [saveLocation, setSaveLocation] = useState<SaveLocation>('InvokeAIRoot');
const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
useEffect(() => {
setSaveLocation('InvokeAIRoot');
}, [model]);
const modelConvertCancelHandler = useCallback(() => {
setSaveLocation('InvokeAIRoot');
}, []);
const handleChangeSaveLocation = useCallback((v: string) => {
setSaveLocation(v as SaveLocation);
}, []);
const handleChangeCustomSaveLocation = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setCustomSaveLocation(e.target.value);
}, []);
const modelConvertHandler = useCallback(() => {
const queryArg = {
base_model: model.base,
model_name: model.name,
convert_dest_directory: saveLocation === 'Custom' ? customSaveLocation : undefined,
};
if (saveLocation === 'Custom' && customSaveLocation === '') {
dispatch(
addToast(
makeToast({
title: t('modelManager.noCustomLocationProvided'),
status: 'error',
})
)
);
return;
}
dispatch(
addToast(
makeToast({
title: `${t('modelManager.convertingModelBegin')}: ${model.name}`,
status: 'info',
})
)
);
convertModel(queryArg)
.unwrap()
.then(() => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelConverted')}: ${model.name}`,
status: 'success',
})
)
);
})
.catch(() => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelConversionFailed')}: ${model.name}`,
status: 'error',
})
)
);
});
}, [convertModel, customSaveLocation, dispatch, model.base, model.name, saveLocation, t]);
return (
<>
<Button
onClick={onOpen}
size="sm"
aria-label={t('modelManager.convertToDiffusers')}
className=" modal-close-btn"
isLoading={isLoading}
>
🧨 {t('modelManager.convertToDiffusers')}
</Button>
<ConfirmationAlertDialog
title={`${t('modelManager.convert')} ${model.name}`}
acceptCallback={modelConvertHandler}
cancelCallback={modelConvertCancelHandler}
acceptButtonText={`${t('modelManager.convert')}`}
isOpen={isOpen}
onClose={onClose}
>
<Flex flexDirection="column" rowGap={4}>
<Text>{t('modelManager.convertToDiffusersHelpText1')}</Text>
<UnorderedList>
<ListItem>{t('modelManager.convertToDiffusersHelpText2')}</ListItem>
<ListItem>{t('modelManager.convertToDiffusersHelpText3')}</ListItem>
<ListItem>{t('modelManager.convertToDiffusersHelpText4')}</ListItem>
<ListItem>{t('modelManager.convertToDiffusersHelpText5')}</ListItem>
</UnorderedList>
<Text>{t('modelManager.convertToDiffusersHelpText6')}</Text>
</Flex>
<Flex flexDir="column" gap={2}>
<Flex marginTop={4} flexDir="column" gap={2}>
<Text fontWeight="semibold">{t('modelManager.convertToDiffusersSaveLocation')}</Text>
<RadioGroup value={saveLocation} onChange={handleChangeSaveLocation}>
<Flex gap={4}>
<Radio value="InvokeAIRoot">
<Tooltip label="Save converted model in the InvokeAI root folder">
{t('modelManager.invokeRoot')}
</Tooltip>
</Radio>
<Radio value="Custom">
<Tooltip label="Save converted model in a custom folder">{t('modelManager.custom')}</Tooltip>
</Radio>
</Flex>
</RadioGroup>
</Flex>
{saveLocation === 'Custom' && (
<FormControl>
<FormLabel>{t('modelManager.customSaveLocation')}</FormLabel>
<Input width="full" value={customSaveLocation} onChange={handleChangeCustomSaveLocation} />
</FormControl>
)}
</Flex>
</ConfirmationAlertDialog>
</>
);
};
export default memo(ModelConvert);

View File

@ -1,205 +0,0 @@
import { Button, ButtonGroup, Flex, FormControl, FormLabel, Input, Spinner, Text } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { forEach } from 'lodash-es';
import type { ChangeEvent, PropsWithChildren } from 'react';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
// import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
import type { LoRAConfig, MainModelConfig } from 'services/api/types';
import ModelListItem from './ModelListItem';
type ModelListProps = {
selectedModelId: string | undefined;
setSelectedModelId: (name: string | undefined) => void;
};
type ModelFormat = 'all' | 'checkpoint' | 'diffusers';
type ModelType = 'main' | 'lora';
type CombinedModelFormat = ModelFormat | 'lora';
const ModelList = (props: ModelListProps) => {
const { selectedModelId, setSelectedModelId } = props;
const { t } = useTranslation();
const [nameFilter, setNameFilter] = useState<string>('');
const [modelFormatFilter, setModelFormatFilter] = useState<CombinedModelFormat>('all');
const { filteredDiffusersModels, isLoadingDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data, isLoading }) => ({
filteredDiffusersModels: modelsFilter(data, 'main', 'diffusers', nameFilter),
isLoadingDiffusersModels: isLoading,
}),
});
const { filteredCheckpointModels, isLoadingCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data, isLoading }) => ({
filteredCheckpointModels: modelsFilter(data, 'main', 'checkpoint', nameFilter),
isLoadingCheckpointModels: isLoading,
}),
});
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(undefined, {
selectFromResult: ({ data, isLoading }) => ({
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
isLoadingLoraModels: isLoading,
}),
});
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setNameFilter(e.target.value);
}, []);
return (
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
<ButtonGroup>
<Button onClick={setModelFormatFilter.bind(null, 'all')} isChecked={modelFormatFilter === 'all'} size="sm">
{t('modelManager.allModels')}
</Button>
<Button
size="sm"
onClick={setModelFormatFilter.bind(null, 'diffusers')}
isChecked={modelFormatFilter === 'diffusers'}
>
{t('modelManager.diffusersModels')}
</Button>
<Button
size="sm"
onClick={setModelFormatFilter.bind(null, 'checkpoint')}
isChecked={modelFormatFilter === 'checkpoint'}
>
{t('modelManager.checkpointModels')}
</Button>
<Button size="sm" onClick={setModelFormatFilter.bind(null, 'lora')} isChecked={modelFormatFilter === 'lora'}>
{t('modelManager.loraModels')}
</Button>
</ButtonGroup>
<FormControl>
<FormLabel>{t('modelManager.search')}</FormLabel>
<Input onChange={handleSearchFilter} />
</FormControl>
<Flex flexDirection="column" gap={4} maxHeight={window.innerHeight - 280} overflow="scroll">
{/* Diffusers List */}
{isLoadingDiffusersModels && <FetchingModelsLoader loadingMessage="Loading Diffusers..." />}
{['all', 'diffusers'].includes(modelFormatFilter) &&
!isLoadingDiffusersModels &&
filteredDiffusersModels.length > 0 && (
<ModelListWrapper
title="Diffusers"
modelList={filteredDiffusersModels}
selected={{ selectedModelId, setSelectedModelId }}
key="diffusers"
/>
)}
{/* Checkpoints List */}
{isLoadingCheckpointModels && <FetchingModelsLoader loadingMessage="Loading Checkpoints..." />}
{['all', 'checkpoint'].includes(modelFormatFilter) &&
!isLoadingCheckpointModels &&
filteredCheckpointModels.length > 0 && (
<ModelListWrapper
title="Checkpoints"
modelList={filteredCheckpointModels}
selected={{ selectedModelId, setSelectedModelId }}
key="checkpoints"
/>
)}
{/* LoRAs List */}
{isLoadingLoraModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
{['all', 'lora'].includes(modelFormatFilter) && !isLoadingLoraModels && filteredLoraModels.length > 0 && (
<ModelListWrapper
title="LoRAs"
modelList={filteredLoraModels}
selected={{ selectedModelId, setSelectedModelId }}
key="loras"
/>
)}
</Flex>
</Flex>
</Flex>
);
};
export default memo(ModelList);
const modelsFilter = <T extends MainModelConfig | LoRAConfig>(
data: EntityState<T, string> | undefined,
model_type: ModelType,
model_format: ModelFormat | undefined,
nameFilter: string
) => {
const filteredModels: T[] = [];
forEach(data?.entities, (model) => {
if (!model) {
return;
}
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
const matchesFormat = model_format === undefined || model.format === model_format;
const matchesType = model.type === model_type;
if (matchesFilter && matchesFormat && matchesType) {
filteredModels.push(model);
}
});
return filteredModels;
};
const StyledModelContainer = memo((props: PropsWithChildren) => {
return (
<Flex flexDirection="column" gap={4} borderRadius={4} p={4} bg="base.800">
{props.children}
</Flex>
);
});
StyledModelContainer.displayName = 'StyledModelContainer';
type ModelListWrapperProps = {
title: string;
modelList: MainModelConfig[] | LoRAConfig[];
selected: ModelListProps;
};
const ModelListWrapper = memo((props: ModelListWrapperProps) => {
const { title, modelList, selected } = props;
return (
<StyledModelContainer>
<Flex gap={2} flexDir="column">
<Text variant="subtext" fontSize="sm">
{title}
</Text>
{modelList.map((model) => (
<ModelListItem
key={model.key}
model={model}
isSelected={selected.selectedModelId === model.key}
setSelectedModelId={selected.setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
);
});
ModelListWrapper.displayName = 'ModelListWrapper';
const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => {
return (
<StyledModelContainer>
<Flex justifyContent="center" alignItems="center" flexDirection="column" p={4} gap={8}>
<Spinner />
<Text variant="subtext">{loadingMessage ? loadingMessage : 'Fetching...'}</Text>
</Flex>
</StyledModelContainer>
);
});
FetchingModelsLoader.displayName = 'FetchingModelsLoader';

View File

@ -1,111 +0,0 @@
import {
Badge,
Button,
ConfirmationAlertDialog,
Flex,
IconButton,
Text,
Tooltip,
useDisclosure,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
import type { LoRAConfig, MainModelConfig } from 'services/api/types';
type ModelListItemProps = {
model: MainModelConfig | LoRAConfig;
isSelected: boolean;
setSelectedModelId: (v: string | undefined) => void;
};
const ModelListItem = (props: ModelListItemProps) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [deleteModel] = useDeleteModelsMutation();
const { isOpen, onOpen, onClose } = useDisclosure();
const { model, isSelected, setSelectedModelId } = props;
const handleSelectModel = useCallback(() => {
setSelectedModelId(model.key);
}, [model.key, setSelectedModelId]);
const handleModelDelete = useCallback(() => {
deleteModel({ key: model.key })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelDeleted')}: ${model.name}`,
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`,
status: 'error',
})
)
);
}
});
setSelectedModelId(undefined);
}, [deleteModel, model, setSelectedModelId, dispatch, t]);
return (
<Flex gap={2} alignItems="center" w="full">
<Flex
as={Button}
isChecked={isSelected}
variant={isSelected ? 'solid' : 'ghost'}
justifyContent="start"
p={2}
borderRadius="base"
w="full"
alignItems="center"
onClick={handleSelectModel}
>
<Flex gap={4} alignItems="center">
<Badge minWidth={14} p={0.5} fontSize="sm" variant="solid">
{MODEL_TYPE_SHORT_MAP[model.base as keyof typeof MODEL_TYPE_SHORT_MAP]}
</Badge>
<Tooltip label={model.description} placement="bottom">
<Text>{model.name}</Text>
</Tooltip>
</Flex>
</Flex>
<IconButton
onClick={onOpen}
icon={<PiTrashSimpleBold />}
aria-label={t('modelManager.deleteConfig')}
colorScheme="error"
/>
<ConfirmationAlertDialog
isOpen={isOpen}
onClose={onClose}
title={t('modelManager.deleteModel')}
acceptCallback={handleModelDelete}
acceptButtonText={t('modelManager.delete')}
>
<Flex rowGap={4} flexDirection="column">
<Text fontWeight="bold">{t('modelManager.deleteMsg1')}</Text>
<Text>{t('modelManager.deleteMsg2')}</Text>
</Flex>
</ConfirmationAlertDialog>
</Flex>
);
};
export default memo(ModelListItem);

View File

@ -1,14 +0,0 @@
import { Flex } from '@invoke-ai/ui-library';
import { memo } from 'react';
import SyncModels from './ModelManagerSettingsPanel/SyncModels';
const ModelManagerSettingsPanel = () => {
return (
<Flex>
<SyncModels />
</Flex>
);
};
export default memo(ModelManagerSettingsPanel);

View File

@ -1,22 +0,0 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { SyncModelsButton } from 'features/modelManager/components/SyncModels/SyncModelsButton';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const SyncModels = () => {
const { t } = useTranslation();
return (
<Flex w="full" p={4} borderRadius={4} gap={4} justifyContent="space-between" alignItems="center" bg="base.800">
<Flex flexDirection="column" gap={2}>
<Text fontWeight="semibold">{t('modelManager.syncModels')}</Text>
<Text fontSize="sm" variant="subtext">
{t('modelManager.syncModelsDesc')}
</Text>
</Flex>
<SyncModelsButton />
</Flex>
);
};
export default memo(SyncModels);

View File

@ -1,36 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
];
const BaseModelSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { t } = useTranslation();
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
return (
<FormControl>
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<Combobox value={value} options={options} onChange={onChange} />
</FormControl>
);
};
export default typedMemo(BaseModelSelect);

View File

@ -1,32 +0,0 @@
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { memo, useCallback, useMemo } from 'react';
import { useController, type UseControllerProps } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useGetCheckpointConfigsQuery } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/types';
const sx: ChakraProps['sx'] = { w: 'full' };
const CheckpointConfigsSelect = (props: UseControllerProps<CheckpointModelConfig>) => {
const { data } = useGetCheckpointConfigsQuery();
const { t } = useTranslation();
const options = useMemo<ComboboxOption[]>(() => (data ? data.map((i) => ({ label: i, value: i })) : []), [data]);
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value, options]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
return (
<FormControl>
<FormLabel>{t('modelManager.configFile')}</FormLabel>
<Combobox placeholder="Select A Config File" value={value} options={options} onChange={onChange} sx={sx} />
</FormControl>
);
};
export default memo(CheckpointConfigsSelect);

View File

@ -1,34 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { CheckpointModelConfig, DiffusersModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'normal', label: 'Normal' },
{ value: 'inpaint', label: 'Inpaint' },
{ value: 'depth', label: 'Depth' },
];
const ModelVariantSelect = <T extends CheckpointModelConfig | DiffusersModelConfig>(props: UseControllerProps<T>) => {
const { t } = useTranslation();
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
return (
<FormControl>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<Combobox value={value} options={options} onChange={onChange} />
</FormControl>
);
};
export default typedMemo(ModelVariantSelect);

View File

@ -1,7 +1,7 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton';
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';

View File

@ -1,7 +1,7 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton';
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
import type {
SDXLRefinerModelFieldInputInstance,

View File

@ -1,7 +1,7 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton';
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';

View File

@ -1,7 +1,7 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton';
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { pick } from 'lodash-es';

View File

@ -16,7 +16,7 @@ import { EMPTY_ARRAY } from 'app/store/util';
import { LoRAList } from 'features/lora/components/LoRAList';
import LoRASelect from 'features/lora/components/LoRASelect';
import { selectLoraSlice } from 'features/lora/store/loraSlice';
import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton';
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
import ParamSteps from 'features/parameters/components/Core/ParamSteps';

View File

@ -10,7 +10,6 @@ import type {
IPAdapterModelConfig,
LoRAModelConfig,
MainModelConfig,
MergeModelConfig,
T2IAdapterModelConfig,
TextualInversionModelConfig,
VAEModelConfig,
@ -38,22 +37,9 @@ type DeleteMainModelArg = {
type DeleteMainModelResponse = void;
type ConvertMainModelArg = {
base_model: BaseModelType;
model_name: string;
convert_dest_directory?: string;
};
type ConvertMainModelResponse =
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
type MergeMainModelArg = {
base_model: BaseModelType;
body: MergeModelConfig;
};
type MergeMainModelResponse = paths['/api/v2/models/merge']['put']['responses']['200']['content']['application/json'];
type ImportMainModelArg = {
source: NonNullable<operations['heuristic_install_model']['parameters']['query']['source']>;
access_token?: operations['heuristic_install_model']['parameters']['query']['access_token'];
@ -72,20 +58,11 @@ type DeleteImportModelsResponse =
type PruneModelImportsResponse =
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
type ImportAdvancedModelArg = {
source: NonNullable<operations['import_model']['requestBody']['content']['application/json']['source']>;
config: NonNullable<operations['import_model']['requestBody']['content']['application/json']['config']>;
};
type ImportAdvancedModelResponse = paths['/api/v2/models/import']['post']['responses']['201']['content']['application/json'];
export type ScanFolderResponse =
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
type CheckpointConfigsResponse =
paths['/api/v2/models/ckpt_confs']['get']['responses']['200']['content']['application/json'];
export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
@ -199,16 +176,6 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model', 'ModelImports'],
}),
importAdvancedModel: build.mutation<ImportAdvancedModelResponse, ImportAdvancedModelArg>({
query: ({ source, config}) => {
return {
url: buildModelsUrl('install'),
method: 'POST',
body: { source, config },
};
},
invalidatesTags: ['Model', 'ModelImports'],
}),
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
query: ({ key }) => {
return {
@ -218,25 +185,14 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
convertMainModels: build.mutation<ConvertMainModelResponse, ConvertMainModelArg>({
query: ({ base_model, model_name, convert_dest_directory }) => {
convertMainModels: build.mutation<ConvertMainModelResponse, string>({
query: (key) => {
return {
url: buildModelsUrl(`convert/${base_model}/main/${model_name}`),
url: buildModelsUrl(`convert/${key}`),
method: 'PUT',
params: { convert_dest_directory },
};
},
invalidatesTags: ['Model'],
}),
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
query: ({ base_model, body }) => {
return {
url: buildModelsUrl(`merge/${base_model}`),
method: 'PUT',
body: body,
};
},
invalidatesTags: ['Model'],
invalidatesTags: ['ModelConfig'],
}),
getModelConfig: build.query<GetModelConfigResponse, string>({
query: (key) => buildModelsUrl(`i/${key}`),
@ -323,13 +279,6 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['ModelImports'],
}),
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
query: () => {
return {
url: buildModelsUrl(`ckpt_confs`),
};
},
}),
}),
});
@ -345,13 +294,10 @@ export const {
useDeleteModelsMutation,
useUpdateModelsMutation,
useImportMainModelsMutation,
useImportAdvancedModelMutation,
useConvertMainModelsMutation,
useMergeMainModelsMutation,
useSyncModelsMutation,
useScanModelsQuery,
useLazyScanModelsQuery,
useGetCheckpointConfigsQuery,
useGetModelImportsQuery,
useGetModelMetadataQuery,
useDeleteModelImportMutation,