From afb46564e88c242fec9fcf52fb0c2ee5534122d7 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 12 Jul 2023 16:13:49 +1200 Subject: [PATCH 01/28] feat: Restore Update Model functionality --- invokeai/frontend/web/public/locales/en.json | 1 + .../ModelManagerPanel/CheckpointModelEdit.tsx | 61 +++++++++++++------ .../ModelManagerPanel/DiffusersModelEdit.tsx | 58 +++++++++++++----- .../web/src/services/api/endpoints/models.ts | 21 +++++++ 4 files changed, 108 insertions(+), 33 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 1a902a88b7..fc56f5a703 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -342,6 +342,7 @@ "safetensorModels": "SafeTensors", "modelAdded": "Model Added", "modelUpdated": "Model Updated", + "modelUpdateFailed": "Model Update Failed", "modelEntryDeleted": "Model Entry Deleted", "cannotUseSpaces": "Cannot Use Spaces", "addNew": "Add New", diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx index 0d5d21175a..5dbb64ca7d 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -11,7 +11,11 @@ import IAIButton from 'common/components/IAIButton'; import IAIInput from 'common/components/IAIInput'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; -import { S } from 'services/api/types'; + +import { makeToast } from 'app/components/Toaster'; +import { addToast } from 'features/system/store/systemSlice'; +import { useUpdateMainModelsMutation } from 'services/api/endpoints/models'; +import { components } from 'services/api/schema'; import ModelConvert from './ModelConvert'; const baseModelSelectData = [ @@ -25,13 +29,13 @@ const variantSelectData = [ { value: 'depth', label: 'Depth' }, ]; -export type CheckpointModel = - | S<'StableDiffusion1ModelCheckpointConfig'> - | S<'StableDiffusion2ModelCheckpointConfig'>; +export type CheckpointModelConfig = + | components['schemas']['StableDiffusion1ModelCheckpointConfig'] + | components['schemas']['StableDiffusion2ModelCheckpointConfig']; type CheckpointModelEditProps = { modelToEdit: string; - retrievedModel: CheckpointModel; + retrievedModel: CheckpointModelConfig; }; export default function CheckpointModelEdit(props: CheckpointModelEditProps) { @@ -41,25 +45,52 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) { const { modelToEdit, retrievedModel } = props; + const [updateMainModel, { error }] = useUpdateMainModelsMutation(); + const dispatch = useAppDispatch(); const { t } = useTranslation(); - const checkpointEditForm = useForm({ + const checkpointEditForm = useForm({ initialValues: { - name: retrievedModel.name, + name: retrievedModel.name ? retrievedModel.name : '', base_model: retrievedModel.base_model, type: 'main', - path: retrievedModel.path, - description: retrievedModel.description, + path: retrievedModel.path ? retrievedModel.path : '', + description: retrievedModel.description ? retrievedModel.description : '', model_format: 'checkpoint', - vae: retrievedModel.vae, - config: retrievedModel.config, + vae: retrievedModel.vae ? retrievedModel.vae : '', + config: retrievedModel.config ? retrievedModel.config : '', variant: retrievedModel.variant, }, }); - const editModelFormSubmitHandler = (values) => { - console.log(values); + const editModelFormSubmitHandler = (values: CheckpointModelConfig) => { + const responseBody = { + base_model: retrievedModel.base_model, + model_name: retrievedModel.name, + body: values, + }; + updateMainModel(responseBody); + + if (error) { + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdateFailed'), + status: 'success', + }) + ) + ); + } + + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdated'), + status: 'success', + }) + ) + ); }; return modelToEdit ? ( @@ -88,10 +119,6 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) { )} > - - | S<'StableDiffusion2ModelDiffusersConfig'>; +export type DiffusersModelConfig = + | components['schemas']['StableDiffusion1ModelDiffusersConfig'] + | components['schemas']['StableDiffusion2ModelDiffusersConfig']; type DiffusersModelEditProps = { modelToEdit: string; - retrievedModel: DiffusersModel; + retrievedModel: DiffusersModelConfig; }; const baseModelSelectData = [ @@ -39,24 +42,51 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { ); const { retrievedModel, modelToEdit } = props; + const [updateMainModel, { error }] = useUpdateMainModelsMutation(); + const dispatch = useAppDispatch(); const { t } = useTranslation(); - const diffusersEditForm = useForm({ + const diffusersEditForm = useForm({ initialValues: { - name: retrievedModel.name, + name: retrievedModel.name ? retrievedModel.name : '', base_model: retrievedModel.base_model, type: 'main', - path: retrievedModel.path, - description: retrievedModel.description, + path: retrievedModel.path ? retrievedModel.path : '', + description: retrievedModel.description ? retrievedModel.description : '', model_format: 'diffusers', - vae: retrievedModel.vae, + vae: retrievedModel.vae ? retrievedModel.vae : '', variant: retrievedModel.variant, }, }); - const editModelFormSubmitHandler = (values) => { - console.log(values); + const editModelFormSubmitHandler = (values: DiffusersModelConfig) => { + const responseBody = { + base_model: retrievedModel.base_model, + model_name: retrievedModel.name, + body: values, + }; + updateMainModel(responseBody); + + if (error) { + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdateFailed'), + status: 'success', + }) + ) + ); + } + + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdated'), + status: 'success', + }) + ) + ); }; return modelToEdit ? ( @@ -77,10 +107,6 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { )} > - ({ sortComparer: (a, b) => a.name.localeCompare(b.name), }); @@ -101,6 +108,19 @@ export const modelsApi = api.injectEndpoints({ ); }, }), + updateMainModels: build.mutation< + EntityState, + UpdateMainModelQuery + >({ + query: ({ base_model, model_name, body }) => { + return { + url: `models/${base_model}/main/${model_name}`, + method: 'PATCH', + body: body, + }; + }, + invalidatesTags: ['MainModel'], + }), getLoRAModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'lora' } }), providesTags: (result, error, arg) => { @@ -244,4 +264,5 @@ export const { useGetLoRAModelsQuery, useGetTextualInversionModelsQuery, useGetVaeModelsQuery, + useUpdateMainModelsMutation, } = modelsApi; From 5a6ad99d4ee26f72da6d33015b9059aa892c25ac Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 12 Jul 2023 16:39:07 +1200 Subject: [PATCH 02/28] feat: Restore Delete Model Functionality --- .../backend/model_management/model_manager.py | 1 + .../ModelManagerPanel/ModelListItem.tsx | 9 ++++++++- .../web/src/services/api/endpoints/models.ts | 18 ++++++++++++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 03514cfeff..213ebe8ad9 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -593,6 +593,7 @@ class ModelManager(object): rmtree(str(model_path)) else: model_path.unlink() + self.commit() # LS: tested def add_model( diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx index ab5fddd5ea..61c2c8b5e2 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -8,6 +8,7 @@ import IAIAlertDialog from 'common/components/IAIAlertDialog'; import IAIIconButton from 'common/components/IAIIconButton'; import { setOpenModel } from 'features/system/store/systemSlice'; import { useTranslation } from 'react-i18next'; +import { useDeleteMainModelsMutation } from 'services/api/endpoints/models'; type ModelListItemProps = { modelKey: string; @@ -24,6 +25,8 @@ export default function ModelListItem(props: ModelListItemProps) { (state: RootState) => state.system.openModel ); + const [deleteMainModel] = useDeleteMainModelsMutation(); + const { t } = useTranslation(); const dispatch = useAppDispatch(); @@ -35,7 +38,11 @@ export default function ModelListItem(props: ModelListItemProps) { }; const handleModelDelete = () => { - dispatch(deleteModel(modelKey)); + const [base_model, _, model_name] = modelKey.split('/'); + deleteMainModel({ + base_model: base_model, + model_name: model_name, + }); dispatch(setOpenModel(null)); }; diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 217a54a74d..767668f5b2 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -39,6 +39,11 @@ type UpdateMainModelQuery = { body: MainModelConfig; }; +type DeleteMainModelQuery = { + base_model: BaseModelType; + model_name: string; +}; + const mainModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.name.localeCompare(b.name), }); @@ -121,6 +126,18 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['MainModel'], }), + deleteMainModels: build.mutation< + EntityState, + DeleteMainModelQuery + >({ + query: ({ base_model, model_name }) => { + return { + url: `models/${base_model}/main/${model_name}`, + method: 'DELETE', + }; + }, + invalidatesTags: ['MainModel'], + }), getLoRAModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'lora' } }), providesTags: (result, error, arg) => { @@ -265,4 +282,5 @@ export const { useGetTextualInversionModelsQuery, useGetVaeModelsQuery, useUpdateMainModelsMutation, + useDeleteMainModelsMutation, } = modelsApi; From 3568e28b1cb16c5e37bb14419ee78743de635bc9 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 12 Jul 2023 19:05:16 +1200 Subject: [PATCH 03/28] fix: Type resolutions & Bug Fixes - Fix checkpoint filter not working - Resolve all typescript and undefined issues in Model Manager List / Edit Forms and main panel --- .../subpanels/ModelManagerPanel.tsx | 12 ++++-- .../ModelManagerPanel/CheckpointModelEdit.tsx | 8 +++- .../ModelManagerPanel/DiffusersModelEdit.tsx | 8 +++- .../subpanels/ModelManagerPanel/ModelList.tsx | 38 ++++++++++--------- 4 files changed, 41 insertions(+), 25 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx index b22a303571..ddf0874b21 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx @@ -17,19 +17,23 @@ export default function ModelManagerPanel() { const renderModelEditTabs = () => { if (!openModel || !mainModels) return; - if (mainModels['entities'][openModel]['model_format'] === 'diffusers') { + const openedModelData = mainModels['entities'][openModel]; + + if (openedModelData && openedModelData.model_format === 'diffusers') { return ( ); - } else { + } + + if (openedModelData && openedModelData.model_format === 'checkpoint') { return ( ); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx index 5dbb64ca7d..1eeb6d2fc8 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -45,7 +45,7 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) { const { modelToEdit, retrievedModel } = props; - const [updateMainModel, { error }] = useUpdateMainModelsMutation(); + const [updateMainModel, { error, isLoading }] = useUpdateMainModelsMutation(); const dispatch = useAppDispatch(); const { t } = useTranslation(); @@ -145,7 +145,11 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) { label={t('modelManager.config')} {...checkpointEditForm.getInputProps('config')} /> - + {t('modelManager.updateModel')} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx index 696a620d65..377317775b 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx @@ -42,7 +42,7 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { ); const { retrievedModel, modelToEdit } = props; - const [updateMainModel, { error }] = useUpdateMainModelsMutation(); + const [updateMainModel, { isLoading, error }] = useUpdateMainModelsMutation(); const dispatch = useAppDispatch(); const { t } = useTranslation(); @@ -129,7 +129,11 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { label={t('modelManager.vaeLocation')} {...diffusersEditForm.getInputProps('vae')} /> - + {t('modelManager.updateModel')} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx index eb05e70357..3eee71e576 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -50,7 +50,7 @@ const ModelList = () => { const [searchText, setSearchText] = useState(''); const [isSelectedFilter, setIsSelectedFilter] = useState< - 'all' | 'ckpt' | 'diffusers' + 'all' | 'checkpoint' | 'diffusers' >('all'); const [_, startTransition] = useTransition(); @@ -73,35 +73,39 @@ const ModelList = () => { const modelList = mainModels.entities; Object.keys(modelList).forEach((model, i) => { - if ( - modelList[model].name.toLowerCase().includes(searchText.toLowerCase()) - ) { + const modelInfo = modelList[model]; + + // If no model info found for a model, ignore it + if (!modelInfo) return; + + if (modelInfo.name.toLowerCase().includes(searchText.toLowerCase())) { filteredModelListItemsToRender.push( ); - if (modelList[model]?.model_format === isSelectedFilter) { + if (modelInfo?.model_format === isSelectedFilter) { localFilteredModelListItemsToRender.push( ); } } - if (modelList[model]?.model_format !== 'diffusers') { + + if (modelInfo?.model_format !== 'diffusers') { ckptModelListItemsToRender.push( ); } else { @@ -109,8 +113,8 @@ const ModelList = () => { ); } @@ -170,7 +174,7 @@ const ModelList = () => { )} - {isSelectedFilter === 'ckpt' && ( + {isSelectedFilter === 'checkpoint' && ( {ckptModelListItemsToRender} @@ -206,8 +210,8 @@ const ModelList = () => { /> setIsSelectedFilter('ckpt')} - isActive={isSelectedFilter === 'ckpt'} + onClick={() => setIsSelectedFilter('checkpoint')} + isActive={isSelectedFilter === 'checkpoint'} /> From 310e401b033d212d18d69cf3d1f755f8476afc61 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 12 Jul 2023 20:10:33 +1200 Subject: [PATCH 04/28] feat: Create basic IAIMantineTextInput component for form usage --- .../src/common/components/IAIMantineInput.tsx | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 invokeai/frontend/web/src/common/components/IAIMantineInput.tsx diff --git a/invokeai/frontend/web/src/common/components/IAIMantineInput.tsx b/invokeai/frontend/web/src/common/components/IAIMantineInput.tsx new file mode 100644 index 0000000000..f7c2b91ff0 --- /dev/null +++ b/invokeai/frontend/web/src/common/components/IAIMantineInput.tsx @@ -0,0 +1,31 @@ +import { useColorMode } from '@chakra-ui/react'; +import { TextInput, TextInputProps } from '@mantine/core'; +import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; +import { mode } from 'theme/util/mode'; + +type IAIMantineTextInputProps = TextInputProps; + +export default function IAIMantineTextInput(props: IAIMantineTextInputProps) { + const { ...rest } = props; + const { base50, base100, base200, base800, base900, accent500, accent300 } = + useChakraThemeTokens(); + const { colorMode } = useColorMode(); + + return ( + ({ + input: { + color: mode(base900, base100)(colorMode), + backgroundColor: mode(base50, base900)(colorMode), + borderColor: mode(base200, base800)(colorMode), + borderWidth: 2, + outline: 'none', + ':focus': { + borderColor: mode(accent300, accent500)(colorMode), + }, + }, + })} + {...rest} + /> + ); +} From 6238a53fdd856df5f9a4e1c5676382dfb8f3a1b4 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 12 Jul 2023 20:11:05 +1200 Subject: [PATCH 05/28] feat: Add basic form validation for path input --- .../AddModelsPanel/AddDiffusersModel.tsx | 17 +++-------------- .../ModelManagerPanel/CheckpointModelEdit.tsx | 14 +++++++++----- .../ModelManagerPanel/DiffusersModelEdit.tsx | 12 ++++++++---- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddDiffusersModel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddDiffusersModel.tsx index dd491828da..c871a0ede5 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddDiffusersModel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddDiffusersModel.tsx @@ -66,13 +66,13 @@ export default function AddDiffusersModel() { }; return ( - + {({ handleSubmit, errors, touched }) => ( - + {/* Name */} @@ -90,7 +90,6 @@ export default function AddDiffusersModel() { name="name" type="text" validate={baseValidation} - width="2xl" isRequired /> {!!errors.name && touched.name ? ( @@ -119,7 +118,6 @@ export default function AddDiffusersModel() { id="description" name="description" type="text" - width="2xl" isRequired /> {!!errors.description && touched.description ? ( @@ -153,13 +151,7 @@ export default function AddDiffusersModel() { {t('modelManager.modelLocation')} - + {!!errors.path && touched.path ? ( {errors.path} ) : ( @@ -181,7 +173,6 @@ export default function AddDiffusersModel() { id="repo_id" name="repo_id" type="text" - width="2xl" /> {!!errors.repo_id && touched.repo_id ? ( {errors.repo_id} @@ -220,7 +211,6 @@ export default function AddDiffusersModel() { id="vae.path" name="vae.path" type="text" - width="2xl" /> {!!errors.vae?.path && touched.vae?.path ? ( {errors.vae?.path} @@ -245,7 +235,6 @@ export default function AddDiffusersModel() { id="vae.repo_id" name="vae.repo_id" type="text" - width="2xl" /> {!!errors.vae?.repo_id && touched.vae?.repo_id ? ( {errors.vae?.repo_id} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx index 1eeb6d2fc8..8f70a4f322 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -8,11 +8,11 @@ import { useTranslation } from 'react-i18next'; import type { RootState } from 'app/store/store'; import IAIButton from 'common/components/IAIButton'; -import IAIInput from 'common/components/IAIInput'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; import { makeToast } from 'app/components/Toaster'; +import IAIMantineTextInput from 'common/components/IAIMantineInput'; import { addToast } from 'features/system/store/systemSlice'; import { useUpdateMainModelsMutation } from 'services/api/endpoints/models'; import { components } from 'services/api/schema'; @@ -62,6 +62,10 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) { config: retrievedModel.config ? retrievedModel.config : '', variant: retrievedModel.variant, }, + validate: { + path: (value) => + value.trim().length === 0 ? 'Must provide a path' : null, + }, }); const editModelFormSubmitHandler = (values: CheckpointModelConfig) => { @@ -119,7 +123,7 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) { )} > - @@ -133,15 +137,15 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) { data={variantSelectData} {...checkpointEditForm.getInputProps('variant')} /> - - - diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx index 377317775b..68135a26d4 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx @@ -9,7 +9,7 @@ import { useForm } from '@mantine/form'; import { makeToast } from 'app/components/Toaster'; import type { RootState } from 'app/store/store'; import IAIButton from 'common/components/IAIButton'; -import IAIInput from 'common/components/IAIInput'; +import IAIMantineTextInput from 'common/components/IAIMantineInput'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; import { addToast } from 'features/system/store/systemSlice'; @@ -58,6 +58,10 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { vae: retrievedModel.vae ? retrievedModel.vae : '', variant: retrievedModel.variant, }, + validate: { + path: (value) => + value.trim().length === 0 ? 'Must provide a path' : null, + }, }); const editModelFormSubmitHandler = (values: DiffusersModelConfig) => { @@ -107,7 +111,7 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { )} > - @@ -121,11 +125,11 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { data={variantSelectData} {...diffusersEditForm.getInputProps('variant')} /> - - From 2cedf6aed5c9a6e4370cab82452b96c9168545b2 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 12 Jul 2023 20:40:58 +1200 Subject: [PATCH 06/28] feat: Restore Model Convert Functionality --- invokeai/frontend/web/public/locales/en.json | 3 +- .../ModelManagerPanel/DiffusersModelEdit.tsx | 2 +- .../ModelManagerPanel/ModelConvert.tsx | 65 ++++++++++++------- .../web/src/services/api/endpoints/models.ts | 18 +++++ 4 files changed, 61 insertions(+), 27 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index fc56f5a703..23e13fbbc7 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -409,7 +409,7 @@ "convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.", "convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.", "convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.", - "convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 4GB-7GB in size.", + "convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.", "convertToDiffusersHelpText6": "Do you wish to convert this model?", "convertToDiffusersSaveLocation": "Save Location", "v1": "v1", @@ -420,6 +420,7 @@ "pathToCustomConfig": "Path To Custom Config", "statusConverting": "Converting", "modelConverted": "Model Converted", + "modelConversionFailed": "Model Conversion Failed", "sameFolder": "Same folder", "invokeRoot": "InvokeAI folder", "custom": "Custom", diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx index 68135a26d4..776290bf08 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx @@ -77,7 +77,7 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { addToast( makeToast({ title: t('modelManager.modelUpdateFailed'), - status: 'success', + status: 'error', }) ) ); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx index 9f571c2fff..4502b339ac 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx @@ -1,23 +1,17 @@ -import { - Flex, - ListItem, - Radio, - RadioGroup, - Text, - Tooltip, - UnorderedList, -} from '@chakra-ui/react'; +import { Flex, ListItem, Text, UnorderedList } from '@chakra-ui/react'; // import { convertToDiffusers } from 'app/socketio/actions'; +import { makeToast } from 'app/components/Toaster'; import { useAppDispatch } from 'app/store/storeHooks'; import IAIAlertDialog from 'common/components/IAIAlertDialog'; import IAIButton from 'common/components/IAIButton'; -import IAIInput from 'common/components/IAIInput'; +import { addToast } from 'features/system/store/systemSlice'; import { useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; -import { CheckpointModel } from './CheckpointModelEdit'; +import { useConvertMainModelMutation } from 'services/api/endpoints/models'; +import { CheckpointModelConfig } from './CheckpointModelEdit'; interface ModelConvertProps { - model: CheckpointModel; + model: CheckpointModelConfig; } export default function ModelConvert(props: ModelConvertProps) { @@ -26,6 +20,9 @@ export default function ModelConvert(props: ModelConvertProps) { const dispatch = useAppDispatch(); const { t } = useTranslation(); + const [convertModel, { isLoading, error, data }] = + useConvertMainModelMutation(); + const [saveLocation, setSaveLocation] = useState('same'); const [customSaveLocation, setCustomSaveLocation] = useState(''); @@ -38,15 +35,33 @@ export default function ModelConvert(props: ModelConvertProps) { }; const modelConvertHandler = () => { - const modelToConvert = { - model_name: model, - save_location: saveLocation, - custom_location: - saveLocation === 'custom' && customSaveLocation !== '' - ? customSaveLocation - : null, + const responseBody = { + base_model: model.base_model, + model_name: model.name, }; - dispatch(convertToDiffusers(modelToConvert)); + convertModel(responseBody); + + if (error) { + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelConversionFailed'), + status: 'error', + }) + ) + ); + } + + if (data) { + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelConverted'), + status: 'success', + }) + ) + ); + } }; return ( @@ -60,6 +75,7 @@ export default function ModelConvert(props: ModelConvertProps) { size={'sm'} aria-label={t('modelManager.convertToDiffusers')} className=" modal-close-btn" + isLoading={isLoading} > 🧨 {t('modelManager.convertToDiffusers')} @@ -77,7 +93,7 @@ export default function ModelConvert(props: ModelConvertProps) { {t('modelManager.convertToDiffusersHelpText6')} - + {/* {t('modelManager.convertToDiffusersSaveLocation')} @@ -103,9 +119,9 @@ export default function ModelConvert(props: ModelConvertProps) { - + */} - {saveLocation === 'custom' && ( + {/* {saveLocation === 'custom' && ( {t('modelManager.customSaveLocation')} @@ -119,8 +135,7 @@ export default function ModelConvert(props: ModelConvertProps) { width="full" /> - )} - + )} */} ); } diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 767668f5b2..d5f08c864c 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -44,6 +44,11 @@ type DeleteMainModelQuery = { model_name: string; }; +type ConvertMainModelQuery = { + base_model: BaseModelType; + model_name: string; +}; + const mainModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.name.localeCompare(b.name), }); @@ -138,6 +143,18 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['MainModel'], }), + convertMainModel: build.mutation< + EntityState, + ConvertMainModelQuery + >({ + query: ({ base_model, model_name }) => { + return { + url: `models/convert/${base_model}/main/${model_name}`, + method: 'PUT', + }; + }, + invalidatesTags: ['MainModel'], + }), getLoRAModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'lora' } }), providesTags: (result, error, arg) => { @@ -283,4 +300,5 @@ export const { useGetVaeModelsQuery, useUpdateMainModelsMutation, useDeleteMainModelsMutation, + useConvertMainModelMutation, } = modelsApi; From 683229e2855a6f444a7f0a37b3ab4233d6a2f1e2 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 12 Jul 2023 20:44:57 +1200 Subject: [PATCH 07/28] fix: Update model convert toast message --- .../ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx index 4502b339ac..3e7ca7469a 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx @@ -45,7 +45,7 @@ export default function ModelConvert(props: ModelConvertProps) { dispatch( addToast( makeToast({ - title: t('modelManager.modelConversionFailed'), + title: `${t('modelManager.modelConversionFailed')}: ${model.name}`, status: 'error', }) ) @@ -56,7 +56,7 @@ export default function ModelConvert(props: ModelConvertProps) { dispatch( addToast( makeToast({ - title: t('modelManager.modelConverted'), + title: `${t('modelManager.modelConverted')}: ${model.name}`, status: 'success', }) ) From 3db1aa738c1f57930f270bd11db4435aac749d09 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 12 Jul 2023 22:43:06 +1200 Subject: [PATCH 08/28] feat: Restore Model Merge functionality --- invokeai/frontend/web/public/locales/en.json | 4 +- .../subpanels/MergeModelsPanel.tsx | 202 +++++++++++++----- .../ModelManagerPanel/ModelConvert.tsx | 5 +- .../web/src/services/api/endpoints/models.ts | 24 ++- .../frontend/web/src/services/api/types.d.ts | 1 + 5 files changed, 172 insertions(+), 64 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 23e13fbbc7..e2e0e4ae95 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -427,6 +427,7 @@ "customSaveLocation": "Custom Save Location", "merge": "Merge", "modelsMerged": "Models Merged", + "modelsMergeFailed": "Model Merge Failed", "mergeModels": "Merge Models", "modelOne": "Model 1", "modelTwo": "Model 2", @@ -447,7 +448,8 @@ "weightedSum": "Weighted Sum", "none": "none", "addDifference": "Add Difference", - "pickModelType": "Pick Model Type" + "pickModelType": "Pick Model Type", + "selectModel": "Select Model" }, "parameters": { "general": "General", diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx index b71b5636b4..693a0130a7 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx @@ -1,35 +1,74 @@ -import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react'; -import { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { + Flex, + Radio, + RadioGroup, + Text, + Tooltip, + useColorMode, +} from '@chakra-ui/react'; +import { makeToast } from 'app/components/Toaster'; +import { useAppDispatch } from 'app/store/storeHooks'; import IAIButton from 'common/components/IAIButton'; import IAIInput from 'common/components/IAIInput'; -import IAISelect from 'common/components/IAISelect'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; import IAISlider from 'common/components/IAISlider'; +import { addToast } from 'features/system/store/systemSlice'; import { pickBy } from 'lodash-es'; -import { useState } from 'react'; +import { useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; -import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { + useGetMainModelsQuery, + useMergeMainModelsMutation, +} from 'services/api/endpoints/models'; +import { BaseModelType, MergeModelConfig } from 'services/api/types'; +import { mode } from 'theme/util/mode'; + +const baseModelTypeSelectData = [ + { label: 'Stable Diffusion 1', value: 'sd-1' }, + { label: 'Stable Diffusion 2', value: 'sd-2' }, +]; export default function MergeModelsPanel() { const { t } = useTranslation(); + const { colorMode } = useColorMode(); const dispatch = useAppDispatch(); const { data } = useGetMainModelsQuery(); - const diffusersModels = pickBy( + const [mergeModels, { isLoading, error, data: mergedModelData }] = + useMergeMainModelsMutation(); + + const [baseModel, setBaseModel] = useState('sd-1'); + + const sd1DiffusersModels = pickBy( data?.entities, - (value, _) => value?.model_format === 'diffusers' + (value, _) => + value?.model_format === 'diffusers' && value?.base_model === 'sd-1' ); - const [modelOne, setModelOne] = useState( - Object.keys(diffusersModels)[0] + const sd2DiffusersModels = pickBy( + data?.entities, + (value, _) => + value?.model_format === 'diffusers' && value?.base_model === 'sd-2' ); - const [modelTwo, setModelTwo] = useState( - Object.keys(diffusersModels)[1] + + const modelsMap = useMemo(() => { + return { + 'sd-1': sd1DiffusersModels, + 'sd-2': sd2DiffusersModels, + }; + }, [sd1DiffusersModels, sd2DiffusersModels]); + + const [modelOne, setModelOne] = useState( + Object.keys(modelsMap[baseModel])[0] ); - const [modelThree, setModelThree] = useState('none'); + const [modelTwo, setModelTwo] = useState( + Object.keys(modelsMap[baseModel])[1] + ); + + const [modelThree, setModelThree] = useState(null); const [mergedModelName, setMergedModelName] = useState(''); const [modelMergeAlpha, setModelMergeAlpha] = useState(0.5); @@ -47,41 +86,72 @@ export default function MergeModelsPanel() { const [modelMergeForce, setModelMergeForce] = useState(false); - const modelOneList = Object.keys(diffusersModels).filter( - (model) => model !== modelTwo && model !== modelThree + const modelOneList = Object.keys( + modelsMap[baseModel as keyof typeof modelsMap] + ).filter((model) => model !== modelTwo && model !== modelThree); + + const modelTwoList = Object.keys( + modelsMap[baseModel as keyof typeof modelsMap] + ).filter((model) => model !== modelOne && model !== modelThree); + + const modelThreeList = Object.keys(modelsMap[baseModel]).filter( + (model) => model !== modelOne && model !== modelTwo ); - const modelTwoList = Object.keys(diffusersModels).filter( - (model) => model !== modelOne && model !== modelThree - ); - - const modelThreeList = [ - { key: t('modelManager.none'), value: 'none' }, - ...Object.keys(diffusersModels) - .filter((model) => model !== modelOne && model !== modelTwo) - .map((model) => ({ key: model, value: model })), - ]; - - const isProcessing = useAppSelector( - (state: RootState) => state.system.isProcessing - ); + const handleBaseModelChange = (v: string) => { + setBaseModel(v as BaseModelType); + setModelOne(null); + setModelTwo(null); + }; const mergeModelsHandler = () => { - let modelsToMerge: string[] = [modelOne, modelTwo, modelThree]; - modelsToMerge = modelsToMerge.filter((model) => model !== 'none'); + const models_names: string[] = []; - const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = { - models_to_merge: modelsToMerge, + let modelsToMerge: (string | null)[] = [modelOne, modelTwo, modelThree]; + modelsToMerge = modelsToMerge.filter((model) => model !== null); + modelsToMerge.forEach((model) => { + if (model) { + models_names.push(model?.split('/')[2]); + } + }); + + const mergeModelsInfo: MergeModelConfig = { + model_names: models_names, merged_model_name: - mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'), + mergedModelName !== '' ? mergedModelName : models_names.join('-'), alpha: modelMergeAlpha, interp: modelMergeInterp, - model_merge_save_path: - modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc, + // model_merge_save_path: + // modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc, force: modelMergeForce, }; - dispatch(mergeDiffusersModels(mergeModelsInfo)); + mergeModels({ + base_model: baseModel, + body: mergeModelsInfo, + }); + + if (error) { + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelsMergeFailed'), + status: 'error', + }) + ) + ); + } + + if (mergedModelData) { + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelsMerged'), + status: 'success', + }) + ) + ); + } }; return ( @@ -90,7 +160,6 @@ export default function MergeModelsPanel() { sx={{ flexDirection: 'column', rowGap: 1, - bg: 'base.900', }} > {t('modelManager.modelMergeHeaderHelp1')} @@ -98,26 +167,43 @@ export default function MergeModelsPanel() { {t('modelManager.modelMergeHeaderHelp2')} + - + setModelOne(e.target.value)} + w="100%" + value={modelOne} + placeholder={t('modelManager.selectModel')} + data={modelOneList} + onChange={(v) => setModelOne(v)} /> - setModelTwo(e.target.value)} + w="100%" + placeholder={t('modelManager.selectModel')} + value={modelTwo} + data={modelTwoList} + onChange={(v) => setModelTwo(v)} /> - { - if (e.target.value !== 'none') { - setModelThree(e.target.value); + data={modelThreeList} + w="100%" + placeholder={t('modelManager.selectModel')} + clearable + onChange={(v) => { + if (!v) { + setModelThree(null); setModelMergeInterp('add_difference'); } else { - setModelThree('none'); + setModelThree(v); setModelMergeInterp('weighted_sum'); } }} @@ -136,7 +222,7 @@ export default function MergeModelsPanel() { padding: 4, borderRadius: 'base', gap: 4, - bg: 'base.900', + bg: mode('base.100', 'base.800')(colorMode), }} > @@ -174,7 +260,7 @@ export default function MergeModelsPanel() { ) => setModelMergeInterp(v)} > - {modelThree === 'none' ? ( + {modelThree === null ? ( <> {t('modelManager.weightedSum')} @@ -199,7 +285,7 @@ export default function MergeModelsPanel() { - setModelMergeCustomSaveLoc(e.target.value)} /> )} - + */} {t('modelManager.merge')} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx index 3e7ca7469a..63e5767585 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx @@ -7,7 +7,8 @@ import IAIButton from 'common/components/IAIButton'; import { addToast } from 'features/system/store/systemSlice'; import { useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; -import { useConvertMainModelMutation } from 'services/api/endpoints/models'; + +import { useConvertMainModelsMutation } from 'services/api/endpoints/models'; import { CheckpointModelConfig } from './CheckpointModelEdit'; interface ModelConvertProps { @@ -21,7 +22,7 @@ export default function ModelConvert(props: ModelConvertProps) { const { t } = useTranslation(); const [convertModel, { isLoading, error, data }] = - useConvertMainModelMutation(); + useConvertMainModelsMutation(); const [saveLocation, setSaveLocation] = useState('same'); const [customSaveLocation, setCustomSaveLocation] = useState(''); diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index d5f08c864c..cea586a3ff 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -6,6 +6,7 @@ import { ControlNetModelConfig, LoRAModelConfig, MainModelConfig, + MergeModelConfig, TextualInversionModelConfig, VaeModelConfig, } from 'services/api/types'; @@ -49,6 +50,11 @@ type ConvertMainModelQuery = { model_name: string; }; +type MergeMainModelQuery = { + base_model: BaseModelType; + body: MergeModelConfig; +}; + const mainModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.name.localeCompare(b.name), }); @@ -143,7 +149,7 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['MainModel'], }), - convertMainModel: build.mutation< + convertMainModels: build.mutation< EntityState, ConvertMainModelQuery >({ @@ -155,6 +161,19 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['MainModel'], }), + mergeMainModels: build.mutation< + EntityState, + MergeMainModelQuery + >({ + query: ({ base_model, body }) => { + return { + url: `models/merge/${base_model}`, + method: 'PUT', + body: body, + }; + }, + invalidatesTags: ['MainModel'], + }), getLoRAModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'lora' } }), providesTags: (result, error, arg) => { @@ -300,5 +319,6 @@ export const { useGetVaeModelsQuery, useUpdateMainModelsMutation, useDeleteMainModelsMutation, - useConvertMainModelMutation, + useConvertMainModelsMutation, + useMergeMainModelsMutation, } = modelsApi; diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts index ab8214a903..9c154cbc46 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.d.ts @@ -50,6 +50,7 @@ export type AnyModelConfig = | ControlNetModelConfig | TextualInversionModelConfig | MainModelConfig; +export type MergeModelConfig = components['schemas']['Body_merge_models']; // Graphs export type Graph = components['schemas']['Graph']; From 31bb4bfc618d68cee73304000d5813afec916125 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 12 Jul 2023 23:12:12 +1200 Subject: [PATCH 09/28] style: Update Model Manager Styling to new format --- .../IAIForms/IAIFormItemWrapper.tsx | 6 +- .../tabs/ModelManager/ModelManagerTab.tsx | 80 +++++++++------- .../ModelManager/subpanels/AddModelsPanel.tsx | 22 +---- .../subpanels/ModelManagerPanel/ModelList.tsx | 92 +++++++++---------- .../ModelManagerPanel/ModelListItem.tsx | 21 ++++- .../web/src/theme/components/button.ts | 16 ---- 6 files changed, 115 insertions(+), 122 deletions(-) diff --git a/invokeai/frontend/web/src/common/components/IAIForms/IAIFormItemWrapper.tsx b/invokeai/frontend/web/src/common/components/IAIForms/IAIFormItemWrapper.tsx index 1b1ca29d76..83e91366c2 100644 --- a/invokeai/frontend/web/src/common/components/IAIForms/IAIFormItemWrapper.tsx +++ b/invokeai/frontend/web/src/common/components/IAIForms/IAIFormItemWrapper.tsx @@ -1,11 +1,13 @@ -import { Flex } from '@chakra-ui/react'; +import { Flex, useColorMode } from '@chakra-ui/react'; import { ReactElement } from 'react'; +import { mode } from 'theme/util/mode'; export function IAIFormItemWrapper({ children, }: { children: ReactElement | ReactElement[]; }) { + const { colorMode } = useColorMode(); return ( {children} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx index 8d675b17c8..7a56a3f651 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx @@ -1,6 +1,14 @@ -import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react'; +import { + Tab, + TabList, + TabPanel, + TabPanels, + Tabs, + useColorMode, +} from '@chakra-ui/react'; import i18n from 'i18n'; import { ReactNode, memo } from 'react'; +import { mode } from 'theme/util/mode'; import AddModelsPanel from './subpanels/AddModelsPanel'; import MergeModelsPanel from './subpanels/MergeModelsPanel'; import ModelManagerPanel from './subpanels/ModelManagerPanel'; @@ -31,41 +39,43 @@ const modelManagerTabs: ModelManagerTabInfo[] = [ }, ]; -const renderTabsList = () => { - const modelManagerTabListsToRender: ReactNode[] = []; - modelManagerTabs.forEach((modelManagerTab) => { - modelManagerTabListsToRender.push( - {modelManagerTab.label} - ); - }); - - return ( - - {modelManagerTabListsToRender} - - ); -}; - -const renderTabPanels = () => { - const modelManagerTabPanelsToRender: ReactNode[] = []; - modelManagerTabs.forEach((modelManagerTab) => { - modelManagerTabPanelsToRender.push( - {modelManagerTab.content} - ); - }); - - return {modelManagerTabPanelsToRender}; -}; - const ModelManagerTab = () => { + const { colorMode } = useColorMode(); + + const renderTabsList = () => { + const modelManagerTabListsToRender: ReactNode[] = []; + + modelManagerTabs.forEach((modelManagerTab) => { + modelManagerTabListsToRender.push( + {modelManagerTab.label} + ); + }); + + return ( + + {modelManagerTabListsToRender} + + ); + }; + + const renderTabPanels = () => { + const modelManagerTabPanelsToRender: ReactNode[] = []; + modelManagerTabs.forEach((modelManagerTab) => { + modelManagerTabPanelsToRender.push( + {modelManagerTab.content} + ); + }); + + return {modelManagerTabPanelsToRender}; + }; return ( state.ui.addNewModelUIOption ); + const { colorMode } = useColorMode(); + const dispatch = useAppDispatch(); const { t } = useTranslation(); @@ -20,27 +22,13 @@ export default function AddModelsPanel() { dispatch(setAddNewModelUIOption('ckpt'))} - sx={{ - backgroundColor: - addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.700', - '&:hover': { - backgroundColor: - addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.600', - }, - }} + isChecked={addNewModelUIOption == 'ckpt'} > {t('modelManager.addCheckpointModel')} dispatch(setAddNewModelUIOption('diffusers'))} - sx={{ - backgroundColor: - addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.700', - '&:hover': { - backgroundColor: - addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.600', - }, - }} + isChecked={addNewModelUIOption == 'diffusers'} > {t('modelManager.addDiffuserModel')} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx index 3eee71e576..803821f0e9 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -1,4 +1,4 @@ -import { Box, Flex, Spinner, Text } from '@chakra-ui/react'; +import { Box, Flex, Spinner, Text, useColorMode } from '@chakra-ui/react'; import IAIButton from 'common/components/IAIButton'; import IAIInput from 'common/components/IAIInput'; @@ -9,6 +9,7 @@ import { useTranslation } from 'react-i18next'; import type { ChangeEvent, ReactNode } from 'react'; import React, { useMemo, useState, useTransition } from 'react'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { mode } from 'theme/util/mode'; function ModelFilterButton({ label, @@ -20,16 +21,7 @@ function ModelFilterButton({ onClick: () => void; }) { return ( - + {label} ); @@ -37,6 +29,7 @@ function ModelFilterButton({ const ModelList = () => { const { data: mainModels } = useGetMainModelsQuery(); + const { colorMode } = useColorMode(); const [renderModelList, setRenderModelList] = React.useState(false); @@ -130,41 +123,46 @@ const ModelList = () => { {isSelectedFilter === 'all' && ( <> - - - {t('modelManager.diffusersModels')} - - {diffusersModelListItemsToRender} - - - - {t('modelManager.checkpointModels')} - - {ckptModelListItemsToRender} - + {diffusersModelListItemsToRender.length > 0 && ( + + + {t('modelManager.diffusersModels')} + + {diffusersModelListItemsToRender} + + )} + + {ckptModelListItemsToRender.length > 0 && ( + + + {t('modelManager.checkpointModels')} + + {ckptModelListItemsToRender} + + )} )} @@ -181,7 +179,7 @@ const ModelList = () => { )} ); - }, [mainModels, searchText, t, isSelectedFilter]); + }, [mainModels, searchText, t, isSelectedFilter, colorMode]); return ( diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx index 61c2c8b5e2..993602da4c 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -1,5 +1,12 @@ import { DeleteIcon, EditIcon } from '@chakra-ui/icons'; -import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react'; +import { + Box, + Flex, + Spacer, + Text, + Tooltip, + useColorMode, +} from '@chakra-ui/react'; // import { deleteModel, requestModelChange } from 'app/socketio/actions'; import { RootState } from 'app/store/store'; @@ -9,6 +16,8 @@ import IAIIconButton from 'common/components/IAIIconButton'; import { setOpenModel } from 'features/system/store/systemSlice'; import { useTranslation } from 'react-i18next'; import { useDeleteMainModelsMutation } from 'services/api/endpoints/models'; +import { BaseModelType } from 'services/api/types'; +import { mode } from 'theme/util/mode'; type ModelListItemProps = { modelKey: string; @@ -21,6 +30,8 @@ export default function ModelListItem(props: ModelListItemProps) { (state: RootState) => state.system ); + const { colorMode } = useColorMode(); + const openModel = useAppSelector( (state: RootState) => state.system.openModel ); @@ -40,7 +51,7 @@ export default function ModelListItem(props: ModelListItemProps) { const handleModelDelete = () => { const [base_model, _, model_name] = modelKey.split('/'); deleteMainModel({ - base_model: base_model, + base_model: base_model as BaseModelType, model_name: model_name, }); dispatch(setOpenModel(null)); @@ -54,14 +65,14 @@ export default function ModelListItem(props: ModelListItemProps) { sx={ modelKey === openModel ? { - bg: 'accent.750', + bg: mode('accent.200', 'accent.600')(colorMode), _hover: { - bg: 'accent.750', + bg: mode('accent.200', 'accent.600')(colorMode), }, } : { _hover: { - bg: 'base.750', + bg: mode('base.100', 'base.800')(colorMode), }, } } diff --git a/invokeai/frontend/web/src/theme/components/button.ts b/invokeai/frontend/web/src/theme/components/button.ts index 7bb8a39a71..a59b9df826 100644 --- a/invokeai/frontend/web/src/theme/components/button.ts +++ b/invokeai/frontend/web/src/theme/components/button.ts @@ -19,16 +19,8 @@ const invokeAI = defineStyle((props) => { bg: mode('base.200', 'base.600')(props), color: mode('base.850', 'base.100')(props), borderRadius: 'base', - textShadow: mode( - '0 0 0.3rem var(--invokeai-colors-base-50)', - '0 0 0.3rem var(--invokeai-colors-base-900)' - )(props), svg: { fill: mode('base.850', 'base.100')(props), - filter: mode( - 'drop-shadow(0px 0px 0.3rem var(--invokeai-colors-base-100))', - 'drop-shadow(0px 0px 0.3rem var(--invokeai-colors-base-800))' - )(props), }, _hover: { bg: mode('base.300', 'base.500')(props), @@ -57,16 +49,8 @@ const invokeAI = defineStyle((props) => { bg: mode(`${c}.400`, `${c}.600`)(props), color: mode(`base.50`, `base.100`)(props), borderRadius: 'base', - textShadow: mode( - `0 0 0.3rem var(--invokeai-colors-${c}-600)`, - `0 0 0.3rem var(--invokeai-colors-${c}-800)` - )(props), svg: { fill: mode(`base.50`, `base.100`)(props), - filter: mode( - `drop-shadow(0px 0px 0.3rem var(--invokeai-colors-${c}-600))`, - `drop-shadow(0px 0px 0.3rem var(--invokeai-colors-${c}-800))` - )(props), }, _disabled, _hover: { From abe2a0f9b443c6b903e3579ce279af43f2a38080 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Fri, 14 Jul 2023 15:53:28 +1200 Subject: [PATCH 10/28] fix: merge conflicts (name renamed to model_name) for models --- .../ModelManagerPanel/CheckpointModelEdit.tsx | 2 +- .../subpanels/ModelManagerPanel/DiffusersModelEdit.tsx | 2 +- .../subpanels/ModelManagerPanel/ModelConvert.tsx | 10 ++++++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx index 30ec4cec32..105d2355e1 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -71,7 +71,7 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) { const editModelFormSubmitHandler = (values: CheckpointModelConfig) => { const responseBody = { base_model: retrievedModel.base_model, - model_name: retrievedModel.name, + model_name: retrievedModel.model_name, body: values, }; updateMainModel(responseBody); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx index 343fa729fe..800babaf59 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx @@ -67,7 +67,7 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { const editModelFormSubmitHandler = (values: DiffusersModelConfig) => { const responseBody = { base_model: retrievedModel.base_model, - model_name: retrievedModel.name, + model_name: retrievedModel.model_name, body: values, }; updateMainModel(responseBody); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx index 63e5767585..1fc2eecfed 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx @@ -38,7 +38,7 @@ export default function ModelConvert(props: ModelConvertProps) { const modelConvertHandler = () => { const responseBody = { base_model: model.base_model, - model_name: model.name, + model_name: model.model_name, }; convertModel(responseBody); @@ -46,7 +46,9 @@ export default function ModelConvert(props: ModelConvertProps) { dispatch( addToast( makeToast({ - title: `${t('modelManager.modelConversionFailed')}: ${model.name}`, + title: `${t('modelManager.modelConversionFailed')}: ${ + model.model_name + }`, status: 'error', }) ) @@ -57,7 +59,7 @@ export default function ModelConvert(props: ModelConvertProps) { dispatch( addToast( makeToast({ - title: `${t('modelManager.modelConverted')}: ${model.name}`, + title: `${t('modelManager.modelConverted')}: ${model.model_name}`, status: 'success', }) ) @@ -67,7 +69,7 @@ export default function ModelConvert(props: ModelConvertProps) { return ( Date: Fri, 14 Jul 2023 15:32:09 +1000 Subject: [PATCH 11/28] fix(ui): fix rtk tags I had mixed up `type` and `id` on a bunch of the tags. Fixing those --- .../src/services/api/endpoints/boardImages.ts | 11 ++++++----- .../web/src/services/api/endpoints/boards.ts | 19 +++++++++++++------ .../web/src/services/api/endpoints/models.ts | 18 +++++++++--------- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/endpoints/boardImages.ts b/invokeai/frontend/web/src/services/api/endpoints/boardImages.ts index f7a486e8fc..39deaf4172 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/boardImages.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/boardImages.ts @@ -4,7 +4,7 @@ import { paths } from '../schema'; type ListBoardImagesArg = paths['/api/v1/board_images/{board_id}']['get']['parameters']['path'] & - paths['/api/v1/board_images/{board_id}']['get']['parameters']['query']; + paths['/api/v1/board_images/{board_id}']['get']['parameters']['query']; type AddImageToBoardArg = paths['/api/v1/board_images/']['post']['requestBody']['content']['application/json']; @@ -25,11 +25,12 @@ export const boardImagesApi = api.injectEndpoints({ query: ({ board_id, offset, limit }) => ({ url: `board_images/${board_id}`, method: 'GET', - }), providesTags: (result, error, arg) => { // any list of boardimages - const tags: ApiFullTagDescription[] = [{ id: 'BoardImage', type: `${arg.board_id}_${LIST_TAG}` }]; + const tags: ApiFullTagDescription[] = [ + { type: 'BoardImage', id: `${arg.board_id}_${LIST_TAG}` }, + ]; if (result) { // and individual tags for each boardimage @@ -57,7 +58,7 @@ export const boardImagesApi = api.injectEndpoints({ }), invalidatesTags: (result, error, arg) => [ { type: 'BoardImage' }, - { type: 'Board', id: arg.board_id } + { type: 'Board', id: arg.board_id }, ], }), @@ -69,7 +70,7 @@ export const boardImagesApi = api.injectEndpoints({ }), invalidatesTags: (result, error, arg) => [ { type: 'BoardImage' }, - { type: 'Board', id: arg.board_id } + { type: 'Board', id: arg.board_id }, ], }), }), diff --git a/invokeai/frontend/web/src/services/api/endpoints/boards.ts b/invokeai/frontend/web/src/services/api/endpoints/boards.ts index 64ab21075d..fc3cb530a4 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/boards.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/boards.ts @@ -20,7 +20,7 @@ export const boardsApi = api.injectEndpoints({ query: (arg) => ({ url: 'boards/', params: arg }), providesTags: (result, error, arg) => { // any list of boards - const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST_TAG }]; + const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }]; if (result) { // and individual tags for each board @@ -43,7 +43,7 @@ export const boardsApi = api.injectEndpoints({ }), providesTags: (result, error, arg) => { // any list of boards - const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST_TAG }]; + const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }]; if (result) { // and individual tags for each board @@ -69,7 +69,7 @@ export const boardsApi = api.injectEndpoints({ method: 'POST', params: { board_name }, }), - invalidatesTags: [{ id: 'Board', type: LIST_TAG }], + invalidatesTags: [{ type: 'Board', id: LIST_TAG }], }), updateBoard: build.mutation({ @@ -87,8 +87,15 @@ export const boardsApi = api.injectEndpoints({ invalidatesTags: (result, error, arg) => [{ type: 'Board', id: arg }], }), deleteBoardAndImages: build.mutation({ - query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE', params: { include_images: true } }), - invalidatesTags: (result, error, arg) => [{ type: 'Board', id: arg }, { type: 'Image', id: LIST_TAG }], + query: (board_id) => ({ + url: `boards/${board_id}`, + method: 'DELETE', + params: { include_images: true }, + }), + invalidatesTags: (result, error, arg) => [ + { type: 'Board', id: arg }, + { type: 'Image', id: LIST_TAG }, + ], }), }), }); @@ -99,5 +106,5 @@ export const { useCreateBoardMutation, useUpdateBoardMutation, useDeleteBoardMutation, - useDeleteBoardAndImagesMutation + useDeleteBoardAndImagesMutation, } = boardsApi; diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 95b2146448..f239d4c5e4 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -99,7 +99,7 @@ export const modelsApi = api.injectEndpoints({ query: () => ({ url: 'models/', params: { model_type: 'main' } }), providesTags: (result, error, arg) => { const tags: ApiFullTagDescription[] = [ - { id: 'MainModel', type: LIST_TAG }, + { type: 'MainModel', id: LIST_TAG }, ]; if (result) { @@ -138,7 +138,7 @@ export const modelsApi = api.injectEndpoints({ body: body, }; }, - invalidatesTags: ['MainModel'], + invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], }), deleteMainModels: build.mutation< EntityState, @@ -150,7 +150,7 @@ export const modelsApi = api.injectEndpoints({ method: 'DELETE', }; }, - invalidatesTags: ['MainModel'], + invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], }), convertMainModels: build.mutation< EntityState, @@ -162,7 +162,7 @@ export const modelsApi = api.injectEndpoints({ method: 'PUT', }; }, - invalidatesTags: ['MainModel'], + invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], }), mergeMainModels: build.mutation< EntityState, @@ -175,13 +175,13 @@ export const modelsApi = api.injectEndpoints({ body: body, }; }, - invalidatesTags: ['MainModel'], + invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], }), getLoRAModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'lora' } }), providesTags: (result, error, arg) => { const tags: ApiFullTagDescription[] = [ - { id: 'LoRAModel', type: LIST_TAG }, + { type: 'LoRAModel', id: LIST_TAG }, ]; if (result) { @@ -216,7 +216,7 @@ export const modelsApi = api.injectEndpoints({ query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }), providesTags: (result, error, arg) => { const tags: ApiFullTagDescription[] = [ - { id: 'ControlNetModel', type: LIST_TAG }, + { type: 'ControlNetModel', id: LIST_TAG }, ]; if (result) { @@ -248,7 +248,7 @@ export const modelsApi = api.injectEndpoints({ query: () => ({ url: 'models/', params: { model_type: 'vae' } }), providesTags: (result, error, arg) => { const tags: ApiFullTagDescription[] = [ - { id: 'VaeModel', type: LIST_TAG }, + { type: 'VaeModel', id: LIST_TAG }, ]; if (result) { @@ -283,7 +283,7 @@ export const modelsApi = api.injectEndpoints({ query: () => ({ url: 'models/', params: { model_type: 'embedding' } }), providesTags: (result, error, arg) => { const tags: ApiFullTagDescription[] = [ - { id: 'TextualInversionModel', type: LIST_TAG }, + { type: 'TextualInversionModel', id: LIST_TAG }, ]; if (result) { From 834774ce4c82822f62818b27a63f311ef897e83a Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Fri, 14 Jul 2023 18:16:34 +1200 Subject: [PATCH 12/28] fix: Merge Conflicts --- .../subpanels/ModelManagerPanel/CheckpointModelEdit.tsx | 2 +- .../subpanels/ModelManagerPanel/DiffusersModelEdit.tsx | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx index 105d2355e1..c8b86cbd68 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -9,10 +9,10 @@ import { useTranslation } from 'react-i18next'; import type { RootState } from 'app/store/store'; import IAIButton from 'common/components/IAIButton'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; -import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; import { makeToast } from 'app/components/Toaster'; import IAIMantineTextInput from 'common/components/IAIMantineInput'; +import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { addToast } from 'features/system/store/systemSlice'; import { useUpdateMainModelsMutation } from 'services/api/endpoints/models'; import { components } from 'services/api/schema'; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx index 800babaf59..c073d10817 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx @@ -11,7 +11,8 @@ import type { RootState } from 'app/store/store'; import IAIButton from 'common/components/IAIButton'; import IAIMantineTextInput from 'common/components/IAIMantineInput'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; -import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; + +import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { addToast } from 'features/system/store/systemSlice'; import { useUpdateMainModelsMutation } from 'services/api/endpoints/models'; import { components } from 'services/api/schema'; From 66b12ab0ea8ed783bc410c176a713ce7f3c3f46e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 14 Jul 2023 16:59:13 +1000 Subject: [PATCH 13/28] fix(ui): do not blacklist the rtk query events doing so breaks the devtools --- invokeai/frontend/web/src/app/store/store.ts | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index da09b496d7..2bafd21a74 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -100,10 +100,11 @@ export const store = configureStore({ // manually type state, cannot type the arg // const typedState = state as ReturnType; - if (action.type.startsWith('api/')) { - // don't log api actions, with manual cache updates they are extremely noisy - return false; - } + // TODO: doing this breaks the rtk query devtools, commenting out for now + // if (action.type.startsWith('api/')) { + // // don't log api actions, with manual cache updates they are extremely noisy + // return false; + // } if (actionsDenylist.includes(action.type)) { // don't log other noisy actions From b2005d821ae864dbd32d9da7de8de42430e3df74 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 14 Jul 2023 16:59:31 +1000 Subject: [PATCH 14/28] fix(ui): fix types for models queries --- .../web/src/services/api/endpoints/models.ts | 37 ++++++++++++------- .../frontend/web/src/services/api/schema.d.ts | 24 ++++++------ 2 files changed, 35 insertions(+), 26 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index f239d4c5e4..1038a88c09 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -12,6 +12,7 @@ import { } from 'services/api/types'; import { ApiFullTagDescription, LIST_TAG, api } from '..'; +import { paths } from '../schema'; export type MainModelConfigEntity = MainModelConfig & { id: string }; @@ -34,27 +35,38 @@ type AnyModelConfigEntity = | TextualInversionModelConfigEntity | VaeModelConfigEntity; -type UpdateMainModelQuery = { +type UpdateMainModelArg = { base_model: BaseModelType; model_name: string; body: MainModelConfig; }; -type DeleteMainModelQuery = { +type UpdateMainModelResponse = + paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json']; + +type DeleteMainModelArg = { base_model: BaseModelType; model_name: string; }; -type ConvertMainModelQuery = { +type DeleteMainModelResponse = void; + +type ConvertMainModelArg = { base_model: BaseModelType; model_name: string; }; -type MergeMainModelQuery = { +type ConvertMainModelResponse = + paths['/api/v1/models/convert/{base_model}/{model_type}/{model_name}']['put']['responses']['200']['content']['application/json']; + +type MergeMainModelArg = { base_model: BaseModelType; body: MergeModelConfig; }; +type MergeMainModelResponse = + paths['/api/v1/models/merge/{base_model}']['put']['responses']['200']['content']['application/json']; + const mainModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), }); @@ -128,8 +140,8 @@ export const modelsApi = api.injectEndpoints({ }, }), updateMainModels: build.mutation< - EntityState, - UpdateMainModelQuery + UpdateMainModelResponse, + UpdateMainModelArg >({ query: ({ base_model, model_name, body }) => { return { @@ -141,8 +153,8 @@ export const modelsApi = api.injectEndpoints({ invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], }), deleteMainModels: build.mutation< - EntityState, - DeleteMainModelQuery + DeleteMainModelResponse, + DeleteMainModelArg >({ query: ({ base_model, model_name }) => { return { @@ -153,8 +165,8 @@ export const modelsApi = api.injectEndpoints({ invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], }), convertMainModels: build.mutation< - EntityState, - ConvertMainModelQuery + ConvertMainModelResponse, + ConvertMainModelArg >({ query: ({ base_model, model_name }) => { return { @@ -164,10 +176,7 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], }), - mergeMainModels: build.mutation< - EntityState, - MergeMainModelQuery - >({ + mergeMainModels: build.mutation({ query: ({ base_model, body }) => { return { url: `models/merge/${base_model}`, diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index 3da7a0bf8d..acbed14eac 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -3290,7 +3290,7 @@ export type components = { /** ModelsList */ ModelsList: { /** Models */ - models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[]; + models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[]; }; /** * MultiplyInvocation @@ -4605,18 +4605,18 @@ export type components = { */ image?: components["schemas"]["ImageField"]; }; - /** - * StableDiffusion1ModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusion2ModelFormat * @description An enumeration. * @enum {string} */ StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; + /** + * StableDiffusion1ModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; }; responses: never; parameters: never; @@ -4997,7 +4997,7 @@ export type operations = { /** @description The model imported successfully */ 201: { content: { - "application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; + "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"]; }; }; /** @description The model could not be found */ @@ -5065,14 +5065,14 @@ export type operations = { }; requestBody: { content: { - "application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; + "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"]; }; }; responses: { /** @description The model was updated successfully */ 200: { content: { - "application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; + "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"]; }; }; /** @description Bad request */ @@ -5106,7 +5106,7 @@ export type operations = { /** @description Model converted successfully */ 200: { content: { - "application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; + "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"]; }; }; /** @description Bad request */ @@ -5141,7 +5141,7 @@ export type operations = { /** @description Model converted successfully */ 200: { content: { - "application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; + "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"]; }; }; /** @description Incompatible models */ From a0cb18a12c4a9e0772322c5dd752b19f002847c4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 14 Jul 2023 17:34:13 +1000 Subject: [PATCH 15/28] feat(ui): refetch models on socket connect --- .../listeners/socketio/socketConnected.ts | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts index fe4bce682b..f01c3911da 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -1,6 +1,7 @@ import { log } from 'app/logging/useLogger'; -import { appSocketConnected, socketConnected } from 'services/events/actions'; +import { modelsApi } from 'services/api/endpoints/models'; import { receivedOpenAPISchema } from 'services/api/thunks/schema'; +import { appSocketConnected, socketConnected } from 'services/events/actions'; import { startAppListening } from '../..'; const moduleLog = log.child({ namespace: 'socketio' }); @@ -23,6 +24,13 @@ export const addSocketConnectedEventListener = () => { // pass along the socket event as an application action dispatch(appSocketConnected(action.payload)); + + // update all server state + dispatch(modelsApi.endpoints.getMainModels.initiate()); + dispatch(modelsApi.endpoints.getControlNetModels.initiate()); + dispatch(modelsApi.endpoints.getLoRAModels.initiate()); + dispatch(modelsApi.endpoints.getTextualInversionModels.initiate()); + dispatch(modelsApi.endpoints.getVaeModels.initiate()); }, }); }; From d8437d3036f16fedc08a99ffbf4c5f0a60b254a8 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 14 Jul 2023 17:34:34 +1000 Subject: [PATCH 16/28] feat(ui): add simple selectIsBusy selector --- .../web/src/features/system/store/systemSelectors.ts | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/system/store/systemSelectors.ts b/invokeai/frontend/web/src/features/system/store/systemSelectors.ts index e280210069..0d53da85e6 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSelectors.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSelectors.ts @@ -1,7 +1,7 @@ import { createSelector } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { reduce, pickBy } from 'lodash-es'; +import { pickBy, reduce } from 'lodash-es'; export const systemSelector = (state: RootState) => state.system; @@ -50,3 +50,8 @@ export const languageSelector = createSelector( export const isProcessingSelector = (state: RootState) => state.system.isProcessing; + +export const selectIsBusy = createSelector( + (state: RootState) => state, + (state) => state.system.isProcessing || !state.system.isConnected +); From 48a8bd49858a358229e6dfa037d71120b200b247 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 14 Jul 2023 17:35:45 +1000 Subject: [PATCH 17/28] feat(ui): add model update for success/failure handling --- .../ModelManagerPanel/CheckpointModelEdit.tsx | 45 ++++++++++--------- .../ModelManagerPanel/DiffusersModelEdit.tsx | 45 ++++++++++--------- 2 files changed, 48 insertions(+), 42 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx index c8b86cbd68..629f01e8fd 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -74,27 +74,30 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) { model_name: retrievedModel.model_name, body: values, }; - updateMainModel(responseBody); - - if (error) { - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelUpdateFailed'), - status: 'success', - }) - ) - ); - } - - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelUpdated'), - status: 'success', - }) - ) - ); + updateMainModel(responseBody) + .unwrap() + .then((payload) => { + checkpointEditForm.setValues(payload as CheckpointModelConfig); + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdated'), + status: 'success', + }) + ) + ); + }) + .catch((error) => { + checkpointEditForm.reset(); + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdateFailed'), + status: 'error', + }) + ) + ); + }); }; return modelToEdit ? ( diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx index c073d10817..0bf3d87838 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx @@ -71,27 +71,30 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { model_name: retrievedModel.model_name, body: values, }; - updateMainModel(responseBody); - - if (error) { - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelUpdateFailed'), - status: 'error', - }) - ) - ); - } - - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelUpdated'), - status: 'success', - }) - ) - ); + updateMainModel(responseBody) + .unwrap() + .then((payload) => { + diffusersEditForm.setValues(payload as DiffusersModelConfig); + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdated'), + status: 'success', + }) + ) + ); + }) + .catch((error) => { + diffusersEditForm.reset(); + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdateFailed'), + status: 'error', + }) + ) + ); + }); }; return modelToEdit ? ( From 6d7fb49a7a0791ccedec4e8ebba755a8fae2fff3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 14 Jul 2023 17:36:10 +1000 Subject: [PATCH 18/28] fix(ui): fix model edit button disabled status --- .../subpanels/ModelManagerPanel/CheckpointModelEdit.tsx | 8 +++----- .../subpanels/ModelManagerPanel/DiffusersModelEdit.tsx | 9 ++++----- .../subpanels/ModelManagerPanel/ModelListItem.tsx | 9 ++++----- 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx index 629f01e8fd..91d668f1e2 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -6,13 +6,13 @@ import { Divider, Flex, Text } from '@chakra-ui/react'; import { useForm } from '@mantine/form'; import { useTranslation } from 'react-i18next'; -import type { RootState } from 'app/store/store'; import IAIButton from 'common/components/IAIButton'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { makeToast } from 'app/components/Toaster'; import IAIMantineTextInput from 'common/components/IAIMantineInput'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; +import { selectIsBusy } from 'features/system/store/systemSelectors'; import { addToast } from 'features/system/store/systemSlice'; import { useUpdateMainModelsMutation } from 'services/api/endpoints/models'; import { components } from 'services/api/schema'; @@ -39,9 +39,7 @@ type CheckpointModelEditProps = { }; export default function CheckpointModelEdit(props: CheckpointModelEditProps) { - const isProcessing = useAppSelector( - (state: RootState) => state.system.isProcessing - ); + const isBusy = useAppSelector(selectIsBusy); const { modelToEdit, retrievedModel } = props; @@ -153,8 +151,8 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) { {...checkpointEditForm.getInputProps('config')} /> {t('modelManager.updateModel')} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx index 0bf3d87838..65793c9d2c 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx @@ -7,12 +7,12 @@ import { useTranslation } from 'react-i18next'; import { useForm } from '@mantine/form'; import { makeToast } from 'app/components/Toaster'; -import type { RootState } from 'app/store/store'; import IAIButton from 'common/components/IAIButton'; import IAIMantineTextInput from 'common/components/IAIMantineInput'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; +import { selectIsBusy } from 'features/system/store/systemSelectors'; import { addToast } from 'features/system/store/systemSlice'; import { useUpdateMainModelsMutation } from 'services/api/endpoints/models'; import { components } from 'services/api/schema'; @@ -38,9 +38,8 @@ const variantSelectData = [ ]; export default function DiffusersModelEdit(props: DiffusersModelEditProps) { - const isProcessing = useAppSelector( - (state: RootState) => state.system.isProcessing - ); + const isBusy = useAppSelector(selectIsBusy); + const { retrievedModel, modelToEdit } = props; const [updateMainModel, { isLoading, error }] = useUpdateMainModelsMutation(); @@ -138,8 +137,8 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { {...diffusersEditForm.getInputProps('vae')} /> {t('modelManager.updateModel')} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx index 993602da4c..5a90327aa3 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -13,6 +13,7 @@ import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIAlertDialog from 'common/components/IAIAlertDialog'; import IAIIconButton from 'common/components/IAIIconButton'; +import { selectIsBusy } from 'features/system/store/systemSelectors'; import { setOpenModel } from 'features/system/store/systemSlice'; import { useTranslation } from 'react-i18next'; import { useDeleteMainModelsMutation } from 'services/api/endpoints/models'; @@ -26,9 +27,7 @@ type ModelListItemProps = { }; export default function ModelListItem(props: ModelListItemProps) { - const { isProcessing, isConnected } = useAppSelector( - (state: RootState) => state.system - ); + const isBusy = useAppSelector(selectIsBusy); const { colorMode } = useColorMode(); @@ -89,7 +88,7 @@ export default function ModelListItem(props: ModelListItemProps) { size="sm" onClick={openModelHandler} aria-label={t('accessibility.modifyConfig')} - isDisabled={status === 'active' || isProcessing || !isConnected} + isDisabled={isBusy} /> } size="sm" aria-label={t('modelManager.deleteConfig')} - isDisabled={status === 'active' || isProcessing || !isConnected} + isDisabled={isBusy} colorScheme="error" /> } From f2af82bf7354ec9931252fa2884fb958e6093ee9 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 14 Jul 2023 17:39:00 +1000 Subject: [PATCH 19/28] feat(ui): add model convert for success/failure handling --- .../ModelManagerPanel/ModelConvert.tsx | 49 +++++++++---------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx index 1fc2eecfed..5df7631772 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx @@ -40,31 +40,30 @@ export default function ModelConvert(props: ModelConvertProps) { base_model: model.base_model, model_name: model.model_name, }; - convertModel(responseBody); - - if (error) { - dispatch( - addToast( - makeToast({ - title: `${t('modelManager.modelConversionFailed')}: ${ - model.model_name - }`, - status: 'error', - }) - ) - ); - } - - if (data) { - dispatch( - addToast( - makeToast({ - title: `${t('modelManager.modelConverted')}: ${model.model_name}`, - status: 'success', - }) - ) - ); - } + convertModel(responseBody) + .unwrap() + .then((payload) => { + dispatch( + addToast( + makeToast({ + title: `${t('modelManager.modelConverted')}: ${model.model_name}`, + status: 'success', + }) + ) + ); + }) + .catch((error) => { + dispatch( + addToast( + makeToast({ + title: `${t('modelManager.modelConversionFailed')}: ${ + model.model_name + }`, + status: 'error', + }) + ) + ); + }); }; return ( From 1e5ae9d986b11c15c2225ea3f55a4f2853e4be2f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 14 Jul 2023 19:22:37 +1000 Subject: [PATCH 20/28] feat(ui): refactor model manager ui - simplify UI logic in `ModelManagerPanel` components - fix up the types a bit to make it easier to select models - remove `openModel` state, just make it a useState since it is very local to model manager --- .../system/store/systemPersistDenylist.ts | 1 - .../src/features/system/store/systemSlice.ts | 6 - .../tabs/ModelManager/ModelManagerTab.tsx | 67 +--- .../subpanels/ModelManagerPanel.tsx | 83 +++-- .../ModelManagerPanel/CheckpointModelEdit.tsx | 138 ++++---- .../ModelManagerPanel/DiffusersModelEdit.tsx | 130 ++++---- .../subpanels/ModelManagerPanel/ModelList.tsx | 310 ++++++------------ .../ModelManagerPanel/ModelListItem.tsx | 161 ++++----- .../web/src/services/api/endpoints/models.ts | 10 +- .../frontend/web/src/services/api/types.d.ts | 8 +- .../frontend/web/src/theme/components/text.ts | 2 +- 11 files changed, 384 insertions(+), 532 deletions(-) diff --git a/invokeai/frontend/web/src/features/system/store/systemPersistDenylist.ts b/invokeai/frontend/web/src/features/system/store/systemPersistDenylist.ts index c2481c29df..bba279c4bc 100644 --- a/invokeai/frontend/web/src/features/system/store/systemPersistDenylist.ts +++ b/invokeai/frontend/web/src/features/system/store/systemPersistDenylist.ts @@ -13,7 +13,6 @@ export const systemPersistDenylist: (keyof SystemState)[] = [ 'isProcessing', 'totalIterations', 'totalSteps', - 'openModel', 'isCancelScheduled', 'progressImage', 'wereModelsReceived', diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index 01c1344263..4d723378ba 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -46,7 +46,6 @@ export interface SystemState { toastQueue: UseToastOptions[]; searchFolder: string | null; foundModels: InvokeAI.FoundModel[] | null; - openModel: string | null; /** * The current progress image */ @@ -109,7 +108,6 @@ export const initialSystemState: SystemState = { toastQueue: [], searchFolder: null, foundModels: null, - openModel: null, progressImage: null, shouldAntialiasProgressImage: false, sessionId: null, @@ -164,9 +162,6 @@ export const systemSlice = createSlice({ ) => { state.foundModels = action.payload; }, - setOpenModel: (state, action: PayloadAction) => { - state.openModel = action.payload; - }, /** * A cancel was scheduled */ @@ -433,7 +428,6 @@ export const { clearToastQueue, setSearchFolder, setFoundModels, - setOpenModel, cancelScheduled, scheduledCancelAborted, cancelTypeChanged, diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx index 7a56a3f651..70de375774 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx @@ -1,14 +1,6 @@ -import { - Tab, - TabList, - TabPanel, - TabPanels, - Tabs, - useColorMode, -} from '@chakra-ui/react'; +import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react'; import i18n from 'i18n'; import { ReactNode, memo } from 'react'; -import { mode } from 'theme/util/mode'; import AddModelsPanel from './subpanels/AddModelsPanel'; import MergeModelsPanel from './subpanels/MergeModelsPanel'; import ModelManagerPanel from './subpanels/ModelManagerPanel'; @@ -21,7 +13,7 @@ type ModelManagerTabInfo = { content: ReactNode; }; -const modelManagerTabs: ModelManagerTabInfo[] = [ +const tabs: ModelManagerTabInfo[] = [ { id: 'modelManager', label: i18n.t('modelManager.modelManager'), @@ -40,50 +32,25 @@ const modelManagerTabs: ModelManagerTabInfo[] = [ ]; const ModelManagerTab = () => { - const { colorMode } = useColorMode(); - - const renderTabsList = () => { - const modelManagerTabListsToRender: ReactNode[] = []; - - modelManagerTabs.forEach((modelManagerTab) => { - modelManagerTabListsToRender.push( - {modelManagerTab.label} - ); - }); - - return ( - - {modelManagerTabListsToRender} - - ); - }; - - const renderTabPanels = () => { - const modelManagerTabPanelsToRender: ReactNode[] = []; - modelManagerTabs.forEach((modelManagerTab) => { - modelManagerTabPanelsToRender.push( - {modelManagerTab.content} - ); - }); - - return {modelManagerTabPanelsToRender}; - }; return ( - {renderTabsList()} - {renderTabPanels()} + + {tabs.map((tab) => ( + + {tab.label} + + ))} + + + {tabs.map((tab) => ( + {tab.content} + ))} + ); }; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx index ddf0874b21..f681b79437 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx @@ -1,48 +1,59 @@ -import { Flex } from '@chakra-ui/react'; -import { RootState } from 'app/store/store'; -import { useAppSelector } from 'app/store/storeHooks'; +import { Flex, Text } from '@chakra-ui/react'; -import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { useState } from 'react'; +import { + MainModelConfigEntity, + useGetMainModelsQuery, +} from 'services/api/endpoints/models'; import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; import ModelList from './ModelManagerPanel/ModelList'; export default function ModelManagerPanel() { - const { data: mainModels } = useGetMainModelsQuery(); + const [selectedModelId, setSelectedModelId] = useState(); + const { model } = useGetMainModelsQuery(undefined, { + selectFromResult: ({ data }) => ({ + model: selectedModelId ? data?.entities[selectedModelId] : undefined, + }), + }); - const openModel = useAppSelector( - (state: RootState) => state.system.openModel - ); - - const renderModelEditTabs = () => { - if (!openModel || !mainModels) return; - - const openedModelData = mainModels['entities'][openModel]; - - if (openedModelData && openedModelData.model_format === 'diffusers') { - return ( - - ); - } - - if (openedModelData && openedModelData.model_format === 'checkpoint') { - return ( - - ); - } - }; return ( - - {renderModelEditTabs()} + + ); } + +type ModelEditProps = { + model: MainModelConfigEntity | undefined; +}; + +const ModelEdit = (props: ModelEditProps) => { + const { model } = props; + + if (model?.model_format === 'checkpoint') { + return ; + } + + if (model?.model_format === 'diffusers') { + return ; + } + + return ( + + Pick A Model To Edit + + ); +}; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx index 91d668f1e2..2cdaea904f 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -1,21 +1,20 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; - import { Divider, Flex, Text } from '@chakra-ui/react'; - -// import { addNewModel } from 'app/socketio/actions'; import { useForm } from '@mantine/form'; -import { useTranslation } from 'react-i18next'; - -import IAIButton from 'common/components/IAIButton'; -import IAIMantineSelect from 'common/components/IAIMantineSelect'; - import { makeToast } from 'app/components/Toaster'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAIButton from 'common/components/IAIButton'; import IAIMantineTextInput from 'common/components/IAIMantineInput'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { selectIsBusy } from 'features/system/store/systemSelectors'; import { addToast } from 'features/system/store/systemSlice'; -import { useUpdateMainModelsMutation } from 'services/api/endpoints/models'; -import { components } from 'services/api/schema'; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + CheckpointModelConfigEntity, + useUpdateMainModelsMutation, +} from 'services/api/endpoints/models'; +import { CheckpointModelConfig } from 'services/api/types'; import ModelConvert from './ModelConvert'; const baseModelSelectData = [ @@ -29,36 +28,31 @@ const variantSelectData = [ { value: 'depth', label: 'Depth' }, ]; -export type CheckpointModelConfig = - | components['schemas']['StableDiffusion1ModelCheckpointConfig'] - | components['schemas']['StableDiffusion2ModelCheckpointConfig']; - type CheckpointModelEditProps = { - modelToEdit: string; - retrievedModel: CheckpointModelConfig; + model: CheckpointModelConfigEntity; }; export default function CheckpointModelEdit(props: CheckpointModelEditProps) { const isBusy = useAppSelector(selectIsBusy); - const { modelToEdit, retrievedModel } = props; + const { model } = props; - const [updateMainModel, { error, isLoading }] = useUpdateMainModelsMutation(); + const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation(); const dispatch = useAppDispatch(); const { t } = useTranslation(); const checkpointEditForm = useForm({ initialValues: { - model_name: retrievedModel.model_name ? retrievedModel.model_name : '', - base_model: retrievedModel.base_model, + model_name: model.model_name ? model.model_name : '', + base_model: model.base_model, model_type: 'main', - path: retrievedModel.path ? retrievedModel.path : '', - description: retrievedModel.description ? retrievedModel.description : '', + path: model.path ? model.path : '', + description: model.description ? model.description : '', model_format: 'checkpoint', - vae: retrievedModel.vae ? retrievedModel.vae : '', - config: retrievedModel.config ? retrievedModel.config : '', - variant: retrievedModel.variant, + vae: model.vae ? model.vae : '', + config: model.config ? model.config : '', + variant: model.variant, }, validate: { path: (value) => @@ -66,50 +60,60 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) { }, }); - const editModelFormSubmitHandler = (values: CheckpointModelConfig) => { - const responseBody = { - base_model: retrievedModel.base_model, - model_name: retrievedModel.model_name, - body: values, - }; - updateMainModel(responseBody) - .unwrap() - .then((payload) => { - checkpointEditForm.setValues(payload as CheckpointModelConfig); - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelUpdated'), - status: 'success', - }) - ) - ); - }) - .catch((error) => { - checkpointEditForm.reset(); - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelUpdateFailed'), - status: 'error', - }) - ) - ); - }); - }; + const editModelFormSubmitHandler = useCallback( + (values: CheckpointModelConfig) => { + const responseBody = { + base_model: model.base_model, + model_name: model.model_name, + body: values, + }; + updateMainModel(responseBody) + .unwrap() + .then((payload) => { + checkpointEditForm.setValues(payload as CheckpointModelConfig); + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdated'), + status: 'success', + }) + ) + ); + }) + .catch((error) => { + checkpointEditForm.reset(); + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdateFailed'), + status: 'error', + }) + ) + ); + }); + }, + [ + checkpointEditForm, + dispatch, + model.base_model, + model.model_name, + t, + updateMainModel, + ] + ); - return modelToEdit ? ( + return ( - {retrievedModel.model_name} + {model.model_name} - {MODEL_TYPE_MAP[retrievedModel.base_model]} Model + {MODEL_TYPE_MAP[model.base_model]} Model - + @@ -161,17 +165,5 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) { - ) : ( - - Pick A Model To Edit - ); } diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx index 65793c9d2c..5c1667b331 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx @@ -1,29 +1,23 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; - import { Divider, Flex, Text } from '@chakra-ui/react'; - -// import { addNewModel } from 'app/socketio/actions'; -import { useTranslation } from 'react-i18next'; - import { useForm } from '@mantine/form'; import { makeToast } from 'app/components/Toaster'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIButton from 'common/components/IAIButton'; import IAIMantineTextInput from 'common/components/IAIMantineInput'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; - import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { selectIsBusy } from 'features/system/store/systemSelectors'; import { addToast } from 'features/system/store/systemSlice'; -import { useUpdateMainModelsMutation } from 'services/api/endpoints/models'; -import { components } from 'services/api/schema'; - -export type DiffusersModelConfig = - | components['schemas']['StableDiffusion1ModelDiffusersConfig'] - | components['schemas']['StableDiffusion2ModelDiffusersConfig']; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + DiffusersModelConfigEntity, + useUpdateMainModelsMutation, +} from 'services/api/endpoints/models'; +import { DiffusersModelConfig } from 'services/api/types'; type DiffusersModelEditProps = { - modelToEdit: string; - retrievedModel: DiffusersModelConfig; + model: DiffusersModelConfigEntity; }; const baseModelSelectData = [ @@ -40,23 +34,23 @@ const variantSelectData = [ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { const isBusy = useAppSelector(selectIsBusy); - const { retrievedModel, modelToEdit } = props; + const { model } = props; - const [updateMainModel, { isLoading, error }] = useUpdateMainModelsMutation(); + const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation(); const dispatch = useAppDispatch(); const { t } = useTranslation(); const diffusersEditForm = useForm({ initialValues: { - model_name: retrievedModel.model_name ? retrievedModel.model_name : '', - base_model: retrievedModel.base_model, + model_name: model.model_name ? model.model_name : '', + base_model: model.base_model, model_type: 'main', - path: retrievedModel.path ? retrievedModel.path : '', - description: retrievedModel.description ? retrievedModel.description : '', + path: model.path ? model.path : '', + description: model.description ? model.description : '', model_format: 'diffusers', - vae: retrievedModel.vae ? retrievedModel.vae : '', - variant: retrievedModel.variant, + vae: model.vae ? model.vae : '', + variant: model.variant, }, validate: { path: (value) => @@ -64,46 +58,56 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { }, }); - const editModelFormSubmitHandler = (values: DiffusersModelConfig) => { - const responseBody = { - base_model: retrievedModel.base_model, - model_name: retrievedModel.model_name, - body: values, - }; - updateMainModel(responseBody) - .unwrap() - .then((payload) => { - diffusersEditForm.setValues(payload as DiffusersModelConfig); - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelUpdated'), - status: 'success', - }) - ) - ); - }) - .catch((error) => { - diffusersEditForm.reset(); - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelUpdateFailed'), - status: 'error', - }) - ) - ); - }); - }; + const editModelFormSubmitHandler = useCallback( + (values: DiffusersModelConfig) => { + const responseBody = { + base_model: model.base_model, + model_name: model.model_name, + body: values, + }; + updateMainModel(responseBody) + .unwrap() + .then((payload) => { + diffusersEditForm.setValues(payload as DiffusersModelConfig); + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdated'), + status: 'success', + }) + ) + ); + }) + .catch((error) => { + diffusersEditForm.reset(); + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdateFailed'), + status: 'error', + }) + ) + ); + }); + }, + [ + diffusersEditForm, + dispatch, + model.base_model, + model.model_name, + t, + updateMainModel, + ] + ); - return modelToEdit ? ( + return ( - {retrievedModel.model_name} + {model.model_name} - {MODEL_TYPE_MAP[retrievedModel.base_model]} Model + {MODEL_TYPE_MAP[model.base_model]} Model @@ -146,17 +150,5 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { - ) : ( - - Pick A Model To Edit - ); } diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx index 655b8c4de3..b0e44f7615 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -1,187 +1,46 @@ -import { Box, Flex, Spinner, Text, useColorMode } from '@chakra-ui/react'; +import { ButtonGroup, Flex, Text } from '@chakra-ui/react'; +import { EntityState } from '@reduxjs/toolkit'; import IAIButton from 'common/components/IAIButton'; import IAIInput from 'common/components/IAIInput'; - +import { forEach } from 'lodash-es'; +import type { ChangeEvent } from 'react'; +import { useCallback, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + MainModelConfigEntity, + useGetMainModelsQuery, +} from 'services/api/endpoints/models'; import ModelListItem from './ModelListItem'; -import { useTranslation } from 'react-i18next'; +type ModelListProps = { + selectedModelId: string | undefined; + setSelectedModelId: (name: string | undefined) => void; +}; -import type { ChangeEvent, ReactNode } from 'react'; -import React, { useMemo, useState, useTransition } from 'react'; -import { useGetMainModelsQuery } from 'services/api/endpoints/models'; -import { mode } from 'theme/util/mode'; - -function ModelFilterButton({ - label, - isActive, - onClick, -}: { - label: string; - isActive: boolean; - onClick: () => void; -}) { - return ( - - {label} - - ); -} - -const ModelList = () => { - const { data: mainModels } = useGetMainModelsQuery(); - const { colorMode } = useColorMode(); - - const [renderModelList, setRenderModelList] = React.useState(false); - - React.useEffect(() => { - const timer = setTimeout(() => { - setRenderModelList(true); - }, 200); - - return () => clearTimeout(timer); - }, []); - - const [searchText, setSearchText] = useState(''); - const [isSelectedFilter, setIsSelectedFilter] = useState< - 'all' | 'checkpoint' | 'diffusers' - >('all'); - const [_, startTransition] = useTransition(); +type ModelFormat = 'all' | 'checkpoint' | 'diffusers'; +const ModelList = (props: ModelListProps) => { + const { selectedModelId, setSelectedModelId } = props; const { t } = useTranslation(); + const [nameFilter, setNameFilter] = useState(''); + const [modelFormatFilter, setModelFormatFilter] = + useState('all'); - const handleSearchFilter = (e: ChangeEvent) => { - startTransition(() => { - setSearchText(e.target.value); - }); - }; + const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, { + selectFromResult: ({ data }) => ({ + filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter), + }), + }); - const renderModelListItems = useMemo(() => { - const ckptModelListItemsToRender: ReactNode[] = []; - const diffusersModelListItemsToRender: ReactNode[] = []; - const filteredModelListItemsToRender: ReactNode[] = []; - const localFilteredModelListItemsToRender: ReactNode[] = []; + const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, { + selectFromResult: ({ data }) => ({ + filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter), + }), + }); - if (!mainModels) return; - - const modelList = mainModels.entities; - - Object.keys(modelList).forEach((model, i) => { - const modelInfo = modelList[model]; - - // If no model info found for a model, ignore it - if (!modelInfo) return; - - if ( - modelInfo.model_name.toLowerCase().includes(searchText.toLowerCase()) - ) { - filteredModelListItemsToRender.push( - - ); - if (modelInfo?.model_format === isSelectedFilter) { - localFilteredModelListItemsToRender.push( - - ); - } - } - - if (modelInfo?.model_format !== 'diffusers') { - ckptModelListItemsToRender.push( - - ); - } else { - diffusersModelListItemsToRender.push( - - ); - } - }); - - return searchText !== '' ? ( - isSelectedFilter === 'all' ? ( - {filteredModelListItemsToRender} - ) : ( - {localFilteredModelListItemsToRender} - ) - ) : ( - - {isSelectedFilter === 'all' && ( - <> - {diffusersModelListItemsToRender.length > 0 && ( - - - {t('modelManager.diffusersModels')} - - {diffusersModelListItemsToRender} - - )} - - {ckptModelListItemsToRender.length > 0 && ( - - - {t('modelManager.checkpointModels')} - - {ckptModelListItemsToRender} - - )} - - )} - - {isSelectedFilter === 'diffusers' && ( - - {diffusersModelListItemsToRender} - - )} - - {isSelectedFilter === 'checkpoint' && ( - - {ckptModelListItemsToRender} - - )} - - ); - }, [mainModels, searchText, t, isSelectedFilter, colorMode]); + const handleSearchFilter = useCallback((e: ChangeEvent) => { + setNameFilter(e.target.value); + }, []); return ( @@ -189,7 +48,6 @@ const ModelList = () => { onChange={handleSearchFilter} label={t('modelManager.search')} /> - { overflow="scroll" paddingInlineEnd={4} > - - setIsSelectedFilter('all')} - isActive={isSelectedFilter === 'all'} - /> - setIsSelectedFilter('diffusers')} - isActive={isSelectedFilter === 'diffusers'} - /> - setIsSelectedFilter('checkpoint')} - isActive={isSelectedFilter === 'checkpoint'} - /> - - - {renderModelList ? ( - renderModelListItems - ) : ( - + setModelFormatFilter('all')} + isChecked={modelFormatFilter === 'all'} + size="sm" > - + {t('modelManager.allModels')} + + setModelFormatFilter('diffusers')} + isChecked={modelFormatFilter === 'diffusers'} + > + {t('modelManager.diffusersModels')} + + setModelFormatFilter('checkpoint')} + isChecked={modelFormatFilter === 'checkpoint'} + > + {t('modelManager.checkpointModels')} + + + + {['all', 'diffusers'].includes(modelFormatFilter) && ( + + + Diffusers + + {filteredDiffusersModels.map((model) => ( + + ))} + + )} + {['all', 'checkpoint'].includes(modelFormatFilter) && ( + + + Checkpoint + + {filteredCheckpointModels.map((model) => ( + + ))} )} @@ -233,3 +115,27 @@ const ModelList = () => { }; export default ModelList; + +const modelsFilter = ( + data: EntityState | undefined, + model_format: ModelFormat, + nameFilter: string +) => { + const filteredModels: MainModelConfigEntity[] = []; + forEach(data?.entities, (model) => { + if (!model) { + return; + } + + const matchesFilter = model.model_name + .toLowerCase() + .includes(nameFilter.toLowerCase()); + + const matchesFormat = model.model_format === model_format; + + if (matchesFilter && matchesFormat) { + filteredModels.push(model); + } + }); + return filteredModels; +}; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx index 5a90327aa3..84ed784d1e 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -1,115 +1,96 @@ -import { DeleteIcon, EditIcon } from '@chakra-ui/icons'; -import { - Box, - Flex, - Spacer, - Text, - Tooltip, - useColorMode, -} from '@chakra-ui/react'; - -// import { deleteModel, requestModelChange } from 'app/socketio/actions'; -import { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { DeleteIcon } from '@chakra-ui/icons'; +import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react'; +import { useAppSelector } from 'app/store/storeHooks'; import IAIAlertDialog from 'common/components/IAIAlertDialog'; +import IAIButton from 'common/components/IAIButton'; import IAIIconButton from 'common/components/IAIIconButton'; import { selectIsBusy } from 'features/system/store/systemSelectors'; -import { setOpenModel } from 'features/system/store/systemSlice'; +import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { useDeleteMainModelsMutation } from 'services/api/endpoints/models'; -import { BaseModelType } from 'services/api/types'; -import { mode } from 'theme/util/mode'; +import { FaEdit } from 'react-icons/fa'; +import { + MainModelConfigEntity, + useDeleteMainModelsMutation, +} from 'services/api/endpoints/models'; type ModelListItemProps = { - modelKey: string; - name: string; - description: string | undefined; + model: MainModelConfigEntity; + isSelected: boolean; + setSelectedModelId: (v: string | undefined) => void; }; export default function ModelListItem(props: ModelListItemProps) { const isBusy = useAppSelector(selectIsBusy); - - const { colorMode } = useColorMode(); - - const openModel = useAppSelector( - (state: RootState) => state.system.openModel - ); - + const { t } = useTranslation(); const [deleteMainModel] = useDeleteMainModelsMutation(); - const { t } = useTranslation(); + const { model, isSelected, setSelectedModelId } = props; - const dispatch = useAppDispatch(); + const handleSelectModel = useCallback(() => { + setSelectedModelId(model.id); + }, [model.id, setSelectedModelId]); - const { modelKey, name, description } = props; - - const openModelHandler = () => { - dispatch(setOpenModel(modelKey)); - }; - - const handleModelDelete = () => { - const [base_model, _, model_name] = modelKey.split('/'); - deleteMainModel({ - base_model: base_model as BaseModelType, - model_name: model_name, - }); - dispatch(setOpenModel(null)); - }; + const handleModelDelete = useCallback(() => { + deleteMainModel(model); + setSelectedModelId(undefined); + }, [deleteMainModel, model, setSelectedModelId]); return ( - - - - {name} - - - - + + + + + {model.model_name} + + + } + icon={} size="sm" - onClick={openModelHandler} + onClick={handleSelectModel} aria-label={t('accessibility.modifyConfig')} isDisabled={isBusy} + variant="link" /> - } - size="sm" - aria-label={t('modelManager.deleteConfig')} - isDisabled={isBusy} - colorScheme="error" - /> - } - > - -

{t('modelManager.deleteMsg1')}

-

{t('modelManager.deleteMsg2')}

-
-
+ } + aria-label={t('modelManager.deleteConfig')} + isDisabled={isBusy} + colorScheme="error" + /> + } + > + +

