mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): fix main model & control adapter model selects
This commit is contained in:
parent
eb27951b8c
commit
ed860ae851
@ -0,0 +1,88 @@
|
|||||||
|
import type { Item } from '@invoke-ai/ui-library';
|
||||||
|
import type { EntityState } from '@reduxjs/toolkit';
|
||||||
|
import { EMPTY_ARRAY } from 'app/store/util';
|
||||||
|
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||||
|
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||||
|
import { filter } from 'lodash-es';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
|
||||||
|
data: EntityState<T, string> | undefined;
|
||||||
|
isLoading: boolean;
|
||||||
|
selectedModel?: ModelIdentifierWithBase | null;
|
||||||
|
onChange: (value: T | null) => void;
|
||||||
|
modelFilter?: (model: T) => boolean;
|
||||||
|
isModelDisabled?: (model: T) => boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
type UseModelCustomSelectReturn = {
|
||||||
|
selectedItem: Item | null;
|
||||||
|
items: Item[];
|
||||||
|
onChange: (item: Item | null) => void;
|
||||||
|
placeholder: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
const modelFilterDefault = () => true;
|
||||||
|
const isModelDisabledDefault = () => false;
|
||||||
|
|
||||||
|
export const useModelCustomSelect = <T extends AnyModelConfig>({
|
||||||
|
data,
|
||||||
|
isLoading,
|
||||||
|
selectedModel,
|
||||||
|
onChange,
|
||||||
|
modelFilter = modelFilterDefault,
|
||||||
|
isModelDisabled = isModelDisabledDefault,
|
||||||
|
}: UseModelCustomSelectArg<T>): UseModelCustomSelectReturn => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const items: Item[] = useMemo(
|
||||||
|
() =>
|
||||||
|
data
|
||||||
|
? filter(data.entities, modelFilter).map<Item>((m) => ({
|
||||||
|
label: m.name,
|
||||||
|
value: m.key,
|
||||||
|
description: m.description,
|
||||||
|
group: MODEL_TYPE_SHORT_MAP[m.base],
|
||||||
|
isDisabled: isModelDisabled(m),
|
||||||
|
}))
|
||||||
|
: EMPTY_ARRAY,
|
||||||
|
[data, isModelDisabled, modelFilter]
|
||||||
|
);
|
||||||
|
|
||||||
|
const _onChange = useCallback(
|
||||||
|
(item: Item | null) => {
|
||||||
|
if (!item || !data) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const model = data.entities[item.value];
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
onChange(model);
|
||||||
|
},
|
||||||
|
[data, onChange]
|
||||||
|
);
|
||||||
|
|
||||||
|
const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]);
|
||||||
|
|
||||||
|
const placeholder = useMemo(() => {
|
||||||
|
if (isLoading) {
|
||||||
|
return t('common.loading');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (items.length === 0) {
|
||||||
|
return t('models.noModelsAvailable');
|
||||||
|
}
|
||||||
|
|
||||||
|
return t('models.selectModel');
|
||||||
|
}, [isLoading, items, t]);
|
||||||
|
|
||||||
|
return {
|
||||||
|
items,
|
||||||
|
onChange: _onChange,
|
||||||
|
selectedItem,
|
||||||
|
placeholder,
|
||||||
|
};
|
||||||
|
};
|
@ -1,34 +1,27 @@
|
|||||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
import { CustomSelect, FormControl } from '@invoke-ai/ui-library';
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
|
||||||
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||||
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
||||||
import { useControlAdapterModelEntities } from 'features/controlAdapters/hooks/useControlAdapterModelEntities';
|
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
|
||||||
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
||||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
|
||||||
import { pick } from 'lodash-es';
|
import { pick } from 'lodash-es';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types';
|
||||||
import type { AnyModelConfig, ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
type ParamControlAdapterModelProps = {
|
type ParamControlAdapterModelProps = {
|
||||||
id: string;
|
id: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
const selectMainModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
|
|
||||||
|
|
||||||
const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||||
const isEnabled = useControlAdapterIsEnabled(id);
|
const isEnabled = useControlAdapterIsEnabled(id);
|
||||||
const controlAdapterType = useControlAdapterType(id);
|
const controlAdapterType = useControlAdapterType(id);
|
||||||
const model = useControlAdapterModel(id);
|
const model = useControlAdapterModel(id);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||||
const mainModel = useAppSelector(selectMainModel);
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const models = useControlAdapterModelEntities(controlAdapterType);
|
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(model: ControlNetConfig | IPAdapterConfig | T2IAdapterConfig | null) => {
|
(model: ControlNetConfig | IPAdapterConfig | T2IAdapterConfig | null) => {
|
||||||
@ -50,34 +43,18 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
|||||||
[controlAdapterType, model]
|
[controlAdapterType, model]
|
||||||
);
|
);
|
||||||
|
|
||||||
const getIsDisabled = useCallback(
|
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
||||||
(model: AnyModelConfig): boolean => {
|
data,
|
||||||
const isCompatible = currentBaseModel === model.base;
|
isLoading,
|
||||||
const hasMainModel = Boolean(currentBaseModel);
|
|
||||||
return !hasMainModel || !isCompatible;
|
|
||||||
},
|
|
||||||
[currentBaseModel]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
|
|
||||||
modelEntities: models,
|
|
||||||
onChange: _onChange,
|
|
||||||
selectedModel,
|
selectedModel,
|
||||||
getIsDisabled,
|
onChange: _onChange,
|
||||||
|
modelFilter: (model) => model.base === currentBaseModel,
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Tooltip label={value?.description}>
|
<FormControl isDisabled={!items.length || !isEnabled} isInvalid={!selectedItem || !items.length}>
|
||||||
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base !== model?.base}>
|
<CustomSelect selectedItem={selectedItem} placeholder={placeholder} items={items} onChange={onChange} />
|
||||||
<Combobox
|
</FormControl>
|
||||||
options={options}
|
|
||||||
placeholder={t('controlnet.selectModel')}
|
|
||||||
value={value}
|
|
||||||
onChange={onChange}
|
|
||||||
noOptionsMessage={noOptionsMessage}
|
|
||||||
/>
|
|
||||||
</FormControl>
|
|
||||||
</Tooltip>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,23 +0,0 @@
|
|||||||
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
|
||||||
import {
|
|
||||||
useGetControlNetModelsQuery,
|
|
||||||
useGetIPAdapterModelsQuery,
|
|
||||||
useGetT2IAdapterModelsQuery,
|
|
||||||
} from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
export const useControlAdapterModelEntities = (type?: ControlAdapterType) => {
|
|
||||||
const { data: controlNetModelsData } = useGetControlNetModelsQuery();
|
|
||||||
const { data: t2iAdapterModelsData } = useGetT2IAdapterModelsQuery();
|
|
||||||
const { data: ipAdapterModelsData } = useGetIPAdapterModelsQuery();
|
|
||||||
|
|
||||||
if (type === 'controlnet') {
|
|
||||||
return controlNetModelsData;
|
|
||||||
}
|
|
||||||
if (type === 't2i_adapter') {
|
|
||||||
return t2iAdapterModelsData;
|
|
||||||
}
|
|
||||||
if (type === 'ip_adapter') {
|
|
||||||
return ipAdapterModelsData;
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
};
|
|
@ -0,0 +1,26 @@
|
|||||||
|
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
||||||
|
import {
|
||||||
|
useGetControlNetModelsQuery,
|
||||||
|
useGetIPAdapterModelsQuery,
|
||||||
|
useGetT2IAdapterModelsQuery,
|
||||||
|
} from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
export const useControlAdapterModelQuery = (type: ControlAdapterType) => {
|
||||||
|
const controlNetModelsQuery = useGetControlNetModelsQuery();
|
||||||
|
const t2iAdapterModelsQuery = useGetT2IAdapterModelsQuery();
|
||||||
|
const ipAdapterModelsQuery = useGetIPAdapterModelsQuery();
|
||||||
|
|
||||||
|
if (type === 'controlnet') {
|
||||||
|
return controlNetModelsQuery;
|
||||||
|
}
|
||||||
|
if (type === 't2i_adapter') {
|
||||||
|
return t2iAdapterModelsQuery;
|
||||||
|
}
|
||||||
|
if (type === 'ip_adapter') {
|
||||||
|
return ipAdapterModelsQuery;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert that the end of the function is not reachable.
|
||||||
|
const exhaustiveCheck: never = type;
|
||||||
|
return exhaustiveCheck;
|
||||||
|
};
|
@ -5,14 +5,16 @@ import {
|
|||||||
selectControlAdaptersSlice,
|
selectControlAdaptersSlice,
|
||||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
export const useControlAdapterType = (id: string) => {
|
export const useControlAdapterType = (id: string) => {
|
||||||
const selector = useMemo(
|
const selector = useMemo(
|
||||||
() =>
|
() =>
|
||||||
createMemoizedSelector(
|
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
|
||||||
selectControlAdaptersSlice,
|
const type = selectControlAdapterById(controlAdapters, id)?.type;
|
||||||
(controlAdapters) => selectControlAdapterById(controlAdapters, id)?.type
|
assert(type !== undefined, `Control adapter with id ${id} not found`);
|
||||||
),
|
return type;
|
||||||
|
}),
|
||||||
[id]
|
[id]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -2,11 +2,7 @@ import { Flex, Text } from '@invoke-ai/ui-library';
|
|||||||
import { memo, useState } from 'react';
|
import { memo, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||||
import type {
|
import type { DiffusersModelConfig, LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
|
||||||
DiffusersModelConfig,
|
|
||||||
LoRAConfig,
|
|
||||||
MainModelConfig,
|
|
||||||
} from 'services/api/endpoints/models';
|
|
||||||
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||||
|
@ -1,62 +1,47 @@
|
|||||||
import { Box, Combobox, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
|
import { CustomSelect, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
|
||||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
|
||||||
import { modelSelected } from 'features/parameters/store/actions';
|
import { modelSelected } from 'features/parameters/store/actions';
|
||||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||||
import { pick } from 'lodash-es';
|
import { memo, useCallback } from 'react';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||||
import type { MainModelConfig } from 'services/api/endpoints/models';
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
import { getModelId, mainModelsAdapterSelectors, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
import type { MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
|
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
|
||||||
|
|
||||||
const ParamMainModelSelect = () => {
|
const ParamMainModelSelect = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const model = useAppSelector(selectModel);
|
const selectedModel = useAppSelector(selectModel);
|
||||||
const { data, isLoading } = useGetMainModelsQuery(NON_REFINER_BASE_MODELS);
|
const { data, isLoading } = useGetMainModelsQuery(NON_REFINER_BASE_MODELS);
|
||||||
const tooltipLabel = useMemo(() => {
|
|
||||||
if (!data || !model) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
return mainModelsAdapterSelectors.selectById(data, getModelId(model))?.description;
|
|
||||||
}, [data, model]);
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(model: MainModelConfig | null) => {
|
(model: MainModelConfig | null) => {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(modelSelected(pick(model, ['base_model', 'model_name', 'model_type'])));
|
dispatch(modelSelected({ key: model.key, base: model.base }));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
|
||||||
modelEntities: data,
|
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
||||||
onChange: _onChange,
|
data,
|
||||||
selectedModel: model,
|
|
||||||
isLoading,
|
isLoading,
|
||||||
|
selectedModel,
|
||||||
|
onChange: _onChange,
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<FormControl isDisabled={!options.length} isInvalid={!options.length}>
|
<FormControl isDisabled={!items.length} isInvalid={!selectedItem || !items.length}>
|
||||||
<InformationalPopover feature="paramModel">
|
<InformationalPopover feature="paramModel">
|
||||||
<FormLabel>{t('modelManager.model')}</FormLabel>
|
<FormLabel>{t('modelManager.model')}</FormLabel>
|
||||||
</InformationalPopover>
|
</InformationalPopover>
|
||||||
<Tooltip label={tooltipLabel}>
|
<CustomSelect selectedItem={selectedItem} placeholder={placeholder} items={items} onChange={onChange} />
|
||||||
<Box w="full">
|
|
||||||
<Combobox
|
|
||||||
value={value}
|
|
||||||
placeholder={placeholder}
|
|
||||||
options={options}
|
|
||||||
onChange={onChange}
|
|
||||||
noOptionsMessage={noOptionsMessage}
|
|
||||||
/>
|
|
||||||
</Box>
|
|
||||||
</Tooltip>
|
|
||||||
</FormControl>
|
</FormControl>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import type { ImageDTO, MainModelField } from 'services/api/types';
|
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||||
|
import type { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
export const initialImageSelected = createAction<ImageDTO | undefined>('generation/initialImageSelected');
|
export const initialImageSelected = createAction<ImageDTO | undefined>('generation/initialImageSelected');
|
||||||
|
|
||||||
export const modelSelected = createAction<MainModelField>('generation/modelSelected');
|
export const modelSelected = createAction<ParameterModel>('generation/modelSelected');
|
||||||
|
@ -17,8 +17,8 @@ export const MODEL_TYPE_MAP = {
|
|||||||
*/
|
*/
|
||||||
export const MODEL_TYPE_SHORT_MAP = {
|
export const MODEL_TYPE_SHORT_MAP = {
|
||||||
any: 'Any',
|
any: 'Any',
|
||||||
'sd-1': 'SD1',
|
'sd-1': 'SD1.X',
|
||||||
'sd-2': 'SD2',
|
'sd-2': 'SD2.X',
|
||||||
sdxl: 'SDXL',
|
sdxl: 'SDXL',
|
||||||
'sdxl-refiner': 'SDXLR',
|
'sdxl-refiner': 'SDXLR',
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user