feat(ui): enhance IAICustomSelect

Now accepts an array of strings or array of `IAICustomSelectOption`s. This supports custom labels and tooltips within the select component.
This commit is contained in:
psychedelicious 2023-06-09 15:56:43 +10:00
parent 6ad7cc4f2a
commit a33327c651
7 changed files with 215 additions and 116 deletions

View File

@ -2,7 +2,6 @@ import { CheckIcon, ChevronUpIcon } from '@chakra-ui/icons';
import {
Box,
Flex,
FlexProps,
FormControl,
FormControlProps,
FormLabel,
@ -16,6 +15,7 @@ import {
} from '@chakra-ui/react';
import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom';
import { useSelect } from 'downshift';
import { isString } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useMemo } from 'react';
@ -23,15 +23,19 @@ import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles';
export type ItemTooltips = { [key: string]: string };
export type IAICustomSelectOption = {
value: string;
label: string;
tooltip?: string;
};
type IAICustomSelectProps = {
label?: string;
items: string[];
itemTooltips?: ItemTooltips;
selectedItem: string;
setSelectedItem: (v: string | null | undefined) => void;
value: string;
data: IAICustomSelectOption[] | string[];
onChange: (v: string) => void;
withCheckIcon?: boolean;
formControlProps?: FormControlProps;
buttonProps?: FlexProps;
tooltip?: string;
tooltipProps?: Omit<TooltipProps, 'children'>;
ellipsisPosition?: 'start' | 'end';
@ -40,18 +44,33 @@ type IAICustomSelectProps = {
const IAICustomSelect = (props: IAICustomSelectProps) => {
const {
label,
items,
itemTooltips,
setSelectedItem,
selectedItem,
withCheckIcon,
formControlProps,
tooltip,
buttonProps,
tooltipProps,
ellipsisPosition = 'end',
data,
value,
onChange,
} = props;
const values = useMemo(() => {
return data.map<IAICustomSelectOption>((v) => {
if (isString(v)) {
return { value: v, label: v };
}
return v;
});
}, [data]);
const stringValues = useMemo(() => {
return values.map((v) => v.value);
}, [values]);
const valueData = useMemo(() => {
return values.find((v) => v.value === value);
}, [values, value]);
const {
isOpen,
getToggleButtonProps,
@ -60,10 +79,11 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
highlightedIndex,
getItemProps,
} = useSelect({
items,
selectedItem,
onSelectedItemChange: ({ selectedItem: newSelectedItem }) =>
setSelectedItem(newSelectedItem),
items: stringValues,
selectedItem: value,
onSelectedItemChange: ({ selectedItem: newSelectedItem }) => {
newSelectedItem && onChange(newSelectedItem);
},
});
const { refs, floatingStyles } = useFloating<HTMLButtonElement>({
@ -94,7 +114,6 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
<Tooltip label={tooltip} {...tooltipProps}>
<Flex
{...getToggleButtonProps({ ref: refs.setReference })}
{...buttonProps}
sx={{
alignItems: 'center',
userSelect: 'none',
@ -119,7 +138,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
direction: labelTextDirection,
}}
>
{selectedItem}
{valueData?.label}
</Text>
<ChevronUpIcon
sx={{
@ -155,8 +174,8 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
}}
>
<OverlayScrollbarsComponent>
{items.map((item, index) => {
const isSelected = selectedItem === item;
{values.map((v, index) => {
const isSelected = value === v.value;
const isHighlighted = highlightedIndex === index;
const fontWeight = isSelected ? 700 : 500;
const bg = isHighlighted
@ -166,9 +185,9 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
: undefined;
return (
<Tooltip
isDisabled={!itemTooltips}
key={`${item}${index}`}
label={itemTooltips?.[item]}
isDisabled={!v.tooltip}
key={`${v.value}${index}`}
label={v.tooltip}
hasArrow
placement="right"
>
@ -182,8 +201,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
transitionProperty: 'common',
transitionDuration: '0.15s',
}}
key={`${item}${index}`}
{...getItemProps({ item, index })}
{...getItemProps({ item: v.value, index })}
>
{withCheckIcon ? (
<Grid gridTemplateColumns="1.25rem auto">
@ -198,7 +216,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
fontWeight,
}}
>
{item}
{v.label}
</Text>
</GridItem>
</Grid>
@ -210,7 +228,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
fontWeight,
}}
>
{item}
{v.label}
</Text>
)}
</ListItem>

View File

@ -1,17 +1,26 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAICustomSelect from 'common/components/IAICustomSelect';
import IAICustomSelect, {
IAICustomSelectOption,
} from 'common/components/IAICustomSelect';
import {
CONTROLNET_MODELS,
ControlNetModel,
ControlNetModelName,
} from 'features/controlNet/store/constants';
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
type ParamControlNetModelProps = {
controlNetId: string;
model: ControlNetModel;
model: ControlNetModelName;
};
const DATA: IAICustomSelectOption[] = map(CONTROLNET_MODELS, (m) => ({
value: m.type,
label: m.label,
tooltip: m.type,
}));
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const { controlNetId, model } = props;
const dispatch = useAppDispatch();
@ -19,7 +28,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const handleModelChanged = useCallback(
(val: string | null | undefined) => {
// TODO: do not cast
const model = val as ControlNetModel;
const model = val as ControlNetModelName;
dispatch(controlNetModelChanged({ controlNetId, model }));
},
[controlNetId, dispatch]
@ -29,9 +38,9 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
<IAICustomSelect
tooltip={model}
tooltipProps={{ placement: 'top', hasArrow: true }}
items={CONTROLNET_MODELS}
selectedItem={model}
setSelectedItem={handleModelChanged}
data={DATA}
value={model}
onChange={handleModelChanged}
ellipsisPosition="start"
withCheckIcon
/>

View File

@ -1,4 +1,6 @@
import IAICustomSelect from 'common/components/IAICustomSelect';
import IAICustomSelect, {
IAICustomSelectOption,
} from 'common/components/IAICustomSelect';
import { memo, useCallback } from 'react';
import {
ControlNetProcessorNode,
@ -7,15 +9,28 @@ import {
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
import { map } from 'lodash-es';
type ParamControlNetProcessorSelectProps = {
controlNetId: string;
processorNode: ControlNetProcessorNode;
};
const CONTROLNET_PROCESSOR_TYPES = Object.keys(
CONTROLNET_PROCESSORS
) as ControlNetProcessorType[];
const CONTROLNET_PROCESSOR_TYPES: IAICustomSelectOption[] = map(
CONTROLNET_PROCESSORS,
(p) => ({
value: p.type,
label: p.label,
tooltip: p.description,
})
).sort((a, b) =>
// sort 'none' to the top
a.value === 'none'
? -1
: b.value === 'none'
? 1
: a.label.localeCompare(b.label)
);
const ParamControlNetProcessorSelect = (
props: ParamControlNetProcessorSelectProps
@ -36,9 +51,9 @@ const ParamControlNetProcessorSelect = (
return (
<IAICustomSelect
label="Processor"
items={CONTROLNET_PROCESSOR_TYPES}
selectedItem={processorNode.type ?? 'canny_image_processor'}
setSelectedItem={handleProcessorTypeChanged}
value={processorNode.type ?? 'canny_image_processor'}
data={CONTROLNET_PROCESSOR_TYPES}
onChange={handleProcessorTypeChanged}
withCheckIcon
/>
);

View File

@ -5,12 +5,12 @@ import {
} from './types';
type ControlNetProcessorsDict = Record<
ControlNetProcessorType,
string,
{
type: ControlNetProcessorType;
type: ControlNetProcessorType | 'none';
label: string;
description: string;
default: RequiredControlNetProcessorNode;
default: RequiredControlNetProcessorNode | { type: 'none' };
}
>;
@ -23,10 +23,10 @@ type ControlNetProcessorsDict = Record<
*
* TODO: Generate from the OpenAPI schema
*/
export const CONTROLNET_PROCESSORS = {
export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
none: {
type: 'none',
label: 'None',
label: 'none',
description: '',
default: {
type: 'none',
@ -116,7 +116,7 @@ export const CONTROLNET_PROCESSORS = {
},
mlsd_image_processor: {
type: 'mlsd_image_processor',
label: 'MLSD',
label: 'M-LSD',
description: '',
default: {
id: 'mlsd_image_processor',
@ -174,39 +174,98 @@ export const CONTROLNET_PROCESSORS = {
},
};
export const CONTROLNET_MODELS = [
'lllyasviel/control_v11p_sd15_canny',
'lllyasviel/control_v11p_sd15_inpaint',
'lllyasviel/control_v11p_sd15_mlsd',
'lllyasviel/control_v11f1p_sd15_depth',
'lllyasviel/control_v11p_sd15_normalbae',
'lllyasviel/control_v11p_sd15_seg',
'lllyasviel/control_v11p_sd15_lineart',
'lllyasviel/control_v11p_sd15s2_lineart_anime',
'lllyasviel/control_v11p_sd15_scribble',
'lllyasviel/control_v11p_sd15_softedge',
'lllyasviel/control_v11e_sd15_shuffle',
'lllyasviel/control_v11p_sd15_openpose',
'lllyasviel/control_v11f1e_sd15_tile',
'lllyasviel/control_v11e_sd15_ip2p',
'CrucibleAI/ControlNetMediaPipeFace',
];
export type ControlNetModel = (typeof CONTROLNET_MODELS)[number];
export const CONTROLNET_MODEL_MAP: Record<
ControlNetModel,
ControlNetProcessorType
> = {
'lllyasviel/control_v11p_sd15_canny': 'canny_image_processor',
'lllyasviel/control_v11p_sd15_mlsd': 'mlsd_image_processor',
'lllyasviel/control_v11f1p_sd15_depth': 'midas_depth_image_processor',
'lllyasviel/control_v11p_sd15_normalbae': 'normalbae_image_processor',
'lllyasviel/control_v11p_sd15_lineart': 'lineart_image_processor',
'lllyasviel/control_v11p_sd15s2_lineart_anime':
'lineart_anime_image_processor',
'lllyasviel/control_v11p_sd15_softedge': 'hed_image_processor',
'lllyasviel/control_v11e_sd15_shuffle': 'content_shuffle_image_processor',
'lllyasviel/control_v11p_sd15_openpose': 'openpose_image_processor',
'CrucibleAI/ControlNetMediaPipeFace': 'mediapipe_face_processor',
type ControlNetModel = {
type: string;
label: string;
description?: string;
defaultProcessor?: ControlNetProcessorType;
};
export const CONTROLNET_MODELS: Record<string, ControlNetModel> = {
'lllyasviel/control_v11p_sd15_canny': {
type: 'lllyasviel/control_v11p_sd15_canny',
label: 'Canny',
description: '',
defaultProcessor: 'canny_image_processor',
},
'lllyasviel/control_v11p_sd15_inpaint': {
type: 'lllyasviel/control_v11p_sd15_inpaint',
label: 'Inpaint',
description: 'Requires preprocessed control image',
},
'lllyasviel/control_v11p_sd15_mlsd': {
type: 'lllyasviel/control_v11p_sd15_mlsd',
label: 'M-LSD',
description: '',
defaultProcessor: 'mlsd_image_processor',
},
'lllyasviel/control_v11f1p_sd15_depth': {
type: 'lllyasviel/control_v11f1p_sd15_depth',
label: 'Depth',
description: '',
defaultProcessor: 'midas_depth_image_processor',
},
'lllyasviel/control_v11p_sd15_normalbae': {
type: 'lllyasviel/control_v11p_sd15_normalbae',
label: 'Normal Map (BAE)',
description: '',
defaultProcessor: 'normalbae_image_processor',
},
'lllyasviel/control_v11p_sd15_seg': {
type: 'lllyasviel/control_v11p_sd15_seg',
label: 'Segment Anything',
description: 'Requires preprocessed control image',
},
'lllyasviel/control_v11p_sd15_lineart': {
type: 'lllyasviel/control_v11p_sd15_lineart',
label: 'Lineart',
description: '',
defaultProcessor: 'lineart_image_processor',
},
'lllyasviel/control_v11p_sd15s2_lineart_anime': {
type: 'lllyasviel/control_v11p_sd15s2_lineart_anime',
label: 'Lineart Anime',
description: '',
defaultProcessor: 'lineart_anime_image_processor',
},
'lllyasviel/control_v11p_sd15_scribble': {
type: 'lllyasviel/control_v11p_sd15_scribble',
label: 'Scribble',
description: 'Requires preprocessed control image',
},
'lllyasviel/control_v11p_sd15_softedge': {
type: 'lllyasviel/control_v11p_sd15_softedge',
label: 'Soft Edge',
description: '',
defaultProcessor: 'hed_image_processor',
},
'lllyasviel/control_v11e_sd15_shuffle': {
type: 'lllyasviel/control_v11e_sd15_shuffle',
label: 'Content Shuffle',
description: '',
defaultProcessor: 'content_shuffle_image_processor',
},
'lllyasviel/control_v11p_sd15_openpose': {
type: 'lllyasviel/control_v11p_sd15_openpose',
label: 'Openpose',
description: '',
defaultProcessor: 'openpose_image_processor',
},
'lllyasviel/control_v11f1e_sd15_tile': {
type: 'lllyasviel/control_v11f1e_sd15_tile',
label: 'Tile (experimental)',
},
'lllyasviel/control_v11e_sd15_ip2p': {
type: 'lllyasviel/control_v11e_sd15_ip2p',
label: 'Pix2Pix (experimental)',
description: 'Requires preprocessed control image',
},
'CrucibleAI/ControlNetMediaPipeFace': {
type: 'CrucibleAI/ControlNetMediaPipeFace',
label: 'Mediapipe Face',
description: '',
defaultProcessor: 'mediapipe_face_processor',
},
};
export type ControlNetModelName = keyof typeof CONTROLNET_MODELS;

View File

@ -9,9 +9,8 @@ import {
} from './types';
import {
CONTROLNET_MODELS,
CONTROLNET_MODEL_MAP,
CONTROLNET_PROCESSORS,
ControlNetModel,
ControlNetModelName,
} from './constants';
import { controlNetImageProcessed } from './actions';
import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
@ -21,7 +20,7 @@ import { appSocketInvocationError } from 'services/events/actions';
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true,
model: CONTROLNET_MODELS[0],
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
@ -36,7 +35,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
export type ControlNetConfig = {
controlNetId: string;
isEnabled: boolean;
model: ControlNetModel;
model: ControlNetModelName;
weight: number;
beginStepPct: number;
endStepPct: number;
@ -138,14 +137,17 @@ export const controlNetSlice = createSlice({
},
controlNetModelChanged: (
state,
action: PayloadAction<{ controlNetId: string; model: ControlNetModel }>
action: PayloadAction<{
controlNetId: string;
model: ControlNetModelName;
}>
) => {
const { controlNetId, model } = action.payload;
state.controlNets[controlNetId].model = model;
state.controlNets[controlNetId].processedControlImage = null;
if (state.controlNets[controlNetId].shouldAutoConfig) {
const processorType = CONTROLNET_MODEL_MAP[model];
const processorType = CONTROLNET_MODELS[model].defaultProcessor;
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
@ -225,7 +227,8 @@ export const controlNetSlice = createSlice({
if (newShouldAutoConfig) {
// manage the processor for the user
const processorType =
CONTROLNET_MODEL_MAP[state.controlNets[controlNetId].model];
CONTROLNET_MODELS[state.controlNets[controlNetId].model]
.defaultProcessor;
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[

View File

@ -14,9 +14,11 @@ const selector = createSelector(
(ui, generation) => {
// TODO: DPMSolverSinglestepScheduler is fixed in https://github.com/huggingface/diffusers/pull/3413
// but we need to wait for the next release before removing this special handling.
const allSchedulers = ui.schedulers.filter((scheduler) => {
return !['dpmpp_2s'].includes(scheduler);
});
const allSchedulers = ui.schedulers
.filter((scheduler) => {
return !['dpmpp_2s'].includes(scheduler);
})
.sort((a, b) => a.localeCompare(b));
return {
scheduler: generation.scheduler,
@ -45,9 +47,9 @@ const ParamScheduler = () => {
return (
<IAICustomSelect
label={t('parameters.scheduler')}
selectedItem={scheduler}
setSelectedItem={handleChange}
items={allSchedulers}
value={scheduler}
data={allSchedulers}
onChange={handleChange}
withCheckIcon
/>
);

View File

@ -4,34 +4,29 @@ import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
selectModelsAll,
selectModelsById,
selectModelsIds,
} from '../store/modelSlice';
import { selectModelsAll, selectModelsById } from '../store/modelSlice';
import { RootState } from 'app/store/store';
import { modelSelected } from 'features/parameters/store/generationSlice';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import IAICustomSelect, {
ItemTooltips,
IAICustomSelectOption,
} from 'common/components/IAICustomSelect';
const selector = createSelector(
[(state: RootState) => state, generationSelector],
(state, generation) => {
const selectedModel = selectModelsById(state, generation.model);
const allModelNames = selectModelsIds(state).map((id) => String(id));
const allModelTooltips = selectModelsAll(state).reduce(
(allModelTooltips, model) => {
allModelTooltips[model.name] = model.description ?? '';
return allModelTooltips;
},
{} as ItemTooltips
);
const modelData = selectModelsAll(state)
.map<IAICustomSelectOption>((m) => ({
value: m.name,
label: m.name,
tooltip: m.description,
}))
.sort((a, b) => a.label.localeCompare(b.label));
return {
allModelNames,
allModelTooltips,
selectedModel,
modelData,
};
},
{
@ -44,8 +39,7 @@ const selector = createSelector(
const ModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { allModelNames, allModelTooltips, selectedModel } =
useAppSelector(selector);
const { selectedModel, modelData } = useAppSelector(selector);
const handleChangeModel = useCallback(
(v: string | null | undefined) => {
if (!v) {
@ -60,10 +54,9 @@ const ModelSelect = () => {
<IAICustomSelect
label={t('modelManager.model')}
tooltip={selectedModel?.description}
items={allModelNames}
itemTooltips={allModelTooltips}
selectedItem={selectedModel?.name ?? ''}
setSelectedItem={handleChangeModel}
data={modelData}
value={selectedModel?.name ?? ''}
onChange={handleChangeModel}
withCheckIcon={true}
tooltipProps={{ placement: 'top', hasArrow: true }}
/>