From 4a2f34f77fda055361827efdde3aeae7a0fb24c4 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sat, 15 Jul 2023 22:23:00 +1200 Subject: [PATCH] wip: Model Search Going to rework the whole thing. The old system is convoluted and too difficult to plug back. --- .../AddModelsPanel/FoundModelsList.tsx | 34 ++ .../AddModelsPanel/SearchFolderForm.tsx | 114 +++++ .../subpanels/AddModelsPanel/SearchModels.tsx | 429 +---------------- .../AddModelsPanel/SearchModelsOld.tsx | 430 ++++++++++++++++++ .../web/src/services/api/endpoints/models.ts | 17 +- .../frontend/web/src/services/api/schema.d.ts | 12 +- .../frontend/web/src/services/api/types.d.ts | 1 + 7 files changed, 609 insertions(+), 428 deletions(-) create mode 100644 invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/FoundModelsList.tsx create mode 100644 invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/SearchFolderForm.tsx create mode 100644 invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/SearchModelsOld.tsx diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/FoundModelsList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/FoundModelsList.tsx new file mode 100644 index 0000000000..af862d005d --- /dev/null +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/FoundModelsList.tsx @@ -0,0 +1,34 @@ +import { Flex } from '@chakra-ui/react'; +import { RootState } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { useGetModelsInFolderQuery } from 'services/api/endpoints/models'; + +export default function FoundModelsList() { + const searchFolder = useAppSelector( + (state: RootState) => state.modelmanager.searchFolder + ); + + const { data: foundModels } = useGetModelsInFolderQuery({ + search_path: searchFolder ? searchFolder : '', + }); + + console.log(foundModels); + + const renderFoundModels = () => { + if (!searchFolder) return; + + if (!foundModels || foundModels.length === 0) { + return No Models Found; + } + + return ( + + {foundModels.map((model) => ( + {model} + ))} + + ); + }; + + return renderFoundModels(); +} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/SearchFolderForm.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/SearchFolderForm.tsx new file mode 100644 index 0000000000..10d0f51665 --- /dev/null +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/SearchFolderForm.tsx @@ -0,0 +1,114 @@ +import { Flex, Text } from '@chakra-ui/react'; +import { useForm } from '@mantine/form'; +import { RootState } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAIIconButton from 'common/components/IAIIconButton'; +import IAIInput from 'common/components/IAIInput'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { FaSearch, FaTrash } from 'react-icons/fa'; +import { setSearchFolder } from '../../store/modelManagerSlice'; + +type SearchFolderForm = { + folder: string; +}; + +function SearchFolderForm() { + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const searchFolder = useAppSelector( + (state: RootState) => state.modelmanager.searchFolder + ); + + const searchFolderForm = useForm({ + initialValues: { + folder: '', + }, + }); + + const searchFolderFormSubmitHandler = useCallback( + (values: SearchFolderForm) => { + dispatch(setSearchFolder(values.folder)); + }, + [dispatch] + ); + + return ( + + searchFolderFormSubmitHandler(values) + )} + style={{ width: '100%' }} + > + + + + Search Folder + + {!searchFolder ? ( + + ) : ( + + {searchFolder} + + )} + + + + } + fontSize={18} + size="sm" + type="submit" + /> + } + size="sm" + onClick={() => dispatch(setSearchFolder(null))} + isDisabled={!searchFolder} + colorScheme="red" + /> + + + + ); +} + +export default memo(SearchFolderForm); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/SearchModels.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/SearchModels.tsx index 3381cb85d3..e3e48c7e6b 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/SearchModels.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/SearchModels.tsx @@ -1,430 +1,17 @@ -import IAIButton from 'common/components/IAIButton'; -import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; -import IAIIconButton from 'common/components/IAIIconButton'; -import React from 'react'; - -import { - Badge, - Flex, - FormControl, - HStack, - Radio, - RadioGroup, - Spacer, - Text, -} from '@chakra-ui/react'; -import { createSelector } from '@reduxjs/toolkit'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { systemSelector } from 'features/system/store/systemSelectors'; +import { Flex } from '@chakra-ui/react'; +import { useAppDispatch } from 'app/store/storeHooks'; import { useTranslation } from 'react-i18next'; - -import { FaSearch, FaTrash } from 'react-icons/fa'; - -// import { addNewModel, searchForModels } from 'app/socketio/actions'; -import { - setFoundModels, - setSearchFolder, -} from 'features/system/store/systemSlice'; -import { setShouldShowExistingModelsInSearch } from 'features/ui/store/uiSlice'; - -import type { FoundModel } from 'app/types/invokeai'; -import type { RootState } from 'app/store/store'; -import IAIInput from 'common/components/IAIInput'; -import { Field, Formik } from 'formik'; -import { forEach, remove } from 'lodash-es'; -import type { ChangeEvent, ReactNode } from 'react'; -import IAIForm from 'common/components/IAIForm'; - -const existingModelsSelector = createSelector([systemSelector], (system) => { - const { model_list } = system; - - const existingModels: string[] = []; - - forEach(model_list, (value) => { - existingModels.push(value.weights); - }); - - return existingModels; -}); - -interface SearchModelEntry { - model: FoundModel; - modelsToAdd: string[]; - setModelsToAdd: React.Dispatch>; -} - -function SearchModelEntry({ - model, - modelsToAdd, - setModelsToAdd, -}: SearchModelEntry) { - const { t } = useTranslation(); - const existingModels = useAppSelector(existingModelsSelector); - - const foundModelsChangeHandler = (e: ChangeEvent) => { - if (!modelsToAdd.includes(e.target.value)) { - setModelsToAdd([...modelsToAdd, e.target.value]); - } else { - setModelsToAdd(remove(modelsToAdd, (v) => v !== e.target.value)); - } - }; - - return ( - - - {model.name}} - isChecked={modelsToAdd.includes(model.name)} - isDisabled={existingModels.includes(model.location)} - onChange={foundModelsChangeHandler} - > - {existingModels.includes(model.location) && ( - {t('modelManager.modelExists')} - )} - - - {model.location} - - - ); -} +import FoundModelsList from './FoundModelsList'; +import SearchFolderForm from './SearchFolderForm'; export default function SearchModels() { const dispatch = useAppDispatch(); - const { t } = useTranslation(); - const searchFolder = useAppSelector( - (state: RootState) => state.system.searchFolder - ); - - const foundModels = useAppSelector( - (state: RootState) => state.system.foundModels - ); - - const existingModels = useAppSelector(existingModelsSelector); - - const shouldShowExistingModelsInSearch = useAppSelector( - (state: RootState) => state.ui.shouldShowExistingModelsInSearch - ); - - const isProcessing = useAppSelector( - (state: RootState) => state.system.isProcessing - ); - - const [modelsToAdd, setModelsToAdd] = React.useState([]); - const [modelType, setModelType] = React.useState('v1'); - const [pathToConfig, setPathToConfig] = React.useState(''); - - const resetSearchModelHandler = () => { - dispatch(setSearchFolder(null)); - dispatch(setFoundModels(null)); - setModelsToAdd([]); - }; - - const findModelsHandler = (values: { checkpointFolder: string }) => { - dispatch(searchForModels(values.checkpointFolder)); - }; - - const addAllToSelected = () => { - setModelsToAdd([]); - if (foundModels) { - foundModels.forEach((model) => { - if (!existingModels.includes(model.location)) { - setModelsToAdd((currentModels) => { - return [...currentModels, model.name]; - }); - } - }); - } - }; - - const removeAllFromSelected = () => { - setModelsToAdd([]); - }; - - const addSelectedModels = () => { - const modelsToBeAdded = foundModels?.filter((foundModel) => - modelsToAdd.includes(foundModel.name) - ); - - const configFiles = { - v1: 'configs/stable-diffusion/v1-inference.yaml', - v2_base: 'configs/stable-diffusion/v2-inference-v.yaml', - v2_768: 'configs/stable-diffusion/v2-inference-v.yaml', - inpainting: 'configs/stable-diffusion/v1-inpainting-inference.yaml', - custom: pathToConfig, - }; - - modelsToBeAdded?.forEach((model) => { - const modelFormat = { - name: model.name, - description: '', - config: configFiles[modelType as keyof typeof configFiles], - weights: model.location, - vae: '', - width: 512, - height: 512, - default: false, - format: 'ckpt', - }; - dispatch(addNewModel(modelFormat)); - }); - setModelsToAdd([]); - }; - - const renderFoundModels = () => { - const newFoundModels: ReactNode[] = []; - const existingFoundModels: ReactNode[] = []; - - if (foundModels) { - foundModels.forEach((model, index) => { - if (existingModels.includes(model.location)) { - existingFoundModels.push( - - ); - } else { - newFoundModels.push( - - ); - } - }); - } - - return ( - - {newFoundModels} - {shouldShowExistingModelsInSearch && existingFoundModels} - - ); - }; - return ( - <> - {searchFolder ? ( - - - - {t('modelManager.checkpointFolder')} - - {searchFolder} - - - } - fontSize={18} - disabled={isProcessing} - onClick={() => dispatch(searchForModels(searchFolder))} - /> - } - onClick={resetSearchModelHandler} - /> - - ) : ( - { - findModelsHandler(values); - }} - > - {({ handleSubmit }) => ( - - - - - - } - aria-label={t('modelManager.findModels')} - tooltip={t('modelManager.findModels')} - type="submit" - disabled={isProcessing} - px={8} - > - {t('modelManager.findModels')} - - - - )} - - )} - {foundModels && ( - - - - {t('modelManager.modelsFound')}: {foundModels.length} - - - {t('modelManager.selected')}: {modelsToAdd.length} - - - - - - {t('modelManager.selectAll')} - - - {t('modelManager.deselectAll')} - - - dispatch( - setShouldShowExistingModelsInSearch( - !shouldShowExistingModelsInSearch - ) - ) - } - /> - - - - {t('modelManager.addSelected')} - - - - - - - {t('modelManager.pickModelType')} - - setModelType(v)} - defaultValue="v1" - name="model_type" - > - - - {t('modelManager.v1')} - - - {t('modelManager.v2_base')} - - - {t('modelManager.v2_768')} - - - {t('modelManager.inpainting')} - - - {t('modelManager.customConfig')} - - - - - - {modelType === 'custom' && ( - - - {t('modelManager.pathToCustomConfig')} - - { - if (e.target.value !== '') setPathToConfig(e.target.value); - }} - width="full" - /> - - )} - - - - {foundModels.length > 0 ? ( - renderFoundModels() - ) : ( - - {t('modelManager.noModelsFound')} - - )} - - - )} - > + + + + ); } diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/SearchModelsOld.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/SearchModelsOld.tsx new file mode 100644 index 0000000000..3381cb85d3 --- /dev/null +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/SearchModelsOld.tsx @@ -0,0 +1,430 @@ +import IAIButton from 'common/components/IAIButton'; +import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; +import IAIIconButton from 'common/components/IAIIconButton'; +import React from 'react'; + +import { + Badge, + Flex, + FormControl, + HStack, + Radio, + RadioGroup, + Spacer, + Text, +} from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { systemSelector } from 'features/system/store/systemSelectors'; +import { useTranslation } from 'react-i18next'; + +import { FaSearch, FaTrash } from 'react-icons/fa'; + +// import { addNewModel, searchForModels } from 'app/socketio/actions'; +import { + setFoundModels, + setSearchFolder, +} from 'features/system/store/systemSlice'; +import { setShouldShowExistingModelsInSearch } from 'features/ui/store/uiSlice'; + +import type { FoundModel } from 'app/types/invokeai'; +import type { RootState } from 'app/store/store'; +import IAIInput from 'common/components/IAIInput'; +import { Field, Formik } from 'formik'; +import { forEach, remove } from 'lodash-es'; +import type { ChangeEvent, ReactNode } from 'react'; +import IAIForm from 'common/components/IAIForm'; + +const existingModelsSelector = createSelector([systemSelector], (system) => { + const { model_list } = system; + + const existingModels: string[] = []; + + forEach(model_list, (value) => { + existingModels.push(value.weights); + }); + + return existingModels; +}); + +interface SearchModelEntry { + model: FoundModel; + modelsToAdd: string[]; + setModelsToAdd: React.Dispatch>; +} + +function SearchModelEntry({ + model, + modelsToAdd, + setModelsToAdd, +}: SearchModelEntry) { + const { t } = useTranslation(); + const existingModels = useAppSelector(existingModelsSelector); + + const foundModelsChangeHandler = (e: ChangeEvent) => { + if (!modelsToAdd.includes(e.target.value)) { + setModelsToAdd([...modelsToAdd, e.target.value]); + } else { + setModelsToAdd(remove(modelsToAdd, (v) => v !== e.target.value)); + } + }; + + return ( + + + {model.name}} + isChecked={modelsToAdd.includes(model.name)} + isDisabled={existingModels.includes(model.location)} + onChange={foundModelsChangeHandler} + > + {existingModels.includes(model.location) && ( + {t('modelManager.modelExists')} + )} + + + {model.location} + + + ); +} + +export default function SearchModels() { + const dispatch = useAppDispatch(); + + const { t } = useTranslation(); + + const searchFolder = useAppSelector( + (state: RootState) => state.system.searchFolder + ); + + const foundModels = useAppSelector( + (state: RootState) => state.system.foundModels + ); + + const existingModels = useAppSelector(existingModelsSelector); + + const shouldShowExistingModelsInSearch = useAppSelector( + (state: RootState) => state.ui.shouldShowExistingModelsInSearch + ); + + const isProcessing = useAppSelector( + (state: RootState) => state.system.isProcessing + ); + + const [modelsToAdd, setModelsToAdd] = React.useState([]); + const [modelType, setModelType] = React.useState('v1'); + const [pathToConfig, setPathToConfig] = React.useState(''); + + const resetSearchModelHandler = () => { + dispatch(setSearchFolder(null)); + dispatch(setFoundModels(null)); + setModelsToAdd([]); + }; + + const findModelsHandler = (values: { checkpointFolder: string }) => { + dispatch(searchForModels(values.checkpointFolder)); + }; + + const addAllToSelected = () => { + setModelsToAdd([]); + if (foundModels) { + foundModels.forEach((model) => { + if (!existingModels.includes(model.location)) { + setModelsToAdd((currentModels) => { + return [...currentModels, model.name]; + }); + } + }); + } + }; + + const removeAllFromSelected = () => { + setModelsToAdd([]); + }; + + const addSelectedModels = () => { + const modelsToBeAdded = foundModels?.filter((foundModel) => + modelsToAdd.includes(foundModel.name) + ); + + const configFiles = { + v1: 'configs/stable-diffusion/v1-inference.yaml', + v2_base: 'configs/stable-diffusion/v2-inference-v.yaml', + v2_768: 'configs/stable-diffusion/v2-inference-v.yaml', + inpainting: 'configs/stable-diffusion/v1-inpainting-inference.yaml', + custom: pathToConfig, + }; + + modelsToBeAdded?.forEach((model) => { + const modelFormat = { + name: model.name, + description: '', + config: configFiles[modelType as keyof typeof configFiles], + weights: model.location, + vae: '', + width: 512, + height: 512, + default: false, + format: 'ckpt', + }; + dispatch(addNewModel(modelFormat)); + }); + setModelsToAdd([]); + }; + + const renderFoundModels = () => { + const newFoundModels: ReactNode[] = []; + const existingFoundModels: ReactNode[] = []; + + if (foundModels) { + foundModels.forEach((model, index) => { + if (existingModels.includes(model.location)) { + existingFoundModels.push( + + ); + } else { + newFoundModels.push( + + ); + } + }); + } + + return ( + + {newFoundModels} + {shouldShowExistingModelsInSearch && existingFoundModels} + + ); + }; + + return ( + <> + {searchFolder ? ( + + + + {t('modelManager.checkpointFolder')} + + {searchFolder} + + + } + fontSize={18} + disabled={isProcessing} + onClick={() => dispatch(searchForModels(searchFolder))} + /> + } + onClick={resetSearchModelHandler} + /> + + ) : ( + { + findModelsHandler(values); + }} + > + {({ handleSubmit }) => ( + + + + + + } + aria-label={t('modelManager.findModels')} + tooltip={t('modelManager.findModels')} + type="submit" + disabled={isProcessing} + px={8} + > + {t('modelManager.findModels')} + + + + )} + + )} + {foundModels && ( + + + + {t('modelManager.modelsFound')}: {foundModels.length} + + + {t('modelManager.selected')}: {modelsToAdd.length} + + + + + + {t('modelManager.selectAll')} + + + {t('modelManager.deselectAll')} + + + dispatch( + setShouldShowExistingModelsInSearch( + !shouldShowExistingModelsInSearch + ) + ) + } + /> + + + + {t('modelManager.addSelected')} + + + + + + + {t('modelManager.pickModelType')} + + setModelType(v)} + defaultValue="v1" + name="model_type" + > + + + {t('modelManager.v1')} + + + {t('modelManager.v2_base')} + + + {t('modelManager.v2_768')} + + + {t('modelManager.inpainting')} + + + {t('modelManager.customConfig')} + + + + + + {modelType === 'custom' && ( + + + {t('modelManager.pathToCustomConfig')} + + { + if (e.target.value !== '') setPathToConfig(e.target.value); + }} + width="full" + /> + + )} + + + + {foundModels.length > 0 ? ( + renderFoundModels() + ) : ( + + {t('modelManager.noModelsFound')} + + )} + + + )} + > + ); +} diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 79e685313e..a838a82f46 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -14,8 +14,9 @@ import { VaeModelConfig, } from 'services/api/types'; +import queryString from 'query-string'; import { ApiFullTagDescription, LIST_TAG, api } from '..'; -import { paths } from '../schema'; +import { operations, paths } from '../schema'; export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string }; export type CheckpointModelConfigEntity = CheckpointModelConfig & { @@ -77,6 +78,11 @@ type MergeMainModelArg = { type MergeMainModelResponse = paths['/api/v1/models/merge/{base_model}']['put']['responses']['200']['content']['application/json']; +type SearchFolderResponse = + paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json']; + +type SearchFolderArg = operations['search_for_models']['parameters']['query']; + const mainModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), }); @@ -331,6 +337,14 @@ export const modelsApi = api.injectEndpoints({ ); }, }), + getModelsInFolder: build.query({ + query: (arg) => { + const folderQueryStr = queryString.stringify(arg, {}); + return { + url: `/models/search?${folderQueryStr}`, + }; + }, + }), }), }); @@ -344,4 +358,5 @@ export const { useDeleteMainModelsMutation, useConvertMainModelsMutation, useMergeMainModelsMutation, + useGetModelsInFolderQuery, } = modelsApi; diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index 610e9fa05e..892ed289c1 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -4655,18 +4655,18 @@ export type components = { */ image?: components["schemas"]["ImageField"]; }; - /** - * StableDiffusion1ModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusion2ModelFormat * @description An enumeration. * @enum {string} */ StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; + /** + * StableDiffusion1ModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; }; responses: never; parameters: never; diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts index 57258fb19b..c945b7de78 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.d.ts @@ -58,6 +58,7 @@ export type AnyModelConfig = export type MergeModelConfig = components['schemas']['Body_merge_models']; export type ConvertModelConfig = components['schemas']['Body_convert_model']; +export type SearchFolderConfig = components['schemas']; // Graphs export type Graph = components['schemas']['Graph'];
- {t('modelManager.modelsFound')}: {foundModels.length} -
- {t('modelManager.selected')}: {modelsToAdd.length} -
+ {t('modelManager.modelsFound')}: {foundModels.length} +
+ {t('modelManager.selected')}: {modelsToAdd.length} +