diff --git a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts index 55887eb3be..5b57fcd2bb 100644 --- a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts @@ -13,6 +13,7 @@ type UseGroupedModelComboboxArg = { onChange: (value: T | null) => void; getIsDisabled?: (model: T) => boolean; isLoading?: boolean; + groupByType?: boolean; }; type UseGroupedModelComboboxReturn = { @@ -23,17 +24,21 @@ type UseGroupedModelComboboxReturn = { noOptionsMessage: () => string; }; +const groupByBaseFunc = (model: T) => model.base.toUpperCase(); +const groupByBaseAndTypeFunc = (model: T) => + `${model.base.toUpperCase()} / ${model.type.replaceAll('_', ' ').toUpperCase()}`; + export const useGroupedModelCombobox = ( arg: UseGroupedModelComboboxArg ): UseGroupedModelComboboxReturn => { const { t } = useTranslation(); const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl'); - const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading } = arg; + const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, groupByType = false } = arg; const options = useMemo[]>(() => { if (!modelConfigs) { return []; } - const groupedModels = groupBy(modelConfigs, 'base'); + const groupedModels = groupBy(modelConfigs, groupByType ? groupByBaseAndTypeFunc : groupByBaseFunc); const _options = reduce( groupedModels, (acc, val, label) => { @@ -49,9 +54,9 @@ export const useGroupedModelCombobox = ( }, [] as GroupBase[] ); - _options.sort((a) => (a.label === base_model ? -1 : 1)); + _options.sort((a) => (a.label?.split('/')[0]?.toLowerCase().includes(base_model) ? -1 : 1)); return _options; - }, [getIsDisabled, modelConfigs, base_model]); + }, [modelConfigs, groupByType, getIsDisabled, base_model]); const value = useMemo( () =>