feat(ui): IAICustomSelect v2, implement for scheduler & model

This commit is contained in:
psychedelicious 2023-05-13 23:46:47 +10:00
parent 37da0fc075
commit 658b556544
8 changed files with 181 additions and 176 deletions

View File

@ -1,28 +1,25 @@
import { CheckIcon, ChevronUpIcon } from '@chakra-ui/icons'; import { CheckIcon } from '@chakra-ui/icons';
import { import {
Box,
Flex, Flex,
FlexProps,
FormControl, FormControl,
FormControlProps, FormControlProps,
FormLabel, FormLabel,
Grid, Grid,
GridItem, GridItem,
Input,
List, List,
ListItem, ListItem,
Select, Select,
Spacer,
Text, Text,
Tooltip,
TooltipProps,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useEnsureOnScreen } from 'common/hooks/useEnsureOnScreen'; import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom';
import { useSelect } from 'downshift'; import { useSelect } from 'downshift';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useRef } from 'react'; import { memo } from 'react';
import { useIntersection } from 'react-use';
const BUTTON_BG = 'base.900';
const BORDER_HOVER = 'base.700';
const BORDER_FOCUS = 'accent.600';
type IAICustomSelectProps = { type IAICustomSelectProps = {
label?: string; label?: string;
@ -31,6 +28,9 @@ type IAICustomSelectProps = {
setSelectedItem: (v: string | null | undefined) => void; setSelectedItem: (v: string | null | undefined) => void;
withCheckIcon?: boolean; withCheckIcon?: boolean;
formControlProps?: FormControlProps; formControlProps?: FormControlProps;
buttonProps?: FlexProps;
tooltip?: string;
tooltipProps?: Omit<TooltipProps, 'children'>;
}; };
const IAICustomSelect = (props: IAICustomSelectProps) => { const IAICustomSelect = (props: IAICustomSelectProps) => {
@ -41,6 +41,9 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
selectedItem, selectedItem,
withCheckIcon, withCheckIcon,
formControlProps, formControlProps,
tooltip,
buttonProps,
tooltipProps,
} = props; } = props;
const { const {
@ -57,27 +60,28 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
setSelectedItem(newSelectedItem), setSelectedItem(newSelectedItem),
}); });
const toggleButtonRef = useRef<HTMLButtonElement>(null); const { refs, floatingStyles } = useFloating<HTMLButtonElement>({
const menuRef = useRef<HTMLUListElement>(null); whileElementsMounted: autoUpdate,
middleware: [offset(4), shift({ crossAxis: true, padding: 8 })],
});
return ( return (
<FormControl {...formControlProps}> <FormControl sx={{ w: 'full' }} {...formControlProps}>
{label && ( {label && (
<FormLabel <FormLabel
{...getLabelProps()} {...getLabelProps()}
onClick={() => { onClick={() => {
toggleButtonRef.current && toggleButtonRef.current.focus(); refs.floating.current && refs.floating.current.focus();
}} }}
> >
{label} {label}
</FormLabel> </FormLabel>
)} )}
<Tooltip label={tooltip} {...tooltipProps}>
<Select <Select
{...getToggleButtonProps({ ref: refs.setReference })}
{...buttonProps}
as={Flex} as={Flex}
{...getToggleButtonProps({
ref: toggleButtonRef,
})}
ref={toggleButtonRef}
sx={{ sx={{
alignItems: 'center', alignItems: 'center',
userSelect: 'none', userSelect: 'none',
@ -88,12 +92,17 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
{selectedItem} {selectedItem}
</Text> </Text>
</Select> </Select>
</Tooltip>
<Box {...getMenuProps()}>
{isOpen && (
<List <List
{...getMenuProps({ ref: menuRef })}
as={Flex} as={Flex}
ref={refs.setFloating}
sx={{ sx={{
position: 'absolute', ...floatingStyles,
visibility: isOpen ? 'visible' : 'hidden', width: 'max-content',
top: 0,
left: 0,
flexDirection: 'column', flexDirection: 'column',
zIndex: 1, zIndex: 1,
bg: 'base.800', bg: 'base.800',
@ -105,12 +114,10 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
px: 0, px: 0,
h: 'fit-content', h: 'fit-content',
maxH: 64, maxH: 64,
mt: 1,
}} }}
> >
<OverlayScrollbarsComponent defer> <OverlayScrollbarsComponent>
{isOpen && {items.map((item, index) => (
items.map((item, index) => (
<ListItem <ListItem
sx={{ sx={{
bg: highlightedIndex === index ? 'base.700' : undefined, bg: highlightedIndex === index ? 'base.700' : undefined,
@ -156,6 +163,8 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
))} ))}
</OverlayScrollbarsComponent> </OverlayScrollbarsComponent>
</List> </List>
)}
</Box>
</FormControl> </FormControl>
); );
}; };

View File

@ -1,5 +1,6 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAICustomSelect from 'common/components/IAICustomSelect';
import IAISelect from 'common/components/IAISelect'; import IAISelect from 'common/components/IAISelect';
import { setSampler } from 'features/parameters/store/generationSlice'; import { setSampler } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
@ -23,21 +24,26 @@ const ParamSampler = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const handleChange = useCallback( const handleChange = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => dispatch(setSampler(e.target.value)), (v: string | null | undefined) => {
if (!v) {
return;
}
dispatch(setSampler(v));
},
[dispatch] [dispatch]
); );
return ( return (
<IAISelect <IAICustomSelect
label={t('parameters.sampler')} label={t('parameters.sampler')}
value={sampler} selectedItem={sampler}
onChange={handleChange} setSelectedItem={handleChange}
validValues={ items={
['img2img', 'unifiedCanvas'].includes(activeTabName) ['img2img', 'unifiedCanvas'].includes(activeTabName)
? img2imgSchedulers ? img2imgSchedulers
: schedulers : schedulers
} }
minWidth={36} withCheckIcon
/> />
); );
}; };

View File

@ -0,0 +1,19 @@
import { Box, Flex } from '@chakra-ui/react';
import { memo } from 'react';
import ParamSampler from './ParamSampler';
import ModelSelect from 'features/system/components/ModelSelect';
const ParamSchedulerAndModel = () => {
return (
<Flex gap={3} w="full">
<Box w="16rem">
<ParamSampler />
</Box>
<Box w="full">
<ModelSelect />
</Box>
</Flex>
);
};
export default memo(ParamSchedulerAndModel);

View File

@ -2,8 +2,9 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai'; import * as InvokeAI from 'app/types/invokeai';
import promptToString from 'common/util/promptToString'; import promptToString from 'common/util/promptToString';
import { clamp } from 'lodash-es'; import { clamp, sample } from 'lodash-es';
import { setAllParametersReducer } from './setAllParametersReducer'; import { setAllParametersReducer } from './setAllParametersReducer';
import { receivedModels } from 'services/thunks/model';
export interface GenerationState { export interface GenerationState {
cfgScale: number; cfgScale: number;
@ -236,6 +237,16 @@ export const generationSlice = createSlice({
state.model = action.payload; state.model = action.payload;
}, },
}, },
extraReducers: (builder) => {
builder.addCase(receivedModels.fulfilled, (state, action) => {
if (!state.model) {
const randomModel = sample(action.payload);
if (randomModel) {
state.model = randomModel.name;
}
}
});
},
}); });
export const { export const {

View File

@ -1,21 +1,20 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { ChangeEvent, memo } from 'react'; import { memo, useCallback } from 'react';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISelect from 'common/components/IAISelect';
import { selectModelsById, selectModelsIds } from '../store/modelSlice'; import { selectModelsById, selectModelsIds } from '../store/modelSlice';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { modelSelected } from 'features/parameters/store/generationSlice'; import { modelSelected } from 'features/parameters/store/generationSlice';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import IAICustomSelect from 'common/components/IAICustomSelect';
const selector = createSelector( const selector = createSelector(
[(state: RootState) => state, generationSelector], [(state: RootState) => state, generationSelector],
(state, generation) => { (state, generation) => {
// const selectedModel = selectedModelSelector(state);
const selectedModel = selectModelsById(state, generation.model); const selectedModel = selectModelsById(state, generation.model);
const allModelNames = selectModelsIds(state); const allModelNames = selectModelsIds(state).map((id) => String(id));
return { return {
allModelNames, allModelNames,
selectedModel, selectedModel,
@ -32,19 +31,25 @@ const ModelSelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { allModelNames, selectedModel } = useAppSelector(selector); const { allModelNames, selectedModel } = useAppSelector(selector);
const handleChangeModel = (e: ChangeEvent<HTMLSelectElement>) => { const handleChangeModel = useCallback(
dispatch(modelSelected(e.target.value)); (v: string | null | undefined) => {
}; if (!v) {
return;
}
dispatch(modelSelected(v));
},
[dispatch]
);
return ( return (
<IAISelect <IAICustomSelect
label={t('modelManager.model')} label={t('modelManager.model')}
style={{ fontSize: 'sm' }} tooltip={selectedModel?.description}
aria-label={t('accessibility.modelSelect')} items={allModelNames}
tooltip={selectedModel?.description || ''} selectedItem={selectedModel?.name ?? ''}
value={selectedModel?.name || undefined} setSelectedItem={handleChangeModel}
validValues={allModelNames} withCheckIcon={true}
onChange={handleChangeModel} tooltipProps={{ placement: 'top', hasArrow: true }}
/> />
); );
}; };

View File

@ -1,5 +1,5 @@
import { memo } from 'react'; import { memo } from 'react';
import { Box, Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { uiSelector } from 'features/ui/store/uiSelectors'; import { uiSelector } from 'features/ui/store/uiSelectors';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
@ -9,11 +9,10 @@ import ParamSteps from 'features/parameters/components/Parameters/Core/ParamStep
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale'; import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth'; import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth';
import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight'; import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight';
import ParamSampler from 'features/parameters/components/Parameters/Core/ParamSampler';
import ModelSelect from 'features/system/components/ModelSelect';
import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength'; import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength';
import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit'; import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel';
const selector = createSelector( const selector = createSelector(
[uiSelector, generationSelector], [uiSelector, generationSelector],
@ -48,14 +47,7 @@ const ImageToImageTabCoreParameters = () => {
<ParamHeight isDisabled={!shouldFitToWidthHeight} /> <ParamHeight isDisabled={!shouldFitToWidthHeight} />
<ImageToImageStrength /> <ImageToImageStrength />
<ImageToImageFit /> <ImageToImageFit />
<Flex gap={3} w="full"> <ParamSchedulerAndModel />
<Box flexGrow={2}>
<ParamSampler />
</Box>
<Box flexGrow={3}>
<ModelSelect />
</Box>
</Flex>
</Flex> </Flex>
) : ( ) : (
<Flex sx={{ gap: 2, flexDirection: 'column' }}> <Flex sx={{ gap: 2, flexDirection: 'column' }}>
@ -64,14 +56,7 @@ const ImageToImageTabCoreParameters = () => {
<ParamSteps /> <ParamSteps />
<ParamCFGScale /> <ParamCFGScale />
</Flex> </Flex>
<Flex gap={3} w="full"> <ParamSchedulerAndModel />
<Box flexGrow={2}>
<ParamSampler />
</Box>
<Box flexGrow={3}>
<ModelSelect />
</Box>
</Flex>
<ParamWidth isDisabled={!shouldFitToWidthHeight} /> <ParamWidth isDisabled={!shouldFitToWidthHeight} />
<ParamHeight isDisabled={!shouldFitToWidthHeight} /> <ParamHeight isDisabled={!shouldFitToWidthHeight} />
<ImageToImageStrength /> <ImageToImageStrength />

View File

@ -3,14 +3,13 @@ import ParamSteps from 'features/parameters/components/Parameters/Core/ParamStep
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale'; import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth'; import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth';
import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight'; import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight';
import ParamSampler from 'features/parameters/components/Parameters/Core/ParamSampler'; import { Flex } from '@chakra-ui/react';
import ModelSelect from 'features/system/components/ModelSelect';
import { Box, Flex } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { uiSelector } from 'features/ui/store/uiSelectors'; import { uiSelector } from 'features/ui/store/uiSelectors';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { memo } from 'react'; import { memo } from 'react';
import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel';
const selector = createSelector( const selector = createSelector(
uiSelector, uiSelector,
@ -42,14 +41,7 @@ const TextToImageTabCoreParameters = () => {
<ParamCFGScale /> <ParamCFGScale />
<ParamWidth /> <ParamWidth />
<ParamHeight /> <ParamHeight />
<Flex gap={3} w="full"> <ParamSchedulerAndModel />
<Box flexGrow={2}>
<ParamSampler />
</Box>
<Box flexGrow={3}>
<ModelSelect />
</Box>
</Flex>
</Flex> </Flex>
) : ( ) : (
<Flex sx={{ gap: 2, flexDirection: 'column' }}> <Flex sx={{ gap: 2, flexDirection: 'column' }}>
@ -58,14 +50,7 @@ const TextToImageTabCoreParameters = () => {
<ParamSteps /> <ParamSteps />
<ParamCFGScale /> <ParamCFGScale />
</Flex> </Flex>
<Flex gap={3} w="full"> <ParamSchedulerAndModel />
<Box flexGrow={2}>
<ParamSampler />
</Box>
<Box flexGrow={3}>
<ModelSelect />
</Box>
</Flex>
<ParamWidth /> <ParamWidth />
<ParamHeight /> <ParamHeight />
</Flex> </Flex>

View File

@ -1,10 +1,9 @@
import { memo } from 'react'; import { memo } from 'react';
import { Box, Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { uiSelector } from 'features/ui/store/uiSelectors'; import { uiSelector } from 'features/ui/store/uiSelectors';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import ModelSelect from 'features/system/components/ModelSelect';
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations'; import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale'; import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
@ -12,7 +11,7 @@ import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidt
import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight'; import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight';
import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength'; import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength';
import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit'; import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit';
import ParamSampler from 'features/parameters/components/Parameters/Core/ParamSampler'; import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel';
const selector = createSelector( const selector = createSelector(
uiSelector, uiSelector,
@ -46,14 +45,7 @@ const UnifiedCanvasCoreParameters = () => {
<ParamHeight /> <ParamHeight />
<ImageToImageStrength /> <ImageToImageStrength />
<ImageToImageFit /> <ImageToImageFit />
<Flex gap={3} w="full"> <ParamSchedulerAndModel />
<Box flexGrow={2}>
<ParamSampler />
</Box>
<Box flexGrow={3}>
<ModelSelect />
</Box>
</Flex>
</Flex> </Flex>
) : ( ) : (
<Flex sx={{ gap: 2, flexDirection: 'column' }}> <Flex sx={{ gap: 2, flexDirection: 'column' }}>
@ -62,14 +54,7 @@ const UnifiedCanvasCoreParameters = () => {
<ParamSteps /> <ParamSteps />
<ParamCFGScale /> <ParamCFGScale />
</Flex> </Flex>
<Flex gap={3} w="full"> <ParamSchedulerAndModel />
<Box flexGrow={2}>
<ParamSampler />
</Box>
<Box flexGrow={3}>
<ModelSelect />
</Box>
</Flex>
<ParamWidth /> <ParamWidth />
<ParamHeight /> <ParamHeight />
<ImageToImageStrength /> <ImageToImageStrength />