feat(ui): only show refiner models on refiner model select

This commit is contained in:
psychedelicious 2023-12-29 17:38:01 +11:00 committed by Kent Keirsey
parent 4e9841c924
commit 3e6173ee8c
3 changed files with 98 additions and 3 deletions

View File

@ -34,7 +34,7 @@ export const useGroupedModelInvSelect = <T extends AnyModelConfigEntity>(
);
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } =
arg;
const options = useMemo(() => {
const options = useMemo<GroupBase<InvSelectOption>[]>(() => {
if (!modelEntities) {
return [];
}

View File

@ -0,0 +1,91 @@
import type { EntityState } from '@reduxjs/toolkit';
import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfigEntity } from 'services/api/endpoints/models';
import { getModelId } from 'services/api/endpoints/models';
import type { InvSelectOnChange, InvSelectOption } from './types';
type UseModelInvSelectArg<T extends AnyModelConfigEntity> = {
modelEntities: EntityState<T, string> | undefined;
selectedModel?: Pick<T, 'base_model' | 'model_name' | 'model_type'> | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
optionsFilter?: (model: T) => boolean;
isLoading?: boolean;
};
type UseModelInvSelectReturn = {
value: InvSelectOption | undefined | null;
options: InvSelectOption[];
onChange: InvSelectOnChange;
placeholder: string;
noOptionsMessage: () => string;
};
export const useModelInvSelect = <T extends AnyModelConfigEntity>(
arg: UseModelInvSelectArg<T>
): UseModelInvSelectReturn => {
const { t } = useTranslation();
const {
modelEntities,
selectedModel,
getIsDisabled,
onChange,
isLoading,
optionsFilter = () => true,
} = arg;
const options = useMemo<InvSelectOption[]>(() => {
if (!modelEntities) {
return [];
}
return map(modelEntities.entities)
.filter(optionsFilter)
.map((model) => ({
label: model.model_name,
value: model.id,
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
}));
}, [optionsFilter, getIsDisabled, modelEntities]);
const value = useMemo(
() =>
options.find((m) =>
selectedModel ? m.value === getModelId(selectedModel) : false
),
[options, selectedModel]
);
const _onChange = useCallback<InvSelectOnChange>(
(v) => {
if (!v) {
onChange(null);
return;
}
const model = modelEntities?.entities[v.value];
if (!model) {
onChange(null);
return;
}
onChange(model);
},
[modelEntities?.entities, onChange]
);
const placeholder = useMemo(() => {
if (isLoading) {
return t('common.loading');
}
if (options.length === 0) {
return t('models.noModelsAvailable');
}
return t('models.selectModel');
}, [isLoading, options, t]);
const noOptionsMessage = useCallback(() => t('models.noMatchingModels'), [t]);
return { options, value, onChange: _onChange, placeholder, noOptionsMessage };
};

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InvControl } from 'common/components/InvControl/InvControl';
import { InvSelect } from 'common/components/InvSelect/InvSelect';
import { useGroupedModelInvSelect } from 'common/components/InvSelect/useGroupedModelInvSelect';
import { useModelInvSelect } from 'common/components/InvSelect/useModelInvSelect';
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@ -15,6 +15,9 @@ const selector = createMemoizedSelector(stateSelector, (state) => ({
model: state.sdxl.refinerModel,
}));
const optionsFilter = (model: MainModelConfigEntity) =>
model.base_model === 'sdxl-refiner';
const ParamSDXLRefinerModelSelect = () => {
const dispatch = useAppDispatch();
const { model } = useAppSelector(selector);
@ -37,11 +40,12 @@ const ParamSDXLRefinerModelSelect = () => {
[dispatch]
);
const { options, value, onChange, placeholder, noOptionsMessage } =
useGroupedModelInvSelect({
useModelInvSelect({
modelEntities: data,
onChange: _onChange,
selectedModel: model,
isLoading,
optionsFilter,
});
return (
<InvControl