From 6ff1c7d54182995fd7c005ba0ceea6804640a292 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 10 May 2024 18:23:37 +1000 Subject: [PATCH] feat(ui): add group by base & type to useGroupedModelCombobox hook This allows comboboxes for models to have more granular groupings. For example, Control Adapter models can be grouped by base model & model type. Before: - `SD-1` - `SDXL` After: - `SD-1 / ControlNet` - `SD-1 / T2I Adapter` - `SDXL / ControlNet` - `SDXL / T2I Adapter` --- .../web/src/common/hooks/useGroupedModelCombobox.ts | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts index 55887eb3be..5b57fcd2bb 100644 --- a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts @@ -13,6 +13,7 @@ type UseGroupedModelComboboxArg = { onChange: (value: T | null) => void; getIsDisabled?: (model: T) => boolean; isLoading?: boolean; + groupByType?: boolean; }; type UseGroupedModelComboboxReturn = { @@ -23,17 +24,21 @@ type UseGroupedModelComboboxReturn = { noOptionsMessage: () => string; }; +const groupByBaseFunc = (model: T) => model.base.toUpperCase(); +const groupByBaseAndTypeFunc = (model: T) => + `${model.base.toUpperCase()} / ${model.type.replaceAll('_', ' ').toUpperCase()}`; + export const useGroupedModelCombobox = ( arg: UseGroupedModelComboboxArg ): UseGroupedModelComboboxReturn => { const { t } = useTranslation(); const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl'); - const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading } = arg; + const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, groupByType = false } = arg; const options = useMemo[]>(() => { if (!modelConfigs) { return []; } - const groupedModels = groupBy(modelConfigs, 'base'); + const groupedModels = groupBy(modelConfigs, groupByType ? groupByBaseAndTypeFunc : groupByBaseFunc); const _options = reduce( groupedModels, (acc, val, label) => { @@ -49,9 +54,9 @@ export const useGroupedModelCombobox = ( }, [] as GroupBase[] ); - _options.sort((a) => (a.label === base_model ? -1 : 1)); + _options.sort((a) => (a.label?.split('/')[0]?.toLowerCase().includes(base_model) ? -1 : 1)); return _options; - }, [getIsDisabled, modelConfigs, base_model]); + }, [modelConfigs, groupByType, getIsDisabled, base_model]); const value = useMemo( () =>