From a33327c651903f07625dc3e84dd3ef6f9fb15b2f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 9 Jun 2023 15:56:43 +1000 Subject: [PATCH] feat(ui): enhance IAICustomSelect Now accepts an array of strings or array of `IAICustomSelectOption`s. This supports custom labels and tooltips within the select component. --- .../src/common/components/IAICustomSelect.tsx | 70 +++++---- .../parameters/ParamControlNetModel.tsx | 23 ++- .../ParamControlNetProcessorSelect.tsx | 29 +++- .../features/controlNet/store/constants.ts | 141 +++++++++++++----- .../controlNet/store/controlNetSlice.ts | 17 ++- .../Parameters/Core/ParamScheduler.tsx | 14 +- .../system/components/ModelSelect.tsx | 37 ++--- 7 files changed, 215 insertions(+), 116 deletions(-) diff --git a/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx b/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx index 9accceb846..5dead4dce5 100644 --- a/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx @@ -2,7 +2,6 @@ import { CheckIcon, ChevronUpIcon } from '@chakra-ui/icons'; import { Box, Flex, - FlexProps, FormControl, FormControlProps, FormLabel, @@ -16,6 +15,7 @@ import { } from '@chakra-ui/react'; import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom'; import { useSelect } from 'downshift'; +import { isString } from 'lodash-es'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import { memo, useMemo } from 'react'; @@ -23,15 +23,19 @@ import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles'; export type ItemTooltips = { [key: string]: string }; +export type IAICustomSelectOption = { + value: string; + label: string; + tooltip?: string; +}; + type IAICustomSelectProps = { label?: string; - items: string[]; - itemTooltips?: ItemTooltips; - selectedItem: string; - setSelectedItem: (v: string | null | undefined) => void; + value: string; + data: IAICustomSelectOption[] | string[]; + onChange: (v: string) => void; withCheckIcon?: boolean; formControlProps?: FormControlProps; - buttonProps?: FlexProps; tooltip?: string; tooltipProps?: Omit; ellipsisPosition?: 'start' | 'end'; @@ -40,18 +44,33 @@ type IAICustomSelectProps = { const IAICustomSelect = (props: IAICustomSelectProps) => { const { label, - items, - itemTooltips, - setSelectedItem, - selectedItem, withCheckIcon, formControlProps, tooltip, - buttonProps, tooltipProps, ellipsisPosition = 'end', + data, + value, + onChange, } = props; + const values = useMemo(() => { + return data.map((v) => { + if (isString(v)) { + return { value: v, label: v }; + } + return v; + }); + }, [data]); + + const stringValues = useMemo(() => { + return values.map((v) => v.value); + }, [values]); + + const valueData = useMemo(() => { + return values.find((v) => v.value === value); + }, [values, value]); + const { isOpen, getToggleButtonProps, @@ -60,10 +79,11 @@ const IAICustomSelect = (props: IAICustomSelectProps) => { highlightedIndex, getItemProps, } = useSelect({ - items, - selectedItem, - onSelectedItemChange: ({ selectedItem: newSelectedItem }) => - setSelectedItem(newSelectedItem), + items: stringValues, + selectedItem: value, + onSelectedItemChange: ({ selectedItem: newSelectedItem }) => { + newSelectedItem && onChange(newSelectedItem); + }, }); const { refs, floatingStyles } = useFloating({ @@ -94,7 +114,6 @@ const IAICustomSelect = (props: IAICustomSelectProps) => { { direction: labelTextDirection, }} > - {selectedItem} + {valueData?.label} { }} > - {items.map((item, index) => { - const isSelected = selectedItem === item; + {values.map((v, index) => { + const isSelected = value === v.value; const isHighlighted = highlightedIndex === index; const fontWeight = isSelected ? 700 : 500; const bg = isHighlighted @@ -166,9 +185,9 @@ const IAICustomSelect = (props: IAICustomSelectProps) => { : undefined; return ( @@ -182,8 +201,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => { transitionProperty: 'common', transitionDuration: '0.15s', }} - key={`${item}${index}`} - {...getItemProps({ item, index })} + {...getItemProps({ item: v.value, index })} > {withCheckIcon ? ( @@ -198,7 +216,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => { fontWeight, }} > - {item} + {v.label} @@ -210,7 +228,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => { fontWeight, }} > - {item} + {v.label} )} diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx index 187d296a4f..222e8d657b 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx @@ -1,17 +1,26 @@ import { useAppDispatch } from 'app/store/storeHooks'; -import IAICustomSelect from 'common/components/IAICustomSelect'; +import IAICustomSelect, { + IAICustomSelectOption, +} from 'common/components/IAICustomSelect'; import { CONTROLNET_MODELS, - ControlNetModel, + ControlNetModelName, } from 'features/controlNet/store/constants'; import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice'; +import { map } from 'lodash-es'; import { memo, useCallback } from 'react'; type ParamControlNetModelProps = { controlNetId: string; - model: ControlNetModel; + model: ControlNetModelName; }; +const DATA: IAICustomSelectOption[] = map(CONTROLNET_MODELS, (m) => ({ + value: m.type, + label: m.label, + tooltip: m.type, +})); + const ParamControlNetModel = (props: ParamControlNetModelProps) => { const { controlNetId, model } = props; const dispatch = useAppDispatch(); @@ -19,7 +28,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => { const handleModelChanged = useCallback( (val: string | null | undefined) => { // TODO: do not cast - const model = val as ControlNetModel; + const model = val as ControlNetModelName; dispatch(controlNetModelChanged({ controlNetId, model })); }, [controlNetId, dispatch] @@ -29,9 +38,9 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => { diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx index 019b5ef849..19f05bc53d 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx @@ -1,4 +1,6 @@ -import IAICustomSelect from 'common/components/IAICustomSelect'; +import IAICustomSelect, { + IAICustomSelectOption, +} from 'common/components/IAICustomSelect'; import { memo, useCallback } from 'react'; import { ControlNetProcessorNode, @@ -7,15 +9,28 @@ import { import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice'; import { useAppDispatch } from 'app/store/storeHooks'; import { CONTROLNET_PROCESSORS } from '../../store/constants'; +import { map } from 'lodash-es'; type ParamControlNetProcessorSelectProps = { controlNetId: string; processorNode: ControlNetProcessorNode; }; -const CONTROLNET_PROCESSOR_TYPES = Object.keys( - CONTROLNET_PROCESSORS -) as ControlNetProcessorType[]; +const CONTROLNET_PROCESSOR_TYPES: IAICustomSelectOption[] = map( + CONTROLNET_PROCESSORS, + (p) => ({ + value: p.type, + label: p.label, + tooltip: p.description, + }) +).sort((a, b) => + // sort 'none' to the top + a.value === 'none' + ? -1 + : b.value === 'none' + ? 1 + : a.label.localeCompare(b.label) +); const ParamControlNetProcessorSelect = ( props: ParamControlNetProcessorSelectProps @@ -36,9 +51,9 @@ const ParamControlNetProcessorSelect = ( return ( ); diff --git a/invokeai/frontend/web/src/features/controlNet/store/constants.ts b/invokeai/frontend/web/src/features/controlNet/store/constants.ts index c8689badf5..3c539ba639 100644 --- a/invokeai/frontend/web/src/features/controlNet/store/constants.ts +++ b/invokeai/frontend/web/src/features/controlNet/store/constants.ts @@ -5,12 +5,12 @@ import { } from './types'; type ControlNetProcessorsDict = Record< - ControlNetProcessorType, + string, { - type: ControlNetProcessorType; + type: ControlNetProcessorType | 'none'; label: string; description: string; - default: RequiredControlNetProcessorNode; + default: RequiredControlNetProcessorNode | { type: 'none' }; } >; @@ -23,10 +23,10 @@ type ControlNetProcessorsDict = Record< * * TODO: Generate from the OpenAPI schema */ -export const CONTROLNET_PROCESSORS = { +export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = { none: { type: 'none', - label: 'None', + label: 'none', description: '', default: { type: 'none', @@ -116,7 +116,7 @@ export const CONTROLNET_PROCESSORS = { }, mlsd_image_processor: { type: 'mlsd_image_processor', - label: 'MLSD', + label: 'M-LSD', description: '', default: { id: 'mlsd_image_processor', @@ -174,39 +174,98 @@ export const CONTROLNET_PROCESSORS = { }, }; -export const CONTROLNET_MODELS = [ - 'lllyasviel/control_v11p_sd15_canny', - 'lllyasviel/control_v11p_sd15_inpaint', - 'lllyasviel/control_v11p_sd15_mlsd', - 'lllyasviel/control_v11f1p_sd15_depth', - 'lllyasviel/control_v11p_sd15_normalbae', - 'lllyasviel/control_v11p_sd15_seg', - 'lllyasviel/control_v11p_sd15_lineart', - 'lllyasviel/control_v11p_sd15s2_lineart_anime', - 'lllyasviel/control_v11p_sd15_scribble', - 'lllyasviel/control_v11p_sd15_softedge', - 'lllyasviel/control_v11e_sd15_shuffle', - 'lllyasviel/control_v11p_sd15_openpose', - 'lllyasviel/control_v11f1e_sd15_tile', - 'lllyasviel/control_v11e_sd15_ip2p', - 'CrucibleAI/ControlNetMediaPipeFace', -]; - -export type ControlNetModel = (typeof CONTROLNET_MODELS)[number]; - -export const CONTROLNET_MODEL_MAP: Record< - ControlNetModel, - ControlNetProcessorType -> = { - 'lllyasviel/control_v11p_sd15_canny': 'canny_image_processor', - 'lllyasviel/control_v11p_sd15_mlsd': 'mlsd_image_processor', - 'lllyasviel/control_v11f1p_sd15_depth': 'midas_depth_image_processor', - 'lllyasviel/control_v11p_sd15_normalbae': 'normalbae_image_processor', - 'lllyasviel/control_v11p_sd15_lineart': 'lineart_image_processor', - 'lllyasviel/control_v11p_sd15s2_lineart_anime': - 'lineart_anime_image_processor', - 'lllyasviel/control_v11p_sd15_softedge': 'hed_image_processor', - 'lllyasviel/control_v11e_sd15_shuffle': 'content_shuffle_image_processor', - 'lllyasviel/control_v11p_sd15_openpose': 'openpose_image_processor', - 'CrucibleAI/ControlNetMediaPipeFace': 'mediapipe_face_processor', +type ControlNetModel = { + type: string; + label: string; + description?: string; + defaultProcessor?: ControlNetProcessorType; }; + +export const CONTROLNET_MODELS: Record = { + 'lllyasviel/control_v11p_sd15_canny': { + type: 'lllyasviel/control_v11p_sd15_canny', + label: 'Canny', + description: '', + defaultProcessor: 'canny_image_processor', + }, + 'lllyasviel/control_v11p_sd15_inpaint': { + type: 'lllyasviel/control_v11p_sd15_inpaint', + label: 'Inpaint', + description: 'Requires preprocessed control image', + }, + 'lllyasviel/control_v11p_sd15_mlsd': { + type: 'lllyasviel/control_v11p_sd15_mlsd', + label: 'M-LSD', + description: '', + defaultProcessor: 'mlsd_image_processor', + }, + 'lllyasviel/control_v11f1p_sd15_depth': { + type: 'lllyasviel/control_v11f1p_sd15_depth', + label: 'Depth', + description: '', + defaultProcessor: 'midas_depth_image_processor', + }, + 'lllyasviel/control_v11p_sd15_normalbae': { + type: 'lllyasviel/control_v11p_sd15_normalbae', + label: 'Normal Map (BAE)', + description: '', + defaultProcessor: 'normalbae_image_processor', + }, + 'lllyasviel/control_v11p_sd15_seg': { + type: 'lllyasviel/control_v11p_sd15_seg', + label: 'Segment Anything', + description: 'Requires preprocessed control image', + }, + 'lllyasviel/control_v11p_sd15_lineart': { + type: 'lllyasviel/control_v11p_sd15_lineart', + label: 'Lineart', + description: '', + defaultProcessor: 'lineart_image_processor', + }, + 'lllyasviel/control_v11p_sd15s2_lineart_anime': { + type: 'lllyasviel/control_v11p_sd15s2_lineart_anime', + label: 'Lineart Anime', + description: '', + defaultProcessor: 'lineart_anime_image_processor', + }, + 'lllyasviel/control_v11p_sd15_scribble': { + type: 'lllyasviel/control_v11p_sd15_scribble', + label: 'Scribble', + description: 'Requires preprocessed control image', + }, + 'lllyasviel/control_v11p_sd15_softedge': { + type: 'lllyasviel/control_v11p_sd15_softedge', + label: 'Soft Edge', + description: '', + defaultProcessor: 'hed_image_processor', + }, + 'lllyasviel/control_v11e_sd15_shuffle': { + type: 'lllyasviel/control_v11e_sd15_shuffle', + label: 'Content Shuffle', + description: '', + defaultProcessor: 'content_shuffle_image_processor', + }, + 'lllyasviel/control_v11p_sd15_openpose': { + type: 'lllyasviel/control_v11p_sd15_openpose', + label: 'Openpose', + description: '', + defaultProcessor: 'openpose_image_processor', + }, + 'lllyasviel/control_v11f1e_sd15_tile': { + type: 'lllyasviel/control_v11f1e_sd15_tile', + label: 'Tile (experimental)', + }, + 'lllyasviel/control_v11e_sd15_ip2p': { + type: 'lllyasviel/control_v11e_sd15_ip2p', + label: 'Pix2Pix (experimental)', + description: 'Requires preprocessed control image', + }, + 'CrucibleAI/ControlNetMediaPipeFace': { + type: 'CrucibleAI/ControlNetMediaPipeFace', + label: 'Mediapipe Face', + description: '', + defaultProcessor: 'mediapipe_face_processor', + }, +}; + +export type ControlNetModelName = keyof typeof CONTROLNET_MODELS; diff --git a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts index 2558a38ab2..d71ff4da68 100644 --- a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts +++ b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts @@ -9,9 +9,8 @@ import { } from './types'; import { CONTROLNET_MODELS, - CONTROLNET_MODEL_MAP, CONTROLNET_PROCESSORS, - ControlNetModel, + ControlNetModelName, } from './constants'; import { controlNetImageProcessed } from './actions'; import { imageDeleted, imageUrlsReceived } from 'services/thunks/image'; @@ -21,7 +20,7 @@ import { appSocketInvocationError } from 'services/events/actions'; export const initialControlNet: Omit = { isEnabled: true, - model: CONTROLNET_MODELS[0], + model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type, weight: 1, beginStepPct: 0, endStepPct: 1, @@ -36,7 +35,7 @@ export const initialControlNet: Omit = { export type ControlNetConfig = { controlNetId: string; isEnabled: boolean; - model: ControlNetModel; + model: ControlNetModelName; weight: number; beginStepPct: number; endStepPct: number; @@ -138,14 +137,17 @@ export const controlNetSlice = createSlice({ }, controlNetModelChanged: ( state, - action: PayloadAction<{ controlNetId: string; model: ControlNetModel }> + action: PayloadAction<{ + controlNetId: string; + model: ControlNetModelName; + }> ) => { const { controlNetId, model } = action.payload; state.controlNets[controlNetId].model = model; state.controlNets[controlNetId].processedControlImage = null; if (state.controlNets[controlNetId].shouldAutoConfig) { - const processorType = CONTROLNET_MODEL_MAP[model]; + const processorType = CONTROLNET_MODELS[model].defaultProcessor; if (processorType) { state.controlNets[controlNetId].processorType = processorType; state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ @@ -225,7 +227,8 @@ export const controlNetSlice = createSlice({ if (newShouldAutoConfig) { // manage the processor for the user const processorType = - CONTROLNET_MODEL_MAP[state.controlNets[controlNetId].model]; + CONTROLNET_MODELS[state.controlNets[controlNetId].model] + .defaultProcessor; if (processorType) { state.controlNets[controlNetId].processorType = processorType; state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx index f4413c4cf6..2aa762b477 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx @@ -14,9 +14,11 @@ const selector = createSelector( (ui, generation) => { // TODO: DPMSolverSinglestepScheduler is fixed in https://github.com/huggingface/diffusers/pull/3413 // but we need to wait for the next release before removing this special handling. - const allSchedulers = ui.schedulers.filter((scheduler) => { - return !['dpmpp_2s'].includes(scheduler); - }); + const allSchedulers = ui.schedulers + .filter((scheduler) => { + return !['dpmpp_2s'].includes(scheduler); + }) + .sort((a, b) => a.localeCompare(b)); return { scheduler: generation.scheduler, @@ -45,9 +47,9 @@ const ParamScheduler = () => { return ( ); diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index be4be8ceaa..1eb8e4cb4c 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -4,34 +4,29 @@ import { isEqual } from 'lodash-es'; import { useTranslation } from 'react-i18next'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { - selectModelsAll, - selectModelsById, - selectModelsIds, -} from '../store/modelSlice'; +import { selectModelsAll, selectModelsById } from '../store/modelSlice'; import { RootState } from 'app/store/store'; import { modelSelected } from 'features/parameters/store/generationSlice'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import IAICustomSelect, { - ItemTooltips, + IAICustomSelectOption, } from 'common/components/IAICustomSelect'; const selector = createSelector( [(state: RootState) => state, generationSelector], (state, generation) => { const selectedModel = selectModelsById(state, generation.model); - const allModelNames = selectModelsIds(state).map((id) => String(id)); - const allModelTooltips = selectModelsAll(state).reduce( - (allModelTooltips, model) => { - allModelTooltips[model.name] = model.description ?? ''; - return allModelTooltips; - }, - {} as ItemTooltips - ); + + const modelData = selectModelsAll(state) + .map((m) => ({ + value: m.name, + label: m.name, + tooltip: m.description, + })) + .sort((a, b) => a.label.localeCompare(b.label)); return { - allModelNames, - allModelTooltips, selectedModel, + modelData, }; }, { @@ -44,8 +39,7 @@ const selector = createSelector( const ModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { allModelNames, allModelTooltips, selectedModel } = - useAppSelector(selector); + const { selectedModel, modelData } = useAppSelector(selector); const handleChangeModel = useCallback( (v: string | null | undefined) => { if (!v) { @@ -60,10 +54,9 @@ const ModelSelect = () => {