feat(ui): fix main model & control adapter model selects

This commit is contained in:
psychedelicious 2024-02-16 22:41:09 +11:00
parent db363b5178
commit e50b76571a
9 changed files with 154 additions and 102 deletions

View File

@ -0,0 +1,88 @@
import type { Item } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/util';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import { filter } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
data: EntityState<T, string> | undefined;
isLoading: boolean;
selectedModel?: ModelIdentifierWithBase | null;
onChange: (value: T | null) => void;
modelFilter?: (model: T) => boolean;
isModelDisabled?: (model: T) => boolean;
};
type UseModelCustomSelectReturn = {
selectedItem: Item | null;
items: Item[];
onChange: (item: Item | null) => void;
placeholder: string;
};
const modelFilterDefault = () => true;
const isModelDisabledDefault = () => false;
export const useModelCustomSelect = <T extends AnyModelConfig>({
data,
isLoading,
selectedModel,
onChange,
modelFilter = modelFilterDefault,
isModelDisabled = isModelDisabledDefault,
}: UseModelCustomSelectArg<T>): UseModelCustomSelectReturn => {
const { t } = useTranslation();
const items: Item[] = useMemo(
() =>
data
? filter(data.entities, modelFilter).map<Item>((m) => ({
label: m.name,
value: m.key,
description: m.description,
group: MODEL_TYPE_SHORT_MAP[m.base],
isDisabled: isModelDisabled(m),
}))
: EMPTY_ARRAY,
[data, isModelDisabled, modelFilter]
);
const _onChange = useCallback(
(item: Item | null) => {
if (!item || !data) {
return;
}
const model = data.entities[item.value];
if (!model) {
return;
}
onChange(model);
},
[data, onChange]
);
const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]);
const placeholder = useMemo(() => {
if (isLoading) {
return t('common.loading');
}
if (items.length === 0) {
return t('models.noModelsAvailable');
}
return t('models.selectModel');
}, [isLoading, items, t]);
return {
items,
onChange: _onChange,
selectedItem,
placeholder,
};
};

View File

@ -1,34 +1,27 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { CustomSelect, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
import { useControlAdapterModelEntities } from 'features/controlAdapters/hooks/useControlAdapterModelEntities';
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { pick } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig, ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types';
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types';
type ParamControlAdapterModelProps = {
id: string;
};
const selectMainModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
const isEnabled = useControlAdapterIsEnabled(id);
const controlAdapterType = useControlAdapterType(id);
const model = useControlAdapterModel(id);
const dispatch = useAppDispatch();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const mainModel = useAppSelector(selectMainModel);
const { t } = useTranslation();
const models = useControlAdapterModelEntities(controlAdapterType);
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
const _onChange = useCallback(
(model: ControlNetConfig | IPAdapterConfig | T2IAdapterConfig | null) => {
@ -50,34 +43,18 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
[controlAdapterType, model]
);
const getIsDisabled = useCallback(
(model: AnyModelConfig): boolean => {
const isCompatible = currentBaseModel === model.base;
const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible;
},
[currentBaseModel]
);
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: models,
onChange: _onChange,
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
data,
isLoading,
selectedModel,
getIsDisabled,
onChange: _onChange,
modelFilter: (model) => model.base === currentBaseModel,
});
return (
<Tooltip label={value?.description}>
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base !== model?.base}>
<Combobox
options={options}
placeholder={t('controlnet.selectModel')}
value={value}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
<FormControl isDisabled={!items.length || !isEnabled} isInvalid={!selectedItem || !items.length}>
<CustomSelect selectedItem={selectedItem} placeholder={placeholder} items={items} onChange={onChange} />
</FormControl>
);
};

View File

@ -1,23 +0,0 @@
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
import {
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
} from 'services/api/endpoints/models';
export const useControlAdapterModelEntities = (type?: ControlAdapterType) => {
const { data: controlNetModelsData } = useGetControlNetModelsQuery();
const { data: t2iAdapterModelsData } = useGetT2IAdapterModelsQuery();
const { data: ipAdapterModelsData } = useGetIPAdapterModelsQuery();
if (type === 'controlnet') {
return controlNetModelsData;
}
if (type === 't2i_adapter') {
return t2iAdapterModelsData;
}
if (type === 'ip_adapter') {
return ipAdapterModelsData;
}
return;
};

