mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): fix main model & control adapter model selects
This commit is contained in:
parent
db363b5178
commit
e50b76571a
@ -0,0 +1,88 @@
|
||||
import type { Item } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import { EMPTY_ARRAY } from 'app/store/util';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||
import { filter } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
|
||||
data: EntityState<T, string> | undefined;
|
||||
isLoading: boolean;
|
||||
selectedModel?: ModelIdentifierWithBase | null;
|
||||
onChange: (value: T | null) => void;
|
||||
modelFilter?: (model: T) => boolean;
|
||||
isModelDisabled?: (model: T) => boolean;
|
||||
};
|
||||
|
||||
type UseModelCustomSelectReturn = {
|
||||
selectedItem: Item | null;
|
||||
items: Item[];
|
||||
onChange: (item: Item | null) => void;
|
||||
placeholder: string;
|
||||
};
|
||||
|
||||
const modelFilterDefault = () => true;
|
||||
const isModelDisabledDefault = () => false;
|
||||
|
||||
export const useModelCustomSelect = <T extends AnyModelConfig>({
|
||||
data,
|
||||
isLoading,
|
||||
selectedModel,
|
||||
onChange,
|
||||
modelFilter = modelFilterDefault,
|
||||
isModelDisabled = isModelDisabledDefault,
|
||||
}: UseModelCustomSelectArg<T>): UseModelCustomSelectReturn => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const items: Item[] = useMemo(
|
||||
() =>
|
||||
data
|
||||
? filter(data.entities, modelFilter).map<Item>((m) => ({
|
||||
label: m.name,
|
||||
value: m.key,
|
||||
description: m.description,
|
||||
group: MODEL_TYPE_SHORT_MAP[m.base],
|
||||
isDisabled: isModelDisabled(m),
|
||||
}))
|
||||
: EMPTY_ARRAY,
|
||||
[data, isModelDisabled, modelFilter]
|
||||
);
|
||||
|
||||
const _onChange = useCallback(
|
||||
(item: Item | null) => {
|
||||
if (!item || !data) {
|
||||
return;
|
||||
}
|
||||
const model = data.entities[item.value];
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
onChange(model);
|
||||
},
|
||||
[data, onChange]
|
||||
);
|
||||
|
||||
const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]);
|
||||
|
||||
const placeholder = useMemo(() => {
|
||||
if (isLoading) {
|
||||
return t('common.loading');
|
||||
}
|
||||
|
||||
if (items.length === 0) {
|
||||
return t('models.noModelsAvailable');
|
||||
}
|
||||
|
||||
return t('models.selectModel');
|
||||
}, [isLoading, items, t]);
|
||||
|
||||
return {
|
||||
items,
|
||||
onChange: _onChange,
|
||||
selectedItem,
|
||||
placeholder,
|
||||
};
|
||||
};
|
@ -1,34 +1,27 @@
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { CustomSelect, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
|
||||
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
||||
import { useControlAdapterModelEntities } from 'features/controlAdapters/hooks/useControlAdapterModelEntities';
|
||||
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
|
||||
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig, ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types';
|
||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types';
|
||||
|
||||
type ParamControlAdapterModelProps = {
|
||||
id: string;
|
||||
};
|
||||
|
||||
const selectMainModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
|
||||
|
||||
const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
const isEnabled = useControlAdapterIsEnabled(id);
|
||||
const controlAdapterType = useControlAdapterType(id);
|
||||
const model = useControlAdapterModel(id);
|
||||
const dispatch = useAppDispatch();
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
const mainModel = useAppSelector(selectMainModel);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const models = useControlAdapterModelEntities(controlAdapterType);
|
||||
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
|
||||
|
||||
const _onChange = useCallback(
|
||||
(model: ControlNetConfig | IPAdapterConfig | T2IAdapterConfig | null) => {
|
||||
@ -50,34 +43,18 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
[controlAdapterType, model]
|
||||
);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(model: AnyModelConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === model.base;
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
return !hasMainModel || !isCompatible;
|
||||
},
|
||||
[currentBaseModel]
|
||||
);
|
||||
|
||||
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelEntities: models,
|
||||
onChange: _onChange,
|
||||
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
||||
data,
|
||||
isLoading,
|
||||
selectedModel,
|
||||
getIsDisabled,
|
||||
onChange: _onChange,
|
||||
modelFilter: (model) => model.base === currentBaseModel,
|
||||
});
|
||||
|
||||
return (
|
||||
<Tooltip label={value?.description}>
|
||||
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base !== model?.base}>
|
||||
<Combobox
|
||||
options={options}
|
||||
placeholder={t('controlnet.selectModel')}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
<FormControl isDisabled={!items.length || !isEnabled} isInvalid={!selectedItem || !items.length}>
|
||||
<CustomSelect selectedItem={selectedItem} placeholder={placeholder} items={items} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -1,23 +0,0 @@
|
||||
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
||||
import {
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
useGetT2IAdapterModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
|
||||
export const useControlAdapterModelEntities = (type?: ControlAdapterType) => {
|
||||
const { data: controlNetModelsData } = useGetControlNetModelsQuery();
|
||||
const { data: t2iAdapterModelsData } = useGetT2IAdapterModelsQuery();
|
||||
const { data: ipAdapterModelsData } = useGetIPAdapterModelsQuery();
|
||||
|
||||
if (type === 'controlnet') {
|
||||
return controlNetModelsData;
|
||||
}
|
||||
if (type === 't2i_adapter') {
|
||||
return t2iAdapterModelsData;
|
||||
}
|
||||
if (type === 'ip_adapter') {
|
||||
return ipAdapterModelsData;
|
||||
}
|
||||
return;
|
||||
};
|
@ -0,0 +1,26 @@
|
||||
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
||||
import {
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
useGetT2IAdapterModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
|
||||
export const useControlAdapterModelQuery = (type: ControlAdapterType) => {
|
||||
const controlNetModelsQuery = useGetControlNetModelsQuery();
|
||||
const t2iAdapterModelsQuery = useGetT2IAdapterModelsQuery();
|
||||
const ipAdapterModelsQuery = useGetIPAdapterModelsQuery();
|
||||
|
||||
if (type === 'controlnet') {
|
||||
return controlNetModelsQuery;
|
||||
}
|
||||
if (type === 't2i_adapter') {
|
||||
return t2iAdapterModelsQuery;
|
||||
}
|
||||
if (type === 'ip_adapter') {
|
||||
return ipAdapterModelsQuery;
|
||||
}
|
||||
|
||||
// Assert that the end of the function is not reachable.
|
||||
const exhaustiveCheck: never = type;
|
||||
return exhaustiveCheck;
|
||||
};
|
@ -5,14 +5,16 @@ import {
|
||||
selectControlAdaptersSlice,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const useControlAdapterType = (id: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(
|
||||
selectControlAdaptersSlice,
|
||||
(controlAdapters) => selectControlAdapterById(controlAdapters, id)?.type
|
||||
),
|
||||
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
|
||||
const type = selectControlAdapterById(controlAdapters, id)?.type;
|
||||
assert(type !== undefined, `Control adapter with id ${id} not found`);
|
||||
return type;
|
||||
}),
|
||||
[id]
|
||||
);
|
||||
|
||||
|
@ -2,11 +2,7 @@ import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { memo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import type {
|
||||
DiffusersModelConfig,
|
||||
LoRAConfig,
|
||||
MainModelConfig,
|
||||
} from 'services/api/endpoints/models';
|
||||
import type { DiffusersModelConfig, LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
|
||||
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||
|
@ -1,62 +1,47 @@
|
||||
import { Box, Combobox, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { CustomSelect, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import type { MainModelConfig } from 'services/api/endpoints/models';
|
||||
import { getModelId, mainModelsAdapterSelectors, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
|
||||
|
||||
const ParamMainModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const model = useAppSelector(selectModel);
|
||||
const selectedModel = useAppSelector(selectModel);
|
||||
const { data, isLoading } = useGetMainModelsQuery(NON_REFINER_BASE_MODELS);
|
||||
const tooltipLabel = useMemo(() => {
|
||||
if (!data || !model) {
|
||||
return;
|
||||
}
|
||||
return mainModelsAdapterSelectors.selectById(data, getModelId(model))?.description;
|
||||
}, [data, model]);
|
||||
|
||||
const _onChange = useCallback(
|
||||
(model: MainModelConfig | null) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
dispatch(modelSelected(pick(model, ['base_model', 'model_name', 'model_type'])));
|
||||
dispatch(modelSelected({ key: model.key, base: model.base }));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
onChange: _onChange,
|
||||
selectedModel: model,
|
||||
|
||||
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
||||
data,
|
||||
isLoading,
|
||||
selectedModel,
|
||||
onChange: _onChange,
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl isDisabled={!options.length} isInvalid={!options.length}>
|
||||
<FormControl isDisabled={!items.length} isInvalid={!selectedItem || !items.length}>
|
||||
<InformationalPopover feature="paramModel">
|
||||
<FormLabel>{t('modelManager.model')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Tooltip label={tooltipLabel}>
|
||||
<Box w="full">
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</Box>
|
||||
</Tooltip>
|
||||
<CustomSelect selectedItem={selectedItem} placeholder={placeholder} items={items} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { ImageDTO, MainModelField } from 'services/api/types';
|
||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const initialImageSelected = createAction<ImageDTO | undefined>('generation/initialImageSelected');
|
||||
|
||||
export const modelSelected = createAction<MainModelField>('generation/modelSelected');
|
||||
export const modelSelected = createAction<ParameterModel>('generation/modelSelected');
|
||||
|
@ -17,8 +17,8 @@ export const MODEL_TYPE_MAP = {
|
||||
*/
|
||||
export const MODEL_TYPE_SHORT_MAP = {
|
||||
any: 'Any',
|
||||
'sd-1': 'SD1',
|
||||
'sd-2': 'SD2',
|
||||
'sd-1': 'SD1.X',
|
||||
'sd-2': 'SD2.X',
|
||||
sdxl: 'SDXL',
|
||||
'sdxl-refiner': 'SDXLR',
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user