feat(ui): add group by base & type to useGroupedModelCombobox hook

This allows comboboxes for models to have more granular groupings. For example, Control Adapter models can be grouped by base model & model type.

Before:
- `SD-1`
- `SDXL`

After:
- `SD-1 / ControlNet`
- `SD-1 / T2I Adapter`
- `SDXL / ControlNet`
- `SDXL / T2I Adapter`
This commit is contained in:
psychedelicious 2024-05-10 18:23:37 +10:00
parent 19f5a9c3a9
commit 6ff1c7d541

View File

@ -13,6 +13,7 @@ type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
onChange: (value: T | null) => void; onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean; getIsDisabled?: (model: T) => boolean;
isLoading?: boolean; isLoading?: boolean;
groupByType?: boolean;
}; };
type UseGroupedModelComboboxReturn = { type UseGroupedModelComboboxReturn = {
@ -23,17 +24,21 @@ type UseGroupedModelComboboxReturn = {
noOptionsMessage: () => string; noOptionsMessage: () => string;
}; };
const groupByBaseFunc = <T extends AnyModelConfig>(model: T) => model.base.toUpperCase();
const groupByBaseAndTypeFunc = <T extends AnyModelConfig>(model: T) =>
`${model.base.toUpperCase()} / ${model.type.replaceAll('_', ' ').toUpperCase()}`;
export const useGroupedModelCombobox = <T extends AnyModelConfig>( export const useGroupedModelCombobox = <T extends AnyModelConfig>(
arg: UseGroupedModelComboboxArg<T> arg: UseGroupedModelComboboxArg<T>
): UseGroupedModelComboboxReturn => { ): UseGroupedModelComboboxReturn => {
const { t } = useTranslation(); const { t } = useTranslation();
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl'); const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading } = arg; const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, groupByType = false } = arg;
const options = useMemo<GroupBase<ComboboxOption>[]>(() => { const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
if (!modelConfigs) { if (!modelConfigs) {
return []; return [];
} }
const groupedModels = groupBy(modelConfigs, 'base'); const groupedModels = groupBy(modelConfigs, groupByType ? groupByBaseAndTypeFunc : groupByBaseFunc);
const _options = reduce( const _options = reduce(
groupedModels, groupedModels,
(acc, val, label) => { (acc, val, label) => {
@ -49,9 +54,9 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
}, },
[] as GroupBase<ComboboxOption>[] [] as GroupBase<ComboboxOption>[]
); );
_options.sort((a) => (a.label === base_model ? -1 : 1)); _options.sort((a) => (a.label?.split('/')[0]?.toLowerCase().includes(base_model) ? -1 : 1));
return _options; return _options;
}, [getIsDisabled, modelConfigs, base_model]); }, [modelConfigs, groupByType, getIsDisabled, base_model]);
const value = useMemo( const value = useMemo(
() => () =>