View File

@ -0,0 +1,26 @@
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
import {
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
} from 'services/api/endpoints/models';
export const useControlAdapterModelQuery = (type: ControlAdapterType) => {
const controlNetModelsQuery = useGetControlNetModelsQuery();
const t2iAdapterModelsQuery = useGetT2IAdapterModelsQuery();
const ipAdapterModelsQuery = useGetIPAdapterModelsQuery();
if (type === 'controlnet') {
return controlNetModelsQuery;
}
if (type === 't2i_adapter') {
return t2iAdapterModelsQuery;
}
if (type === 'ip_adapter') {
return ipAdapterModelsQuery;
}
// Assert that the end of the function is not reachable.
const exhaustiveCheck: never = type;
return exhaustiveCheck;
};

View File

@ -5,14 +5,16 @@ import {
selectControlAdaptersSlice,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { useMemo } from 'react';
import { assert } from 'tsafe';
export const useControlAdapterType = (id: string) => {
const selector = useMemo(
() =>
createMemoizedSelector(
selectControlAdaptersSlice,
(controlAdapters) => selectControlAdapterById(controlAdapters, id)?.type
),
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
const type = selectControlAdapterById(controlAdapters, id)?.type;
assert(type !== undefined, `Control adapter with id ${id} not found`);
return type;
}),
[id]
);

View File

@ -2,11 +2,7 @@ import { Flex, Text } from '@invoke-ai/ui-library';
import { memo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import type {
DiffusersModelConfig,
LoRAConfig,
MainModelConfig,
} from 'services/api/endpoints/models';
import type { DiffusersModelConfig, LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';

View File

@ -1,62 +1,47 @@
import { Box, Combobox, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
import { CustomSelect, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
import { modelSelected } from 'features/parameters/store/actions';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { pick } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import type { MainModelConfig } from 'services/api/endpoints/models';
import { getModelId, mainModelsAdapterSelectors, useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import type { MainModelConfig } from 'services/api/types';
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
const ParamMainModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const model = useAppSelector(selectModel);
const selectedModel = useAppSelector(selectModel);
const { data, isLoading } = useGetMainModelsQuery(NON_REFINER_BASE_MODELS);
const tooltipLabel = useMemo(() => {
if (!data || !model) {
return;
}
return mainModelsAdapterSelectors.selectById(data, getModelId(model))?.description;
}, [data, model]);
const _onChange = useCallback(
(model: MainModelConfig | null) => {
if (!model) {
return;
}
dispatch(modelSelected(pick(model, ['base_model', 'model_name', 'model_type'])));
dispatch(modelSelected({ key: model.key, base: model.base }));
},
[dispatch]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
onChange: _onChange,
selectedModel: model,
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
data,
isLoading,
selectedModel,
onChange: _onChange,
});
return (
<FormControl isDisabled={!options.length} isInvalid={!options.length}>
<FormControl isDisabled={!items.length} isInvalid={!selectedItem || !items.length}>
<InformationalPopover feature="paramModel">
<FormLabel>{t('modelManager.model')}</FormLabel>
</InformationalPopover>
<Tooltip label={tooltipLabel}>
<Box w="full">
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</Box>
</Tooltip>
<CustomSelect selectedItem={selectedItem} placeholder={placeholder} items={items} onChange={onChange} />
</FormControl>
);
};

View File

@ -1,6 +1,7 @@
import { createAction } from '@reduxjs/toolkit';
import type { ImageDTO, MainModelField } from 'services/api/types';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import type { ImageDTO } from 'services/api/types';
export const initialImageSelected = createAction<ImageDTO | undefined>('generation/initialImageSelected');
export const modelSelected = createAction<MainModelField>('generation/modelSelected');
export const modelSelected = createAction<ParameterModel>('generation/modelSelected');

View File

@ -17,8 +17,8 @@ export const MODEL_TYPE_MAP = {
*/
export const MODEL_TYPE_SHORT_MAP = {
any: 'Any',
'sd-1': 'SD1',
'sd-2': 'SD2',
'sd-1': 'SD1.X',
'sd-2': 'SD2.X',
sdxl: 'SDXL',
'sdxl-refiner': 'SDXLR',
};