From 4133d7777203159267361e69017067d1b074d610 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sun, 18 Jun 2023 09:19:13 +1200 Subject: [PATCH] wip: Move Model Selector to own file --- .../fields/ModelInputFieldComponent.tsx | 10 ++-- .../system/components/ModelSelect.tsx | 57 ++----------------- .../features/system/store/modelSelectors.ts | 53 ++++++++++++++++- 3 files changed, 62 insertions(+), 58 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx index d3d37765f4..3842e8da3a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx @@ -5,7 +5,8 @@ import { ModelInputFieldTemplate, ModelInputFieldValue, } from 'features/nodes/types/types'; -import { modelSelector } from 'features/system/components/ModelSelect'; + +import { modelSelector } from 'features/system/store/modelSelectors'; import { ChangeEvent, memo } from 'react'; import { FieldComponentProps } from './types'; @@ -16,7 +17,8 @@ const ModelInputFieldComponent = ( const dispatch = useAppDispatch(); - const { sd1ModelData, sd2ModelData } = useAppSelector(modelSelector); + const { sd1ModelDropDownData, sd2ModelDropdownData } = + useAppSelector(modelSelector); const handleValueChanged = (e: ChangeEvent) => { dispatch( @@ -31,8 +33,8 @@ const ModelInputFieldComponent = ( return ( ); }; diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index bf0775d52e..43de144991 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -1,63 +1,14 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { isEqual } from 'lodash-es'; import { memo, useCallback, useEffect } from 'react'; import { useTranslation } from 'react-i18next'; -import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAIMantineSelect, { - IAISelectDataType, -} from 'common/components/IAIMantineSelect'; -import { generationSelector } from 'features/parameters/store/generationSelectors'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { modelSelected, setCurrentModelType, } from 'features/parameters/store/generationSlice'; -import { - selectAllSD1Models, - selectByIdSD1Models, -} from '../store/models/sd1ModelSlice'; -import { - selectAllSD2Models, - selectByIdSD2Models, -} from '../store/models/sd2ModelSlice'; - -export const modelSelector = createSelector( - [(state: RootState) => state, generationSelector], - (state, generation) => { - let selectedModel = selectByIdSD1Models(state, generation.model); - if (selectedModel === undefined) - selectedModel = selectByIdSD2Models(state, generation.model); - - const sd1ModelData = selectAllSD1Models(state) - .map((m) => ({ - value: m.name, - label: m.name, - group: '1.x Models', - })) - .sort((a, b) => a.label.localeCompare(b.label)); - - const sd2ModelData = selectAllSD2Models(state) - .map((m) => ({ - value: m.name, - label: m.name, - group: '2.x Models', - })) - .sort((a, b) => a.label.localeCompare(b.label)); - - return { - selectedModel, - sd1ModelData, - sd2ModelData, - }; - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); +import { modelSelector } from '../store/modelSelectors'; export type ModelLoaderTypes = 'sd1_model_loader' | 'sd2_model_loader'; @@ -69,7 +20,7 @@ const MODEL_LOADER_MAP = { const ModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { selectedModel, sd1ModelData, sd2ModelData } = + const { selectedModel, sd1ModelDropDownData, sd2ModelDropdownData } = useAppSelector(modelSelector); useEffect(() => { @@ -97,7 +48,7 @@ const ModelSelect = () => { label={t('modelManager.model')} value={selectedModel?.name ?? ''} placeholder="Pick one" - data={sd1ModelData.concat(sd2ModelData)} + data={sd1ModelDropDownData.concat(sd2ModelDropdownData)} onChange={handleChangeModel} /> ); diff --git a/invokeai/frontend/web/src/features/system/store/modelSelectors.ts b/invokeai/frontend/web/src/features/system/store/modelSelectors.ts index f857bc85bc..6e101da5f5 100644 --- a/invokeai/frontend/web/src/features/system/store/modelSelectors.ts +++ b/invokeai/frontend/web/src/features/system/store/modelSelectors.ts @@ -1,3 +1,54 @@ +import { createSelector } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; +import { IAISelectDataType } from 'common/components/IAIMantineSelect'; +import { generationSelector } from 'features/parameters/store/generationSelectors'; +import { isEqual } from 'lodash-es'; +import { + selectAllSD1Models, + selectByIdSD1Models, +} from './models/sd1ModelSlice'; +import { + selectAllSD2Models, + selectByIdSD2Models, +} from './models/sd2ModelSlice'; -export const modelSelector = (state: RootState) => state.models; +export const modelSelector = createSelector( + [(state: RootState) => state, generationSelector], + (state, generation) => { + let selectedModel = selectByIdSD1Models(state, generation.model); + if (selectedModel === undefined) + selectedModel = selectByIdSD2Models(state, generation.model); + + const sd1Models = selectAllSD1Models(state); + const sd2Models = selectAllSD2Models(state); + + const sd1ModelDropDownData = selectAllSD1Models(state) + .map((m) => ({ + value: m.name, + label: m.name, + group: '1.x Models', + })) + .sort((a, b) => a.label.localeCompare(b.label)); + + const sd2ModelDropdownData = selectAllSD2Models(state) + .map((m) => ({ + value: m.name, + label: m.name, + group: '2.x Models', + })) + .sort((a, b) => a.label.localeCompare(b.label)); + + return { + selectedModel, + sd1Models, + sd2Models, + sd1ModelDropDownData, + sd2ModelDropdownData, + }; + }, + { + memoizeOptions: { + resultEqualityCheck: isEqual, + }, + } +);