mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): single getModelConfigs query
Single query, with simple wrapper hooks (type-safe). Updated everywhere in frontend.
This commit is contained in:
@ -1,15 +1,14 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { GroupBase } from 'chakra-react-select';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { groupBy, map, reduce } from 'lodash-es';
|
||||
import { groupBy, reduce } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
||||
modelEntities: EntityState<T, string> | undefined;
|
||||
modelConfigs: T[];
|
||||
selectedModel?: ModelIdentifierField | null;
|
||||
onChange: (value: T | null) => void;
|
||||
getIsDisabled?: (model: T) => boolean;
|
||||
@ -29,13 +28,12 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
): UseGroupedModelComboboxReturn => {
|
||||
const { t } = useTranslation();
|
||||
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
|
||||
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg;
|
||||
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading } = arg;
|
||||
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
|
||||
if (!modelEntities) {
|
||||
if (!modelConfigs) {
|
||||
return [];
|
||||
}
|
||||
const modelEntitiesArray = map(modelEntities.entities);
|
||||
const groupedModels = groupBy(modelEntitiesArray, 'base');
|
||||
const groupedModels = groupBy(modelConfigs, 'base');
|
||||
const _options = reduce(
|
||||
groupedModels,
|
||||
(acc, val, label) => {
|
||||
@ -53,7 +51,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
);
|
||||
_options.sort((a) => (a.label === base_model ? -1 : 1));
|
||||
return _options;
|
||||
}, [getIsDisabled, modelEntities, base_model]);
|
||||
}, [getIsDisabled, modelConfigs, base_model]);
|
||||
|
||||
const value = useMemo(
|
||||
() =>
|
||||
@ -67,14 +65,14 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
onChange(null);
|
||||
return;
|
||||
}
|
||||
const model = modelEntities?.entities[v.value];
|
||||
const model = modelConfigs.find((m) => m.key === v.value);
|
||||
if (!model) {
|
||||
onChange(null);
|
||||
return;
|
||||
}
|
||||
onChange(model);
|
||||
},
|
||||
[modelEntities?.entities, onChange]
|
||||
[modelConfigs, onChange]
|
||||
);
|
||||
|
||||
const placeholder = useMemo(() => {
|
||||
|
@ -1,13 +1,11 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { map } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type UseModelComboboxArg<T extends AnyModelConfig> = {
|
||||
modelEntities: EntityState<T, string> | undefined;
|
||||
modelConfigs: T[];
|
||||
selectedModel?: ModelIdentifierField | null;
|
||||
onChange: (value: T | null) => void;
|
||||
getIsDisabled?: (model: T) => boolean;
|
||||
@ -25,19 +23,14 @@ type UseModelComboboxReturn = {
|
||||
|
||||
export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelComboboxArg<T>): UseModelComboboxReturn => {
|
||||
const { t } = useTranslation();
|
||||
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
|
||||
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
|
||||
const options = useMemo<ComboboxOption[]>(() => {
|
||||
if (!modelEntities) {
|
||||
return [];
|
||||
}
|
||||
return map(modelEntities.entities)
|
||||
.filter(optionsFilter)
|
||||
.map((model) => ({
|
||||
label: model.name,
|
||||
value: model.key,
|
||||
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
|
||||
}));
|
||||
}, [optionsFilter, getIsDisabled, modelEntities]);
|
||||
return modelConfigs.filter(optionsFilter).map((model) => ({
|
||||
label: model.name,
|
||||
value: model.key,
|
||||
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
|
||||
}));
|
||||
}, [optionsFilter, getIsDisabled, modelConfigs]);
|
||||
|
||||
const value = useMemo(
|
||||
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
|
||||
@ -50,14 +43,14 @@ export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelCombobox
|
||||
onChange(null);
|
||||
return;
|
||||
}
|
||||
const model = modelEntities?.entities[v.value];
|
||||
const model = modelConfigs.find((m) => m.key === v.value);
|
||||
if (!model) {
|
||||
onChange(null);
|
||||
return;
|
||||
}
|
||||
onChange(model);
|
||||
},
|
||||
[modelEntities?.entities, onChange]
|
||||
[modelConfigs, onChange]
|
||||
);
|
||||
|
||||
const placeholder = useMemo(() => {
|
||||
|
@ -1,15 +1,12 @@
|
||||
import type { Item } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||
import { filter } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
|
||||
data: EntityState<T, string> | undefined;
|
||||
modelConfigs: T[];
|
||||
isLoading: boolean;
|
||||
selectedModel?: ModelIdentifierField | null;
|
||||
onChange: (value: T | null) => void;
|
||||
@ -28,7 +25,7 @@ const modelFilterDefault = () => true;
|
||||
const isModelDisabledDefault = () => false;
|
||||
|
||||
export const useModelCustomSelect = <T extends AnyModelConfig>({
|
||||
data,
|
||||
modelConfigs,
|
||||
isLoading,
|
||||
selectedModel,
|
||||
onChange,
|
||||
@ -39,30 +36,28 @@ export const useModelCustomSelect = <T extends AnyModelConfig>({
|
||||
|
||||
const items: Item[] = useMemo(
|
||||
() =>
|
||||
data
|
||||
? filter(data.entities, modelFilter).map<Item>((m) => ({
|
||||
label: m.name,
|
||||
value: m.key,
|
||||
description: m.description,
|
||||
group: MODEL_TYPE_SHORT_MAP[m.base],
|
||||
isDisabled: isModelDisabled(m),
|
||||
}))
|
||||
: EMPTY_ARRAY,
|
||||
[data, isModelDisabled, modelFilter]
|
||||
modelConfigs.filter(modelFilter).map<Item>((m) => ({
|
||||
label: m.name,
|
||||
value: m.key,
|
||||
description: m.description,
|
||||
group: MODEL_TYPE_SHORT_MAP[m.base],
|
||||
isDisabled: isModelDisabled(m),
|
||||
})),
|
||||
[modelConfigs, isModelDisabled, modelFilter]
|
||||
);
|
||||
|
||||
const _onChange = useCallback(
|
||||
(item: Item | null) => {
|
||||
if (!item || !data) {
|
||||
if (!item || !modelConfigs) {
|
||||
return;
|
||||
}
|
||||
const model = data.entities[item.value];
|
||||
const model = modelConfigs.find((m) => m.key === item.value);
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
onChange(model);
|
||||
},
|
||||
[data, onChange]
|
||||
[modelConfigs, onChange]
|
||||
);
|
||||
|
||||
const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]);
|
||||
|
Reference in New Issue
Block a user