mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): dropped model config cache breaking model edit UI
The model edit UI's composition allows for the model edit form to be instantiated before the model's config has been received. This results in the form having no values - all the fields are blank instead of populated by the model config. Part of the fix is to pass the model config around directly instead of relying on _all_ components to fetch the model directly. I also fixed a crapload of performance issues related to improper use of redux selectors.
This commit is contained in:
parent
74cef38bcf
commit
47414be1e6
@ -1,15 +1,10 @@
|
|||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
|
||||||
import { isNil } from 'lodash-es';
|
import { isNil } from 'lodash-es';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
|
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||||
import { isControlNetOrT2IAdapterModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
export const useControlNetOrT2IAdapterDefaultSettings = (modelKey?: string | null) => {
|
|
||||||
const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(
|
|
||||||
modelKey ?? skipToken,
|
|
||||||
isControlNetOrT2IAdapterModelConfig
|
|
||||||
);
|
|
||||||
|
|
||||||
|
export const useControlNetOrT2IAdapterDefaultSettings = (
|
||||||
|
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig
|
||||||
|
) => {
|
||||||
const defaultSettingsDefaults = useMemo(() => {
|
const defaultSettingsDefaults = useMemo(() => {
|
||||||
return {
|
return {
|
||||||
preprocessor: {
|
preprocessor: {
|
||||||
@ -19,5 +14,5 @@ export const useControlNetOrT2IAdapterDefaultSettings = (modelKey?: string | nul
|
|||||||
};
|
};
|
||||||
}, [modelConfig?.default_settings]);
|
}, [modelConfig?.default_settings]);
|
||||||
|
|
||||||
return { defaultSettingsDefaults, isLoading };
|
return defaultSettingsDefaults;
|
||||||
};
|
};
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
|
||||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||||
import { isNil } from 'lodash-es';
|
import { isNil } from 'lodash-es';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
|
import type { MainModelConfig } from 'services/api/types';
|
||||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
|
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
|
||||||
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision, width, height } = config.sd;
|
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision, width, height } = config.sd;
|
||||||
@ -22,9 +19,7 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
|
|||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
export const useMainModelDefaultSettings = (modelKey?: string | null) => {
|
export const useMainModelDefaultSettings = (modelConfig: MainModelConfig) => {
|
||||||
const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(modelKey ?? skipToken, isNonRefinerMainModelConfig);
|
|
||||||
|
|
||||||
const {
|
const {
|
||||||
initialSteps,
|
initialSteps,
|
||||||
initialCfg,
|
initialCfg,
|
||||||
@ -81,5 +76,5 @@ export const useMainModelDefaultSettings = (modelKey?: string | null) => {
|
|||||||
initialHeight,
|
initialHeight,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
return { defaultSettingsDefaults, isLoading, optimalDimension: getOptimalDimension(modelConfig) };
|
return defaultSettingsDefaults;
|
||||||
};
|
};
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import type { PersistConfig } from 'app/store/store';
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
import type { ModelType } from 'services/api/types';
|
import type { ModelType } from 'services/api/types';
|
||||||
|
|
||||||
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'> | 'refiner';
|
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'> | 'refiner';
|
||||||
@ -50,6 +50,8 @@ export const modelManagerV2Slice = createSlice({
|
|||||||
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType, setSelectedModelMode, setScanPath } =
|
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType, setSelectedModelMode, setScanPath } =
|
||||||
modelManagerV2Slice.actions;
|
modelManagerV2Slice.actions;
|
||||||
|
|
||||||
|
export const selectModelManagerV2Slice = (state: RootState) => state.modelmanagerV2;
|
||||||
|
|
||||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||||
const migrateModelManagerState = (state: any): any => {
|
const migrateModelManagerState = (state: any): any => {
|
||||||
if (!('_version' in state)) {
|
if (!('_version' in state)) {
|
||||||
|
@ -21,7 +21,8 @@ import { FetchingModelsLoader } from './FetchingModelsLoader';
|
|||||||
import { ModelListWrapper } from './ModelListWrapper';
|
import { ModelListWrapper } from './ModelListWrapper';
|
||||||
|
|
||||||
const ModelList = () => {
|
const ModelList = () => {
|
||||||
const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2);
|
const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType);
|
||||||
|
const searchTerm = useAppSelector((s) => s.modelmanagerV2.searchTerm);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const [mainModels, { isLoading: isLoadingMainModels }] = useMainModels();
|
const [mainModels, { isLoading: isLoadingMainModels }] = useMainModels();
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||||
import { ConfirmationAlertDialog, Flex, IconButton, Spacer, Text, useDisclosure } from '@invoke-ai/ui-library';
|
import { ConfirmationAlertDialog, Flex, IconButton, Spacer, Text, useDisclosure } from '@invoke-ai/ui-library';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
import { selectModelManagerV2Slice, setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
||||||
import ModelFormatBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge';
|
import ModelFormatBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge';
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
@ -23,15 +24,21 @@ const sx: SystemStyleObject = {
|
|||||||
"&[aria-selected='true']": { bg: 'base.700' },
|
"&[aria-selected='true']": { bg: 'base.700' },
|
||||||
};
|
};
|
||||||
|
|
||||||
const ModelListItem = (props: ModelListItemProps) => {
|
const ModelListItem = ({ model }: ModelListItemProps) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectIsSelected = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
selectModelManagerV2Slice,
|
||||||
|
(modelManagerV2Slice) => modelManagerV2Slice.selectedModelKey === model.key
|
||||||
|
),
|
||||||
|
[model.key]
|
||||||
|
);
|
||||||
|
const isSelected = useAppSelector(selectIsSelected);
|
||||||
const [deleteModel] = useDeleteModelsMutation();
|
const [deleteModel] = useDeleteModelsMutation();
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
|
|
||||||
const { model } = props;
|
|
||||||
|
|
||||||
const handleSelectModel = useCallback(() => {
|
const handleSelectModel = useCallback(() => {
|
||||||
dispatch(setSelectedModelKey(model.key));
|
dispatch(setSelectedModelKey(model.key));
|
||||||
}, [model.key, dispatch]);
|
}, [model.key, dispatch]);
|
||||||
@ -43,11 +50,6 @@ const ModelListItem = (props: ModelListItemProps) => {
|
|||||||
},
|
},
|
||||||
[onOpen]
|
[onOpen]
|
||||||
);
|
);
|
||||||
|
|
||||||
const isSelected = useMemo(() => {
|
|
||||||
return selectedModelKey === model.key;
|
|
||||||
}, [selectedModelKey, model.key]);
|
|
||||||
|
|
||||||
const handleModelDelete = useCallback(() => {
|
const handleModelDelete = useCallback(() => {
|
||||||
deleteModel({ key: model.key })
|
deleteModel({ key: model.key })
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
import { Button, Flex, Heading, SimpleGrid } from '@invoke-ai/ui-library';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { useControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings';
|
import { useControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings';
|
||||||
import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor';
|
import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor';
|
||||||
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
|
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
|
||||||
@ -10,17 +9,20 @@ import { useForm } from 'react-hook-form';
|
|||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiCheckBold } from 'react-icons/pi';
|
import { PiCheckBold } from 'react-icons/pi';
|
||||||
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||||
|
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
export type ControlNetOrT2IAdapterDefaultSettingsFormData = {
|
export type ControlNetOrT2IAdapterDefaultSettingsFormData = {
|
||||||
preprocessor: FormField<string>;
|
preprocessor: FormField<string>;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const ControlNetOrT2IAdapterDefaultSettings = () => {
|
type Props = {
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const ControlNetOrT2IAdapterDefaultSettings = ({ modelConfig }: Props) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { defaultSettingsDefaults, isLoading: isLoadingDefaultSettings } =
|
const defaultSettingsDefaults = useControlNetOrT2IAdapterDefaultSettings(modelConfig);
|
||||||
useControlNetOrT2IAdapterDefaultSettings(selectedModelKey);
|
|
||||||
|
|
||||||
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
|
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
|
||||||
|
|
||||||
@ -30,16 +32,12 @@ export const ControlNetOrT2IAdapterDefaultSettings = () => {
|
|||||||
|
|
||||||
const onSubmit = useCallback<SubmitHandler<ControlNetOrT2IAdapterDefaultSettingsFormData>>(
|
const onSubmit = useCallback<SubmitHandler<ControlNetOrT2IAdapterDefaultSettingsFormData>>(
|
||||||
(data) => {
|
(data) => {
|
||||||
if (!selectedModelKey) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const body = {
|
const body = {
|
||||||
preprocessor: data.preprocessor.isEnabled ? data.preprocessor.value : null,
|
preprocessor: data.preprocessor.isEnabled ? data.preprocessor.value : null,
|
||||||
};
|
};
|
||||||
|
|
||||||
updateModel({
|
updateModel({
|
||||||
key: selectedModelKey,
|
key: modelConfig.key,
|
||||||
body: { default_settings: body },
|
body: { default_settings: body },
|
||||||
})
|
})
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@ -61,13 +59,9 @@ export const ControlNetOrT2IAdapterDefaultSettings = () => {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[selectedModelKey, reset, updateModel, t]
|
[updateModel, modelConfig.key, t, reset]
|
||||||
);
|
);
|
||||||
|
|
||||||
if (isLoadingDefaultSettings) {
|
|
||||||
return <Text>{t('common.loading')}</Text>;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Flex gap="4" justifyContent="space-between" w="full" pb={4}>
|
<Flex gap="4" justifyContent="space-between" w="full" pb={4}>
|
||||||
|
@ -1,16 +1,18 @@
|
|||||||
import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
import { Button, Flex, Heading, SimpleGrid } from '@invoke-ai/ui-library';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings';
|
import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings';
|
||||||
import { DefaultHeight } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultHeight';
|
import { DefaultHeight } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultHeight';
|
||||||
import { DefaultWidth } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultWidth';
|
import { DefaultWidth } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultWidth';
|
||||||
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { useCallback } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import type { SubmitHandler } from 'react-hook-form';
|
import type { SubmitHandler } from 'react-hook-form';
|
||||||
import { useForm } from 'react-hook-form';
|
import { useForm } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiCheckBold } from 'react-icons/pi';
|
import { PiCheckBold } from 'react-icons/pi';
|
||||||
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||||
|
import type { MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
|
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
|
||||||
import { DefaultCfgScale } from './DefaultCfgScale';
|
import { DefaultCfgScale } from './DefaultCfgScale';
|
||||||
@ -35,16 +37,16 @@ export type MainModelDefaultSettingsFormData = {
|
|||||||
height: FormField<number>;
|
height: FormField<number>;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const MainModelDefaultSettings = () => {
|
type Props = {
|
||||||
|
modelConfig: MainModelConfig;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const MainModelDefaultSettings = ({ modelConfig }: Props) => {
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const {
|
const defaultSettingsDefaults = useMainModelDefaultSettings(modelConfig);
|
||||||
defaultSettingsDefaults,
|
const optimalDimension = useMemo(() => getOptimalDimension(modelConfig), [modelConfig]);
|
||||||
isLoading: isLoadingDefaultSettings,
|
|
||||||
optimalDimension,
|
|
||||||
} = useMainModelDefaultSettings(selectedModelKey);
|
|
||||||
|
|
||||||
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
|
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
|
||||||
|
|
||||||
const { handleSubmit, control, formState, reset } = useForm<MainModelDefaultSettingsFormData>({
|
const { handleSubmit, control, formState, reset } = useForm<MainModelDefaultSettingsFormData>({
|
||||||
@ -94,10 +96,6 @@ export const MainModelDefaultSettings = () => {
|
|||||||
[selectedModelKey, reset, updateModel, t]
|
[selectedModelKey, reset, updateModel, t]
|
||||||
);
|
);
|
||||||
|
|
||||||
if (isLoadingDefaultSettings) {
|
|
||||||
return <Text>{t('common.loading')}</Text>;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Flex gap="4" justifyContent="space-between" w="full" pb={4}>
|
<Flex gap="4" justifyContent="space-between" w="full" pb={4}>
|
||||||
|
@ -1,19 +1,10 @@
|
|||||||
import { Button, Flex, Heading, Spacer, Text } from '@invoke-ai/ui-library';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
import { IAINoContentFallback, IAINoContentFallbackWithSpinner } from 'common/components/IAIImageFallback';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useMemo } from 'react';
|
||||||
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
|
||||||
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
|
|
||||||
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
|
|
||||||
import { toast } from 'features/toast/toast';
|
|
||||||
import { useCallback } from 'react';
|
|
||||||
import type { SubmitHandler } from 'react-hook-form';
|
|
||||||
import { useForm } from 'react-hook-form';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiCheckBold, PiXBold } from 'react-icons/pi';
|
import { PiExclamationMarkBold } from 'react-icons/pi';
|
||||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
|
||||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
import ModelImageUpload from './Fields/ModelImageUpload';
|
|
||||||
import { ModelEdit } from './ModelEdit';
|
import { ModelEdit } from './ModelEdit';
|
||||||
import { ModelView } from './ModelView';
|
import { ModelView } from './ModelView';
|
||||||
|
|
||||||
@ -21,100 +12,34 @@ export const Model = () => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
|
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
const { data: modelConfigs, isLoading } = useGetModelConfigsQuery();
|
||||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
|
const modelConfig = useMemo(() => {
|
||||||
const dispatch = useAppDispatch();
|
if (!modelConfigs) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (selectedModelKey === null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs, selectedModelKey);
|
||||||
|
|
||||||
const form = useForm<UpdateModelArg['body']>({
|
if (!modelConfig) {
|
||||||
defaultValues: data,
|
return null;
|
||||||
mode: 'onChange',
|
}
|
||||||
});
|
|
||||||
|
|
||||||
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
|
return modelConfig;
|
||||||
(values) => {
|
}, [modelConfigs, selectedModelKey]);
|
||||||
if (!data?.key) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const responseBody: UpdateModelArg = {
|
|
||||||
key: data.key,
|
|
||||||
body: values,
|
|
||||||
};
|
|
||||||
|
|
||||||
updateModel(responseBody)
|
|
||||||
.unwrap()
|
|
||||||
.then((payload) => {
|
|
||||||
form.reset(payload, { keepDefaultValues: true });
|
|
||||||
dispatch(setSelectedModelMode('view'));
|
|
||||||
toast({
|
|
||||||
id: 'MODEL_UPDATED',
|
|
||||||
title: t('modelManager.modelUpdated'),
|
|
||||||
status: 'success',
|
|
||||||
});
|
|
||||||
})
|
|
||||||
.catch((_) => {
|
|
||||||
form.reset();
|
|
||||||
toast({
|
|
||||||
id: 'MODEL_UPDATE_FAILED',
|
|
||||||
title: t('modelManager.modelUpdateFailed'),
|
|
||||||
status: 'error',
|
|
||||||
});
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[dispatch, data?.key, form, t, updateModel]
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleClickCancel = useCallback(() => {
|
|
||||||
dispatch(setSelectedModelMode('view'));
|
|
||||||
}, [dispatch]);
|
|
||||||
|
|
||||||
if (isLoading) {
|
if (isLoading) {
|
||||||
return <Text>{t('common.loading')}</Text>;
|
return <IAINoContentFallbackWithSpinner label={t('common.loading')} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!data) {
|
if (!modelConfig) {
|
||||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
return <IAINoContentFallback label={t('common.somethingWentWrong')} icon={PiExclamationMarkBold} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
if (selectedModelMode === 'view') {
|
||||||
<Flex flexDir="column" gap={4}>
|
return <ModelView modelConfig={modelConfig} />;
|
||||||
<Flex alignItems="flex-start" gap={4}>
|
}
|
||||||
<ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
|
|
||||||
<Flex flexDir="column" gap={1} flexGrow={1} minW={0}>
|
return <ModelEdit modelConfig={modelConfig} />;
|
||||||
<Flex gap={2}>
|
|
||||||
<Heading as="h2" fontSize="lg" noOfLines={1} wordBreak="break-all">
|
|
||||||
{data.name}
|
|
||||||
</Heading>
|
|
||||||
<Spacer />
|
|
||||||
{selectedModelMode === 'view' && <ModelConvertButton modelKey={selectedModelKey} />}
|
|
||||||
{selectedModelMode === 'view' && <ModelEditButton />}
|
|
||||||
{selectedModelMode === 'edit' && (
|
|
||||||
<Button size="sm" onClick={handleClickCancel} leftIcon={<PiXBold />}>
|
|
||||||
{t('common.cancel')}
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
{selectedModelMode === 'edit' && (
|
|
||||||
<Button
|
|
||||||
size="sm"
|
|
||||||
colorScheme="invokeYellow"
|
|
||||||
leftIcon={<PiCheckBold />}
|
|
||||||
onClick={form.handleSubmit(onSubmit)}
|
|
||||||
isLoading={isSubmitting}
|
|
||||||
isDisabled={Boolean(Object.keys(form.formState.errors).length)}
|
|
||||||
>
|
|
||||||
{t('common.save')}
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
{data.source && (
|
|
||||||
<Text variant="subtext" noOfLines={1} wordBreak="break-all">
|
|
||||||
{t('modelManager.source')}: {data?.source}
|
|
||||||
</Text>
|
|
||||||
)}
|
|
||||||
<Text noOfLines={3}>{data.description}</Text>
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
{selectedModelMode === 'view' ? <ModelView /> : <ModelEdit form={form} onSubmit={onSubmit} />}
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
};
|
||||||
|
@ -8,52 +8,46 @@ import {
|
|||||||
UnorderedList,
|
UnorderedList,
|
||||||
useDisclosure,
|
useDisclosure,
|
||||||
} from '@invoke-ai/ui-library';
|
} from '@invoke-ai/ui-library';
|
||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useConvertModelMutation, useGetModelConfigQuery } from 'services/api/endpoints/models';
|
import { useConvertModelMutation } from 'services/api/endpoints/models';
|
||||||
|
import type { CheckpointModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
interface ModelConvertProps {
|
interface ModelConvertProps {
|
||||||
modelKey: string | null;
|
modelConfig: CheckpointModelConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const ModelConvertButton = (props: ModelConvertProps) => {
|
export const ModelConvertButton = ({ modelConfig }: ModelConvertProps) => {
|
||||||
const { modelKey } = props;
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { data } = useGetModelConfigQuery(modelKey ?? skipToken);
|
|
||||||
const [convertModel, { isLoading }] = useConvertModelMutation();
|
const [convertModel, { isLoading }] = useConvertModelMutation();
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
|
|
||||||
const modelConvertHandler = useCallback(() => {
|
const modelConvertHandler = useCallback(() => {
|
||||||
if (!data || isLoading) {
|
if (!modelConfig || isLoading) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const toastId = `CONVERTING_MODEL_${data.key}`;
|
const toastId = `CONVERTING_MODEL_${modelConfig.key}`;
|
||||||
toast({
|
toast({
|
||||||
id: toastId,
|
id: toastId,
|
||||||
title: `${t('modelManager.convertingModelBegin')}: ${data?.name}`,
|
title: `${t('modelManager.convertingModelBegin')}: ${modelConfig.name}`,
|
||||||
status: 'info',
|
status: 'info',
|
||||||
});
|
});
|
||||||
|
|
||||||
convertModel(data?.key)
|
convertModel(modelConfig.key)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.then(() => {
|
.then(() => {
|
||||||
toast({ id: toastId, title: `${t('modelManager.modelConverted')}: ${data?.name}`, status: 'success' });
|
toast({ id: toastId, title: `${t('modelManager.modelConverted')}: ${modelConfig.name}`, status: 'success' });
|
||||||
})
|
})
|
||||||
.catch(() => {
|
.catch(() => {
|
||||||
toast({
|
toast({
|
||||||
id: toastId,
|
id: toastId,
|
||||||
title: `${t('modelManager.modelConversionFailed')}: ${data?.name}`,
|
title: `${t('modelManager.modelConversionFailed')}: ${modelConfig.name}`,
|
||||||
status: 'error',
|
status: 'error',
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}, [data, isLoading, t, convertModel]);
|
}, [modelConfig, isLoading, t, convertModel]);
|
||||||
|
|
||||||
if (data?.format !== 'checkpoint') {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
@ -68,7 +62,7 @@ export const ModelConvertButton = (props: ModelConvertProps) => {
|
|||||||
🧨 {t('modelManager.convert')}
|
🧨 {t('modelManager.convert')}
|
||||||
</Button>
|
</Button>
|
||||||
<ConfirmationAlertDialog
|
<ConfirmationAlertDialog
|
||||||
title={`${t('modelManager.convert')} ${data?.name}`}
|
title={`${t('modelManager.convert')} ${modelConfig.name}`}
|
||||||
acceptCallback={modelConvertHandler}
|
acceptCallback={modelConvertHandler}
|
||||||
acceptButtonText={`${t('modelManager.convert')}`}
|
acceptButtonText={`${t('modelManager.convert')}`}
|
||||||
isOpen={isOpen}
|
isOpen={isOpen}
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
|
Button,
|
||||||
Checkbox,
|
Checkbox,
|
||||||
Flex,
|
Flex,
|
||||||
FormControl,
|
FormControl,
|
||||||
@ -7,96 +8,152 @@ import {
|
|||||||
Heading,
|
Heading,
|
||||||
Input,
|
Input,
|
||||||
SimpleGrid,
|
SimpleGrid,
|
||||||
Text,
|
|
||||||
Textarea,
|
Textarea,
|
||||||
} from '@invoke-ai/ui-library';
|
} from '@invoke-ai/ui-library';
|
||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import type { SubmitHandler, UseFormReturn } from 'react-hook-form';
|
import { ModelHeader } from 'features/modelManagerV2/subpanels/ModelPanel/ModelHeader';
|
||||||
|
import { toast } from 'features/toast/toast';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { type SubmitHandler, useForm } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
import { PiCheckBold, PiXBold } from 'react-icons/pi';
|
||||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
import { type UpdateModelArg, useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import BaseModelSelect from './Fields/BaseModelSelect';
|
import BaseModelSelect from './Fields/BaseModelSelect';
|
||||||
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
||||||
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
form: UseFormReturn<UpdateModelArg['body']>;
|
modelConfig: AnyModelConfig;
|
||||||
onSubmit: SubmitHandler<UpdateModelArg['body']>;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const stringFieldOptions = {
|
const stringFieldOptions = {
|
||||||
validate: (value?: string | null) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
|
validate: (value?: string | null) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
|
||||||
};
|
};
|
||||||
|
|
||||||
export const ModelEdit = ({ form }: Props) => {
|
export const ModelEdit = ({ modelConfig }: Props) => {
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
|
||||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
if (isLoading) {
|
const form = useForm<UpdateModelArg['body']>({
|
||||||
return <Text>{t('common.loading')}</Text>;
|
defaultValues: modelConfig,
|
||||||
}
|
mode: 'onChange',
|
||||||
|
});
|
||||||
|
|
||||||
if (!data) {
|
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
|
||||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
(values) => {
|
||||||
}
|
const responseBody: UpdateModelArg = {
|
||||||
|
key: modelConfig.key,
|
||||||
|
body: values,
|
||||||
|
};
|
||||||
|
|
||||||
|
updateModel(responseBody)
|
||||||
|
.unwrap()
|
||||||
|
.then((payload) => {
|
||||||
|
form.reset(payload, { keepDefaultValues: true });
|
||||||
|
dispatch(setSelectedModelMode('view'));
|
||||||
|
toast({
|
||||||
|
id: 'MODEL_UPDATED',
|
||||||
|
title: t('modelManager.modelUpdated'),
|
||||||
|
status: 'success',
|
||||||
|
});
|
||||||
|
})
|
||||||
|
.catch((_) => {
|
||||||
|
form.reset();
|
||||||
|
toast({
|
||||||
|
id: 'MODEL_UPDATE_FAILED',
|
||||||
|
title: t('modelManager.modelUpdateFailed'),
|
||||||
|
status: 'error',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[dispatch, modelConfig.key, form, t, updateModel]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleClickCancel = useCallback(() => {
|
||||||
|
dispatch(setSelectedModelMode('view'));
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex flexDir="column" h="full">
|
<Flex flexDir="column" gap={4}>
|
||||||
<form>
|
<ModelHeader modelConfig={modelConfig}>
|
||||||
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
<Button flexShrink={0} size="sm" onClick={handleClickCancel} leftIcon={<PiXBold />}>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(form.formState.errors.name)}>
|
{t('common.cancel')}
|
||||||
<FormLabel>{t('modelManager.modelName')}</FormLabel>
|
</Button>
|
||||||
<Input {...form.register('name', stringFieldOptions)} size="md" />
|
<Button
|
||||||
|
flexShrink={0}
|
||||||
|
size="sm"
|
||||||
|
colorScheme="invokeYellow"
|
||||||
|
leftIcon={<PiCheckBold />}
|
||||||
|
onClick={form.handleSubmit(onSubmit)}
|
||||||
|
isLoading={isSubmitting}
|
||||||
|
isDisabled={Boolean(Object.keys(form.formState.errors).length)}
|
||||||
|
>
|
||||||
|
{t('common.save')}
|
||||||
|
</Button>
|
||||||
|
</ModelHeader>
|
||||||
|
<Flex flexDir="column" h="full">
|
||||||
|
<form>
|
||||||
|
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
||||||
|
<FormControl
|
||||||
|
flexDir="column"
|
||||||
|
alignItems="flex-start"
|
||||||
|
gap={1}
|
||||||
|
isInvalid={Boolean(form.formState.errors.name)}
|
||||||
|
>
|
||||||
|
<FormLabel>{t('modelManager.modelName')}</FormLabel>
|
||||||
|
<Input {...form.register('name', stringFieldOptions)} size="md" />
|
||||||
|
|
||||||
{form.formState.errors.name?.message && (
|
{form.formState.errors.name?.message && (
|
||||||
<FormErrorMessage>{form.formState.errors.name?.message}</FormErrorMessage>
|
<FormErrorMessage>{form.formState.errors.name?.message}</FormErrorMessage>
|
||||||
)}
|
)}
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<Flex flexDir="column" gap={3} mt="4">
|
|
||||||
<Flex gap="4" alignItems="center">
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
|
||||||
<Textarea {...form.register('description')} minH={32} />
|
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Heading as="h3" fontSize="md" mt="4">
|
|
||||||
{t('modelManager.modelSettings')}
|
<Flex flexDir="column" gap={3} mt="4">
|
||||||
</Heading>
|
<Flex gap="4" alignItems="center">
|
||||||
<SimpleGrid columns={2} gap={4}>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
|
||||||
<BaseModelSelect control={form.control} />
|
|
||||||
</FormControl>
|
|
||||||
{data.type === 'main' && (
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||||
<ModelVariantSelect control={form.control} />
|
<Textarea {...form.register('description')} minH={32} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
)}
|
</Flex>
|
||||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
<Heading as="h3" fontSize="md" mt="4">
|
||||||
<>
|
{t('modelManager.modelSettings')}
|
||||||
|
</Heading>
|
||||||
|
<SimpleGrid columns={2} gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||||
|
<BaseModelSelect control={form.control} />
|
||||||
|
</FormControl>
|
||||||
|
{modelConfig.type === 'main' && (
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||||
<Input {...form.register('config_path', stringFieldOptions)} />
|
<ModelVariantSelect control={form.control} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
)}
|
||||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
{modelConfig.type === 'main' && modelConfig.format === 'checkpoint' && (
|
||||||
<PredictionTypeSelect control={form.control} />
|
<>
|
||||||
</FormControl>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
<Input {...form.register('config_path', stringFieldOptions)} />
|
||||||
<Checkbox {...form.register('upcast_attention')} />
|
</FormControl>
|
||||||
</FormControl>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
</>
|
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||||
)}
|
<PredictionTypeSelect control={form.control} />
|
||||||
</SimpleGrid>
|
</FormControl>
|
||||||
</Flex>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
</form>
|
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||||
|
<Checkbox {...form.register('upcast_attention')} />
|
||||||
|
</FormControl>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</SimpleGrid>
|
||||||
|
</Flex>
|
||||||
|
</form>
|
||||||
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,36 @@
|
|||||||
|
import { Flex, Heading, Spacer, Text } from '@invoke-ai/ui-library';
|
||||||
|
import ModelImageUpload from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelImageUpload';
|
||||||
|
import type { PropsWithChildren } from 'react';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
type Props = PropsWithChildren<{
|
||||||
|
modelConfig: AnyModelConfig;
|
||||||
|
}>;
|
||||||
|
|
||||||
|
export const ModelHeader = memo(({ modelConfig, children }: Props) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
return (
|
||||||
|
<Flex alignItems="flex-start" gap={4}>
|
||||||
|
<ModelImageUpload model_key={modelConfig.key} model_image={modelConfig.cover_image} />
|
||||||
|
<Flex flexDir="column" gap={1} flexGrow={1} minW={0}>
|
||||||
|
<Flex gap={2}>
|
||||||
|
<Heading as="h2" fontSize="lg" noOfLines={1} wordBreak="break-all">
|
||||||
|
{modelConfig.name}
|
||||||
|
</Heading>
|
||||||
|
<Spacer />
|
||||||
|
{children}
|
||||||
|
</Flex>
|
||||||
|
{modelConfig.source && (
|
||||||
|
<Text variant="subtext" noOfLines={1} wordBreak="break-all">
|
||||||
|
{t('modelManager.source')}: {modelConfig.source}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
<Text noOfLines={3}>{modelConfig.description}</Text>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
ModelHeader.displayName = 'ModelHeader';
|
@ -1,55 +1,64 @@
|
|||||||
import { Box, Flex, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
import { Box, Flex, SimpleGrid } from '@invoke-ai/ui-library';
|
||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { ControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
|
import { ControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
|
||||||
|
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
|
||||||
|
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
|
||||||
|
import { ModelHeader } from 'features/modelManagerV2/subpanels/ModelPanel/ModelHeader';
|
||||||
import { TriggerPhrases } from 'features/modelManagerV2/subpanels/ModelPanel/TriggerPhrases';
|
import { TriggerPhrases } from 'features/modelManagerV2/subpanels/ModelPanel/TriggerPhrases';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import { MainModelDefaultSettings } from './MainModelDefaultSettings/MainModelDefaultSettings';
|
import { MainModelDefaultSettings } from './MainModelDefaultSettings/MainModelDefaultSettings';
|
||||||
import { ModelAttrView } from './ModelAttrView';
|
import { ModelAttrView } from './ModelAttrView';
|
||||||
|
|
||||||
export const ModelView = () => {
|
type Props = {
|
||||||
|
modelConfig: AnyModelConfig;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const ModelView = ({ modelConfig }: Props) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
|
||||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
|
||||||
|
|
||||||
if (isLoading) {
|
|
||||||
return <Text>{t('common.loading')}</Text>;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!data) {
|
|
||||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
|
||||||
}
|
|
||||||
return (
|
return (
|
||||||
<Flex flexDir="column" h="full" gap={4}>
|
<Flex flexDir="column" gap={4}>
|
||||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
<ModelHeader modelConfig={modelConfig}>
|
||||||
<SimpleGrid columns={2} gap={4}>
|
{modelConfig.format === 'checkpoint' && modelConfig.type === 'main' && (
|
||||||
<ModelAttrView label={t('modelManager.baseModel')} value={data.base} />
|
<ModelConvertButton modelConfig={modelConfig} />
|
||||||
<ModelAttrView label={t('modelManager.modelType')} value={data.type} />
|
)}
|
||||||
<ModelAttrView label={t('common.format')} value={data.format} />
|
<ModelEditButton />
|
||||||
<ModelAttrView label={t('modelManager.path')} value={data.path} />
|
</ModelHeader>
|
||||||
{data.type === 'main' && <ModelAttrView label={t('modelManager.variant')} value={data.variant} />}
|
<Flex flexDir="column" h="full" gap={4}>
|
||||||
{data.type === 'main' && data.format === 'diffusers' && data.repo_variant && (
|
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||||
<ModelAttrView label={t('modelManager.repoVariant')} value={data.repo_variant} />
|
<SimpleGrid columns={2} gap={4}>
|
||||||
|
<ModelAttrView label={t('modelManager.baseModel')} value={modelConfig.base} />
|
||||||
|
<ModelAttrView label={t('modelManager.modelType')} value={modelConfig.type} />
|
||||||
|
<ModelAttrView label={t('common.format')} value={modelConfig.format} />
|
||||||
|
<ModelAttrView label={t('modelManager.path')} value={modelConfig.path} />
|
||||||
|
{modelConfig.type === 'main' && (
|
||||||
|
<ModelAttrView label={t('modelManager.variant')} value={modelConfig.variant} />
|
||||||
|
)}
|
||||||
|
{modelConfig.type === 'main' && modelConfig.format === 'diffusers' && modelConfig.repo_variant && (
|
||||||
|
<ModelAttrView label={t('modelManager.repoVariant')} value={modelConfig.repo_variant} />
|
||||||
|
)}
|
||||||
|
{modelConfig.type === 'main' && modelConfig.format === 'checkpoint' && (
|
||||||
|
<>
|
||||||
|
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelConfig.config_path} />
|
||||||
|
<ModelAttrView label={t('modelManager.predictionType')} value={modelConfig.prediction_type} />
|
||||||
|
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelConfig.upcast_attention}`} />
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{modelConfig.type === 'ip_adapter' && modelConfig.format === 'invokeai' && (
|
||||||
|
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelConfig.image_encoder_model_id} />
|
||||||
|
)}
|
||||||
|
</SimpleGrid>
|
||||||
|
</Box>
|
||||||
|
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||||
|
{modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner' && (
|
||||||
|
<MainModelDefaultSettings modelConfig={modelConfig} />
|
||||||
)}
|
)}
|
||||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
{(modelConfig.type === 'controlnet' || modelConfig.type === 't2i_adapter') && (
|
||||||
<>
|
<ControlNetOrT2IAdapterDefaultSettings modelConfig={modelConfig} />
|
||||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={data.config_path} />
|
|
||||||
<ModelAttrView label={t('modelManager.predictionType')} value={data.prediction_type} />
|
|
||||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${data.upcast_attention}`} />
|
|
||||||
</>
|
|
||||||
)}
|
)}
|
||||||
{data.type === 'ip_adapter' && data.format === 'invokeai' && (
|
{(modelConfig.type === 'main' || modelConfig.type === 'lora') && <TriggerPhrases modelConfig={modelConfig} />}
|
||||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
|
</Box>
|
||||||
)}
|
</Flex>
|
||||||
</SimpleGrid>
|
|
||||||
</Box>
|
|
||||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
|
||||||
{data.type === 'main' && data.base !== 'sdxl-refiner' && <MainModelDefaultSettings />}
|
|
||||||
{(data.type === 'controlnet' || data.type === 't2i_adapter') && <ControlNetOrT2IAdapterDefaultSettings />}
|
|
||||||
{(data.type === 'main' || data.type === 'lora') && <TriggerPhrases />}
|
|
||||||
</Box>
|
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -9,19 +9,19 @@ import {
|
|||||||
TagCloseButton,
|
TagCloseButton,
|
||||||
TagLabel,
|
TagLabel,
|
||||||
} from '@invoke-ai/ui-library';
|
} from '@invoke-ai/ui-library';
|
||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import type { ChangeEvent } from 'react';
|
import type { ChangeEvent } from 'react';
|
||||||
import { useCallback, useMemo, useState } from 'react';
|
import { useCallback, useMemo, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiPlusBold } from 'react-icons/pi';
|
import { PiPlusBold } from 'react-icons/pi';
|
||||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||||
import { isLoRAModelConfig, isNonRefinerMainModelConfig } from 'services/api/types';
|
import type { LoRAModelConfig, MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
export const TriggerPhrases = () => {
|
type Props = {
|
||||||
|
modelConfig: MainModelConfig | LoRAModelConfig;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const TriggerPhrases = ({ modelConfig }: Props) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
|
||||||
const { currentData: modelConfig } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
|
||||||
const [phrase, setPhrase] = useState('');
|
const [phrase, setPhrase] = useState('');
|
||||||
|
|
||||||
const [updateModel, { isLoading }] = useUpdateModelMutation();
|
const [updateModel, { isLoading }] = useUpdateModelMutation();
|
||||||
@ -31,9 +31,6 @@ export const TriggerPhrases = () => {
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const triggerPhrases = useMemo(() => {
|
const triggerPhrases = useMemo(() => {
|
||||||
if (!modelConfig || (!isNonRefinerMainModelConfig(modelConfig) && !isLoRAModelConfig(modelConfig))) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
return modelConfig?.trigger_phrases || [];
|
return modelConfig?.trigger_phrases || [];
|
||||||
}, [modelConfig]);
|
}, [modelConfig]);
|
||||||
|
|
||||||
@ -48,10 +45,6 @@ export const TriggerPhrases = () => {
|
|||||||
}, [phrase, triggerPhrases]);
|
}, [phrase, triggerPhrases]);
|
||||||
|
|
||||||
const addTriggerPhrase = useCallback(async () => {
|
const addTriggerPhrase = useCallback(async () => {
|
||||||
if (!selectedModelKey) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!phrase.length || triggerPhrases.includes(phrase)) {
|
if (!phrase.length || triggerPhrases.includes(phrase)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -59,22 +52,18 @@ export const TriggerPhrases = () => {
|
|||||||
setPhrase('');
|
setPhrase('');
|
||||||
|
|
||||||
await updateModel({
|
await updateModel({
|
||||||
key: selectedModelKey,
|
key: modelConfig.key,
|
||||||
body: { trigger_phrases: [...triggerPhrases, phrase] },
|
body: { trigger_phrases: [...triggerPhrases, phrase] },
|
||||||
}).unwrap();
|
}).unwrap();
|
||||||
}, [updateModel, selectedModelKey, phrase, triggerPhrases]);
|
}, [phrase, triggerPhrases, updateModel, modelConfig.key]);
|
||||||
|
|
||||||
const removeTriggerPhrase = useCallback(
|
const removeTriggerPhrase = useCallback(
|
||||||
async (phraseToRemove: string) => {
|
async (phraseToRemove: string) => {
|
||||||
if (!selectedModelKey) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const filteredPhrases = triggerPhrases.filter((p) => p !== phraseToRemove);
|
const filteredPhrases = triggerPhrases.filter((p) => p !== phraseToRemove);
|
||||||
|
|
||||||
await updateModel({ key: selectedModelKey, body: { trigger_phrases: filteredPhrases } }).unwrap();
|
await updateModel({ key: modelConfig.key, body: { trigger_phrases: filteredPhrases } }).unwrap();
|
||||||
},
|
},
|
||||||
[updateModel, selectedModelKey, triggerPhrases]
|
[triggerPhrases, updateModel, modelConfig]
|
||||||
);
|
);
|
||||||
|
|
||||||
const onTriggerPhraseAddFormSubmit = useCallback(
|
const onTriggerPhraseAddFormSubmit = useCallback(
|
||||||
|
@ -242,7 +242,6 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
}
|
}
|
||||||
return tags;
|
return tags;
|
||||||
},
|
},
|
||||||
keepUnusedDataFor: 60 * 60 * 1000 * 24, // 1 day (infinite)
|
|
||||||
transformResponse: (response: GetModelConfigsResponse) => {
|
transformResponse: (response: GetModelConfigsResponse) => {
|
||||||
return modelConfigsAdapter.setAll(modelConfigsAdapter.getInitialState(), response.models);
|
return modelConfigsAdapter.setAll(modelConfigsAdapter.getInitialState(), response.models);
|
||||||
},
|
},
|
||||||
|
@ -54,7 +54,7 @@ export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
|||||||
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
|
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
|
||||||
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
||||||
type DiffusersModelConfig = S['MainDiffusersConfig'];
|
type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||||
type CheckpointModelConfig = S['MainCheckpointConfig'];
|
export type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||||
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
||||||
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
||||||
export type AnyModelConfig =
|
export type AnyModelConfig =
|
||||||
|
Loading…
Reference in New Issue
Block a user