{t('modelManager.deleteMsg1')}

+

{t('modelManager.deleteMsg2')}

+
+
); } diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 1038a88c09..c86ad91100 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -3,7 +3,9 @@ import { cloneDeep } from 'lodash-es'; import { AnyModelConfig, BaseModelType, + CheckpointModelConfig, ControlNetModelConfig, + DiffusersModelConfig, LoRAModelConfig, MainModelConfig, MergeModelConfig, @@ -14,7 +16,13 @@ import { import { ApiFullTagDescription, LIST_TAG, api } from '..'; import { paths } from '../schema'; -export type MainModelConfigEntity = MainModelConfig & { id: string }; +export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string }; +export type CheckpointModelConfigEntity = CheckpointModelConfig & { + id: string; +}; +export type MainModelConfigEntity = + | DiffusersModelConfigEntity + | CheckpointModelConfigEntity; export type LoRAModelConfigEntity = LoRAModelConfig & { id: string }; diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts index c2657701e7..fcbbd1a6a0 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.d.ts @@ -42,11 +42,13 @@ export type ControlNetModelConfig = components['schemas']['ControlNetModelConfig']; export type TextualInversionModelConfig = components['schemas']['TextualInversionModelConfig']; -export type MainModelConfig = - | components['schemas']['StableDiffusion1ModelCheckpointConfig'] +export type DiffusersModelConfig = | components['schemas']['StableDiffusion1ModelDiffusersConfig'] - | components['schemas']['StableDiffusion2ModelCheckpointConfig'] | components['schemas']['StableDiffusion2ModelDiffusersConfig']; +export type CheckpointModelConfig = + | components['schemas']['StableDiffusion1ModelCheckpointConfig'] + | components['schemas']['StableDiffusion2ModelCheckpointConfig']; +export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig; export type AnyModelConfig = | LoRAModelConfig | VaeModelConfig diff --git a/invokeai/frontend/web/src/theme/components/text.ts b/invokeai/frontend/web/src/theme/components/text.ts index 2404bf0594..cccbcf2391 100644 --- a/invokeai/frontend/web/src/theme/components/text.ts +++ b/invokeai/frontend/web/src/theme/components/text.ts @@ -2,7 +2,7 @@ import { defineStyle, defineStyleConfig } from '@chakra-ui/react'; import { mode } from '@chakra-ui/theme-tools'; const subtext = defineStyle((props) => ({ - color: mode('colors.base.500', 'colors.base.400')(props), + color: mode('base.500', 'base.400')(props), })); export const textTheme = defineStyleConfig({ From 0b2f0c05b229d7de219da343b49439b3445f01e3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 14 Jul 2023 19:31:52 +1000 Subject: [PATCH 21/28] fix(ui): fix selecting model does not update form --- .../tabs/ModelManager/subpanels/ModelManagerPanel.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx index f681b79437..d2beb3a674 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx @@ -36,11 +36,11 @@ const ModelEdit = (props: ModelEditProps) => { const { model } = props; if (model?.model_format === 'checkpoint') { - return ; + return ; } if (model?.model_format === 'diffusers') { - return ; + return ; } return ( From 56d209842f4d3d8fadd664200d6c632b7c77c8b5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 14 Jul 2023 19:46:18 +1000 Subject: [PATCH 22/28] feat(ui): only show modellistitem when none in array --- .../subpanels/ModelManagerPanel/ModelList.tsx | 62 ++++++++++--------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx index b0e44f7615..003cb309b8 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -79,36 +79,38 @@ const ModelList = (props: ModelListProps) => { - {['all', 'diffusers'].includes(modelFormatFilter) && ( - - - Diffusers - - {filteredDiffusersModels.map((model) => ( - - ))} - - )} - {['all', 'checkpoint'].includes(modelFormatFilter) && ( - - - Checkpoint - - {filteredCheckpointModels.map((model) => ( - - ))} - - )} + {['all', 'diffusers'].includes(modelFormatFilter) && + filteredDiffusersModels.length > 0 && ( + + + Diffusers + + {filteredDiffusersModels.map((model) => ( + + ))} + + )} + {['all', 'checkpoint'].includes(modelFormatFilter) && + filteredCheckpointModels.length > 0 && ( + + + Checkpoint + + {filteredCheckpointModels.map((model) => ( + + ))} + + )}
); From eb2a7058bf697d11a55f92d5a85d718501056d6d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 14 Jul 2023 19:49:05 +1000 Subject: [PATCH 23/28] feat(ui): tweak fontSize in modellist --- .../ModelManager/subpanels/ModelManagerPanel/ModelList.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx index 003cb309b8..ecb9383b1e 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -82,7 +82,7 @@ const ModelList = (props: ModelListProps) => { {['all', 'diffusers'].includes(modelFormatFilter) && filteredDiffusersModels.length > 0 && ( - + Diffusers {filteredDiffusersModels.map((model) => ( @@ -98,7 +98,7 @@ const ModelList = (props: ModelListProps) => { {['all', 'checkpoint'].includes(modelFormatFilter) && filteredCheckpointModels.length > 0 && ( - + Checkpoint {filteredCheckpointModels.map((model) => ( From d4dfd84525ae8e94fa58b4d79b824bbfc92d4bfd Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 14 Jul 2023 20:12:02 +1000 Subject: [PATCH 24/28] feat(ui): mm colors --- .../tabs/ModelManager/ModelManagerTab.tsx | 6 ++-- .../subpanels/ModelManagerPanel.tsx | 11 ++++--- .../ModelManagerPanel/ModelListItem.tsx | 33 +++++++------------ 3 files changed, 22 insertions(+), 28 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx index 70de375774..9aced0dda8 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx @@ -46,9 +46,11 @@ const ModelManagerTab = () => { ))} - + {tabs.map((tab) => ( - {tab.content} + + {tab.content} + ))}
diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx index d2beb3a674..f49294cfb0 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx @@ -18,7 +18,7 @@ export default function ModelManagerPanel() { }); return ( - + { return ( - Pick A Model To Edit + No Model Selected ); }; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx index 84ed784d1e..5ad0016fea 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -1,13 +1,11 @@ import { DeleteIcon } from '@chakra-ui/icons'; -import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react'; +import { Button, Flex, Text, Tooltip } from '@chakra-ui/react'; import { useAppSelector } from 'app/store/storeHooks'; import IAIAlertDialog from 'common/components/IAIAlertDialog'; -import IAIButton from 'common/components/IAIButton'; import IAIIconButton from 'common/components/IAIIconButton'; import { selectIsBusy } from 'features/system/store/systemSelectors'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { FaEdit } from 'react-icons/fa'; import { MainModelConfigEntity, useDeleteMainModelsMutation, @@ -38,40 +36,33 @@ export default function ModelListItem(props: ModelListItemProps) { return ( - - - {model.model_name} - - - - } - size="sm" - onClick={handleSelectModel} - aria-label={t('accessibility.modifyConfig')} - isDisabled={isBusy} - variant="link" - /> + + {model.model_name} + Date: Fri, 14 Jul 2023 23:00:38 +1000 Subject: [PATCH 25/28] feat(ui): extract mantine component styles to hook, add less opinionated mantine components IAIMantineSelect and IAIMantineMultiSelect have a bit of extra logic that prevents simple select functionality from working as expected. - extract the styles into hooks - rename those two components to IAIMantineSearchableSelect and IAIMantineSearchableMultiSelect - Create IAIMantineSelect (which is just a dropdown) and use it in model manager and a few other places When we only have a few options to present and searching is not efficient, we should use this instead. --- .../app/components/ThemeLocaleProvider.tsx | 3 +- .../src/common/components/IAIMantineInput.tsx | 17 +- .../components/IAIMantineMultiSelect.tsx | 113 +----------- .../components/IAIMantineSearchableSelect.tsx | 78 +++++++++ .../common/components/IAIMantineSelect.tsx | 159 +---------------- .../IAICanvasToolbar/IAICanvasToolbar.tsx | 4 +- .../parameters/ParamControlNetModel.tsx | 6 +- .../ParamControlNetProcessorSelect.tsx | 14 +- .../components/ParamEmbeddingPopover.tsx | 4 +- .../Boards/UpdateImageBoardModal.tsx | 6 +- .../lora/components/ParamLoraSelect.tsx | 4 +- .../features/nodes/components/AddNodeMenu.tsx | 4 +- .../fields/LoRAModelInputFieldComponent.tsx | 4 +- .../fields/ModelInputFieldComponent.tsx | 6 +- .../fields/VaeModelInputFieldComponent.tsx | 4 +- .../ParamScaleBeforeProcessing.tsx | 4 +- .../Parameters/Core/ParamScheduler.tsx | 4 +- .../FaceRestore/FaceRestoreType.tsx | 4 +- .../MainModel/ParamMainModelSelect.tsx | 6 +- .../Parameters/Upscale/UpscaleScale.tsx | 4 +- .../VAEModel/ParamVAEModelSelect.tsx | 4 +- .../subpanels/MergeModelsPanel.tsx | 7 +- .../ModelManagerPanel/ModelListItem.tsx | 6 +- .../UnifiedCanvas/UnifiedCanvasContent.tsx | 10 +- .../hooks/useMantineMultiSelectStyles.ts | 140 +++++++++++++++ .../hooks/useMantineSelectStyles.ts | 134 ++++++++++++++ .../frontend/web/src/mantine-theme/theme.ts | 164 ++++++++++++++++-- 27 files changed, 580 insertions(+), 333 deletions(-) create mode 100644 invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx create mode 100644 invokeai/frontend/web/src/mantine-theme/hooks/useMantineMultiSelectStyles.ts create mode 100644 invokeai/frontend/web/src/mantine-theme/hooks/useMantineSelectStyles.ts diff --git a/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx b/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx index 1e86e0ce1b..132667f8ab 100644 --- a/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx +++ b/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx @@ -9,7 +9,6 @@ import { theme as invokeAITheme } from 'theme/theme'; import '@fontsource-variable/inter'; import { MantineProvider } from '@mantine/core'; -import { mantineTheme } from 'mantine-theme/theme'; import 'overlayscrollbars/overlayscrollbars.css'; import 'theme/css/overlayscrollbars.css'; @@ -36,7 +35,7 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) { }, [direction]); return ( - + {children} diff --git a/invokeai/frontend/web/src/common/components/IAIMantineInput.tsx b/invokeai/frontend/web/src/common/components/IAIMantineInput.tsx index f7c2b91ff0..d60f6614df 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineInput.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineInput.tsx @@ -7,8 +7,17 @@ type IAIMantineTextInputProps = TextInputProps; export default function IAIMantineTextInput(props: IAIMantineTextInputProps) { const { ...rest } = props; - const { base50, base100, base200, base800, base900, accent500, accent300 } = - useChakraThemeTokens(); + const { + base50, + base100, + base200, + base300, + base800, + base700, + base900, + accent500, + accent300, + } = useChakraThemeTokens(); const { colorMode } = useColorMode(); return ( @@ -24,6 +33,10 @@ export default function IAIMantineTextInput(props: IAIMantineTextInputProps) { borderColor: mode(accent300, accent500)(colorMode), }, }, + label: { + color: mode(base700, base300)(colorMode), + fontWeight: 'normal', + }, })} {...rest} /> diff --git a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx index dc6db707e7..dd5c602150 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx @@ -1,10 +1,9 @@ -import { Tooltip, useColorMode, useToken } from '@chakra-ui/react'; +import { Tooltip } from '@chakra-ui/react'; import { MultiSelect, MultiSelectProps } from '@mantine/core'; import { useAppDispatch } from 'app/store/storeHooks'; -import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; +import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles'; import { KeyboardEvent, RefObject, memo, useCallback } from 'react'; -import { mode } from 'theme/util/mode'; type IAIMultiSelectProps = MultiSelectProps & { tooltip?: string; @@ -14,25 +13,6 @@ type IAIMultiSelectProps = MultiSelectProps & { const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { const { searchable = true, tooltip, inputRef, ...rest } = props; const dispatch = useAppDispatch(); - const { - base50, - base100, - base200, - base300, - base400, - base500, - base600, - base700, - base800, - base900, - accent200, - accent300, - accent400, - accent500, - accent600, - } = useChakraThemeTokens(); - const [boxShadow] = useToken('shadows', ['dark-lg']); - const { colorMode } = useColorMode(); const handleKeyDown = useCallback( (e: KeyboardEvent) => { @@ -52,6 +32,8 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { [dispatch] ); + const styles = useMantineMultiSelectStyles(); + return ( { onKeyUp={handleKeyUp} searchable={searchable} maxDropdownHeight={300} - styles={() => ({ - label: { - color: mode(base700, base300)(colorMode), - fontWeight: 'normal', - }, - searchInput: { - ':placeholder': { - color: mode(base300, base700)(colorMode), - }, - }, - input: { - backgroundColor: mode(base50, base900)(colorMode), - borderWidth: '2px', - borderColor: mode(base200, base800)(colorMode), - color: mode(base900, base100)(colorMode), - paddingRight: 24, - fontWeight: 600, - '&:hover': { borderColor: mode(base300, base600)(colorMode) }, - '&:focus': { - borderColor: mode(accent300, accent600)(colorMode), - }, - '&:is(:focus, :hover)': { - borderColor: mode(base400, base500)(colorMode), - }, - '&:focus-within': { - borderColor: mode(accent200, accent600)(colorMode), - }, - '&[data-disabled]': { - backgroundColor: mode(base300, base700)(colorMode), - color: mode(base600, base400)(colorMode), - cursor: 'not-allowed', - }, - }, - value: { - backgroundColor: mode(base200, base800)(colorMode), - color: mode(base900, base100)(colorMode), - button: { - color: mode(base900, base100)(colorMode), - }, - '&:hover': { - backgroundColor: mode(base300, base700)(colorMode), - cursor: 'pointer', - }, - }, - dropdown: { - backgroundColor: mode(base200, base800)(colorMode), - borderColor: mode(base200, base800)(colorMode), - boxShadow, - }, - item: { - backgroundColor: mode(base200, base800)(colorMode), - color: mode(base800, base200)(colorMode), - padding: 6, - '&[data-hovered]': { - color: mode(base900, base100)(colorMode), - backgroundColor: mode(base300, base700)(colorMode), - }, - '&[data-active]': { - backgroundColor: mode(base300, base700)(colorMode), - '&:hover': { - color: mode(base900, base100)(colorMode), - backgroundColor: mode(base300, base700)(colorMode), - }, - }, - '&[data-selected]': { - backgroundColor: mode(accent400, accent600)(colorMode), - color: mode(base50, base100)(colorMode), - fontWeight: 600, - '&:hover': { - backgroundColor: mode(accent500, accent500)(colorMode), - color: mode('white', base50)(colorMode), - }, - }, - '&[data-disabled]': { - color: mode(base500, base600)(colorMode), - cursor: 'not-allowed', - }, - }, - rightSection: { - width: 24, - padding: 20, - button: { - color: mode(base900, base100)(colorMode), - }, - }, - })} + styles={styles} {...rest} /> diff --git a/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx new file mode 100644 index 0000000000..edf1665bb4 --- /dev/null +++ b/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx @@ -0,0 +1,78 @@ +import { Tooltip } from '@chakra-ui/react'; +import { Select, SelectProps } from '@mantine/core'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; +import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles'; +import { KeyboardEvent, RefObject, memo, useCallback, useState } from 'react'; + +export type IAISelectDataType = { + value: string; + label: string; + tooltip?: string; +}; + +type IAISelectProps = SelectProps & { + tooltip?: string; + inputRef?: RefObject; +}; + +const IAIMantineSearchableSelect = (props: IAISelectProps) => { + const { searchable = true, tooltip, inputRef, onChange, ...rest } = props; + const dispatch = useAppDispatch(); + + const [searchValue, setSearchValue] = useState(''); + + // we want to capture shift keypressed even when an input is focused + const handleKeyDown = useCallback( + (e: KeyboardEvent) => { + if (e.shiftKey) { + dispatch(shiftKeyPressed(true)); + } + }, + [dispatch] + ); + + const handleKeyUp = useCallback( + (e: KeyboardEvent) => { + if (!e.shiftKey) { + dispatch(shiftKeyPressed(false)); + } + }, + [dispatch] + ); + + // wrap onChange to clear search value on select + const handleChange = useCallback( + (v: string | null) => { + setSearchValue(''); + + if (!onChange) { + return; + } + + onChange(v); + }, + [onChange] + ); + + const styles = useMantineSelectStyles(); + + return ( + + ({ - label: { - color: mode(base700, base300)(colorMode), - fontWeight: 'normal', - }, - input: { - backgroundColor: mode(base50, base900)(colorMode), - borderWidth: '2px', - borderColor: mode(base200, base800)(colorMode), - color: mode(base900, base100)(colorMode), - paddingRight: 24, - fontWeight: 600, - '&:hover': { borderColor: mode(base300, base600)(colorMode) }, - '&:focus': { - borderColor: mode(accent300, accent600)(colorMode), - }, - '&:is(:focus, :hover)': { - borderColor: mode(base400, base500)(colorMode), - }, - '&:focus-within': { - borderColor: mode(accent200, accent600)(colorMode), - }, - '&[data-disabled]': { - backgroundColor: mode(base300, base700)(colorMode), - color: mode(base600, base400)(colorMode), - cursor: 'not-allowed', - }, - }, - value: { - backgroundColor: mode(base100, base900)(colorMode), - color: mode(base900, base100)(colorMode), - button: { - color: mode(base900, base100)(colorMode), - }, - '&:hover': { - backgroundColor: mode(base300, base700)(colorMode), - cursor: 'pointer', - }, - }, - dropdown: { - backgroundColor: mode(base200, base800)(colorMode), - borderColor: mode(base200, base800)(colorMode), - boxShadow, - }, - item: { - backgroundColor: mode(base200, base800)(colorMode), - color: mode(base800, base200)(colorMode), - padding: 6, - '&[data-hovered]': { - color: mode(base900, base100)(colorMode), - backgroundColor: mode(base300, base700)(colorMode), - }, - '&[data-active]': { - backgroundColor: mode(base300, base700)(colorMode), - '&:hover': { - color: mode(base900, base100)(colorMode), - backgroundColor: mode(base300, base700)(colorMode), - }, - }, - '&[data-selected]': { - backgroundColor: mode(accent400, accent600)(colorMode), - color: mode(base50, base100)(colorMode), - fontWeight: 600, - '&:hover': { - backgroundColor: mode(accent500, accent500)(colorMode), - color: mode('white', base50)(colorMode), - }, - }, - '&[data-disabled]': { - color: mode(base500, base600)(colorMode), - cursor: 'not-allowed', - }, - }, - rightSection: { - width: 32, - button: { - color: mode(base900, base100)(colorMode), - }, - }, - })} - {...rest} - /> +