mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): only show refiner models on refiner model select
This commit is contained in:
parent
4e9841c924
commit
3e6173ee8c
@ -34,7 +34,7 @@ export const useGroupedModelInvSelect = <T extends AnyModelConfigEntity>(
|
||||
);
|
||||
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } =
|
||||
arg;
|
||||
const options = useMemo(() => {
|
||||
const options = useMemo<GroupBase<InvSelectOption>[]>(() => {
|
||||
if (!modelEntities) {
|
||||
return [];
|
||||
}
|
||||
|
@ -0,0 +1,91 @@
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import { map } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import { getModelId } from 'services/api/endpoints/models';
|
||||
|
||||
import type { InvSelectOnChange, InvSelectOption } from './types';
|
||||
|
||||
type UseModelInvSelectArg<T extends AnyModelConfigEntity> = {
|
||||
modelEntities: EntityState<T, string> | undefined;
|
||||
selectedModel?: Pick<T, 'base_model' | 'model_name' | 'model_type'> | null;
|
||||
onChange: (value: T | null) => void;
|
||||
getIsDisabled?: (model: T) => boolean;
|
||||
optionsFilter?: (model: T) => boolean;
|
||||
isLoading?: boolean;
|
||||
};
|
||||
|
||||
type UseModelInvSelectReturn = {
|
||||
value: InvSelectOption | undefined | null;
|
||||
options: InvSelectOption[];
|
||||
onChange: InvSelectOnChange;
|
||||
placeholder: string;
|
||||
noOptionsMessage: () => string;
|
||||
};
|
||||
|
||||
export const useModelInvSelect = <T extends AnyModelConfigEntity>(
|
||||
arg: UseModelInvSelectArg<T>
|
||||
): UseModelInvSelectReturn => {
|
||||
const { t } = useTranslation();
|
||||
const {
|
||||
modelEntities,
|
||||
selectedModel,
|
||||
getIsDisabled,
|
||||
onChange,
|
||||
isLoading,
|
||||
optionsFilter = () => true,
|
||||
} = arg;
|
||||
const options = useMemo<InvSelectOption[]>(() => {
|
||||
if (!modelEntities) {
|
||||
return [];
|
||||
}
|
||||
return map(modelEntities.entities)
|
||||
.filter(optionsFilter)
|
||||
.map((model) => ({
|
||||
label: model.model_name,
|
||||
value: model.id,
|
||||
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
|
||||
}));
|
||||
}, [optionsFilter, getIsDisabled, modelEntities]);
|
||||
|
||||
const value = useMemo(
|
||||
() =>
|
||||
options.find((m) =>
|
||||
selectedModel ? m.value === getModelId(selectedModel) : false
|
||||
),
|
||||
[options, selectedModel]
|
||||
);
|
||||
|
||||
const _onChange = useCallback<InvSelectOnChange>(
|
||||
(v) => {
|
||||
if (!v) {
|
||||
onChange(null);
|
||||
return;
|
||||
}
|
||||
const model = modelEntities?.entities[v.value];
|
||||
if (!model) {
|
||||
onChange(null);
|
||||
return;
|
||||
}
|
||||
onChange(model);
|
||||
},
|
||||
[modelEntities?.entities, onChange]
|
||||
);
|
||||
|
||||
const placeholder = useMemo(() => {
|
||||
if (isLoading) {
|
||||
return t('common.loading');
|
||||
}
|
||||
|
||||
if (options.length === 0) {
|
||||
return t('models.noModelsAvailable');
|
||||
}
|
||||
|
||||
return t('models.selectModel');
|
||||
}, [isLoading, options, t]);
|
||||
|
||||
const noOptionsMessage = useCallback(() => t('models.noMatchingModels'), [t]);
|
||||
|
||||
return { options, value, onChange: _onChange, placeholder, noOptionsMessage };
|
||||
};
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InvControl } from 'common/components/InvControl/InvControl';
|
||||
import { InvSelect } from 'common/components/InvSelect/InvSelect';
|
||||
import { useGroupedModelInvSelect } from 'common/components/InvSelect/useGroupedModelInvSelect';
|
||||
import { useModelInvSelect } from 'common/components/InvSelect/useModelInvSelect';
|
||||
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -15,6 +15,9 @@ const selector = createMemoizedSelector(stateSelector, (state) => ({
|
||||
model: state.sdxl.refinerModel,
|
||||
}));
|
||||
|
||||
const optionsFilter = (model: MainModelConfigEntity) =>
|
||||
model.base_model === 'sdxl-refiner';
|
||||
|
||||
const ParamSDXLRefinerModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { model } = useAppSelector(selector);
|
||||
@ -37,11 +40,12 @@ const ParamSDXLRefinerModelSelect = () => {
|
||||
[dispatch]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } =
|
||||
useGroupedModelInvSelect({
|
||||
useModelInvSelect({
|
||||
modelEntities: data,
|
||||
onChange: _onChange,
|
||||
selectedModel: model,
|
||||
isLoading,
|
||||
optionsFilter,
|
||||
});
|
||||
return (
|
||||
<InvControl
|
||||
|
Loading…
x
Reference in New Issue
Block a user