mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip: Move Model Selector to own file
This commit is contained in:
parent
61c426f502
commit
4133d77772
@ -5,7 +5,8 @@ import {
|
|||||||
ModelInputFieldTemplate,
|
ModelInputFieldTemplate,
|
||||||
ModelInputFieldValue,
|
ModelInputFieldValue,
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { modelSelector } from 'features/system/components/ModelSelect';
|
|
||||||
|
import { modelSelector } from 'features/system/store/modelSelectors';
|
||||||
import { ChangeEvent, memo } from 'react';
|
import { ChangeEvent, memo } from 'react';
|
||||||
import { FieldComponentProps } from './types';
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
@ -16,7 +17,8 @@ const ModelInputFieldComponent = (
|
|||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { sd1ModelData, sd2ModelData } = useAppSelector(modelSelector);
|
const { sd1ModelDropDownData, sd2ModelDropdownData } =
|
||||||
|
useAppSelector(modelSelector);
|
||||||
|
|
||||||
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
|
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
|
||||||
dispatch(
|
dispatch(
|
||||||
@ -31,8 +33,8 @@ const ModelInputFieldComponent = (
|
|||||||
return (
|
return (
|
||||||
<NativeSelect
|
<NativeSelect
|
||||||
onChange={handleValueChanged}
|
onChange={handleValueChanged}
|
||||||
value={field.value || sd1ModelData[0].value}
|
value={field.value || sd1ModelDropDownData[0].value}
|
||||||
data={sd1ModelData.concat(sd2ModelData)}
|
data={sd1ModelDropDownData.concat(sd2ModelDropdownData)}
|
||||||
></NativeSelect>
|
></NativeSelect>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,63 +1,14 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { isEqual } from 'lodash-es';
|
|
||||||
import { memo, useCallback, useEffect } from 'react';
|
import { memo, useCallback, useEffect } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIMantineSelect, {
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
IAISelectDataType,
|
|
||||||
} from 'common/components/IAIMantineSelect';
|
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
|
||||||
import {
|
import {
|
||||||
modelSelected,
|
modelSelected,
|
||||||
setCurrentModelType,
|
setCurrentModelType,
|
||||||
} from 'features/parameters/store/generationSlice';
|
} from 'features/parameters/store/generationSlice';
|
||||||
|
|
||||||
import {
|
import { modelSelector } from '../store/modelSelectors';
|
||||||
selectAllSD1Models,
|
|
||||||
selectByIdSD1Models,
|
|
||||||
} from '../store/models/sd1ModelSlice';
|
|
||||||
import {
|
|
||||||
selectAllSD2Models,
|
|
||||||
selectByIdSD2Models,
|
|
||||||
} from '../store/models/sd2ModelSlice';
|
|
||||||
|
|
||||||
export const modelSelector = createSelector(
|
|
||||||
[(state: RootState) => state, generationSelector],
|
|
||||||
(state, generation) => {
|
|
||||||
let selectedModel = selectByIdSD1Models(state, generation.model);
|
|
||||||
if (selectedModel === undefined)
|
|
||||||
selectedModel = selectByIdSD2Models(state, generation.model);
|
|
||||||
|
|
||||||
const sd1ModelData = selectAllSD1Models(state)
|
|
||||||
.map<IAISelectDataType>((m) => ({
|
|
||||||
value: m.name,
|
|
||||||
label: m.name,
|
|
||||||
group: '1.x Models',
|
|
||||||
}))
|
|
||||||
.sort((a, b) => a.label.localeCompare(b.label));
|
|
||||||
|
|
||||||
const sd2ModelData = selectAllSD2Models(state)
|
|
||||||
.map<IAISelectDataType>((m) => ({
|
|
||||||
value: m.name,
|
|
||||||
label: m.name,
|
|
||||||
group: '2.x Models',
|
|
||||||
}))
|
|
||||||
.sort((a, b) => a.label.localeCompare(b.label));
|
|
||||||
|
|
||||||
return {
|
|
||||||
selectedModel,
|
|
||||||
sd1ModelData,
|
|
||||||
sd2ModelData,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
export type ModelLoaderTypes = 'sd1_model_loader' | 'sd2_model_loader';
|
export type ModelLoaderTypes = 'sd1_model_loader' | 'sd2_model_loader';
|
||||||
|
|
||||||
@ -69,7 +20,7 @@ const MODEL_LOADER_MAP = {
|
|||||||
const ModelSelect = () => {
|
const ModelSelect = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { selectedModel, sd1ModelData, sd2ModelData } =
|
const { selectedModel, sd1ModelDropDownData, sd2ModelDropdownData } =
|
||||||
useAppSelector(modelSelector);
|
useAppSelector(modelSelector);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@ -97,7 +48,7 @@ const ModelSelect = () => {
|
|||||||
label={t('modelManager.model')}
|
label={t('modelManager.model')}
|
||||||
value={selectedModel?.name ?? ''}
|
value={selectedModel?.name ?? ''}
|
||||||
placeholder="Pick one"
|
placeholder="Pick one"
|
||||||
data={sd1ModelData.concat(sd2ModelData)}
|
data={sd1ModelDropDownData.concat(sd2ModelDropdownData)}
|
||||||
onChange={handleChangeModel}
|
onChange={handleChangeModel}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
@ -1,3 +1,54 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
|
import { IAISelectDataType } from 'common/components/IAIMantineSelect';
|
||||||
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
|
import { isEqual } from 'lodash-es';
|
||||||
|
import {
|
||||||
|
selectAllSD1Models,
|
||||||
|
selectByIdSD1Models,
|
||||||
|
} from './models/sd1ModelSlice';
|
||||||
|
import {
|
||||||
|
selectAllSD2Models,
|
||||||
|
selectByIdSD2Models,
|
||||||
|
} from './models/sd2ModelSlice';
|
||||||
|
|
||||||
export const modelSelector = (state: RootState) => state.models;
|
export const modelSelector = createSelector(
|
||||||
|
[(state: RootState) => state, generationSelector],
|
||||||
|
(state, generation) => {
|
||||||
|
let selectedModel = selectByIdSD1Models(state, generation.model);
|
||||||
|
if (selectedModel === undefined)
|
||||||
|
selectedModel = selectByIdSD2Models(state, generation.model);
|
||||||
|
|
||||||
|
const sd1Models = selectAllSD1Models(state);
|
||||||
|
const sd2Models = selectAllSD2Models(state);
|
||||||
|
|
||||||
|
const sd1ModelDropDownData = selectAllSD1Models(state)
|
||||||
|
.map<IAISelectDataType>((m) => ({
|
||||||
|
value: m.name,
|
||||||
|
label: m.name,
|
||||||
|
group: '1.x Models',
|
||||||
|
}))
|
||||||
|
.sort((a, b) => a.label.localeCompare(b.label));
|
||||||
|
|
||||||
|
const sd2ModelDropdownData = selectAllSD2Models(state)
|
||||||
|
.map<IAISelectDataType>((m) => ({
|
||||||
|
value: m.name,
|
||||||
|
label: m.name,
|
||||||
|
group: '2.x Models',
|
||||||
|
}))
|
||||||
|
.sort((a, b) => a.label.localeCompare(b.label));
|
||||||
|
|
||||||
|
return {
|
||||||
|
selectedModel,
|
||||||
|
sd1Models,
|
||||||
|
sd2Models,
|
||||||
|
sd1ModelDropDownData,
|
||||||
|
sd2ModelDropdownData,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
Loading…
Reference in New Issue
Block a user