feat(ui): single getModelConfigs query

Single query, with simple wrapper hooks (type-safe). Updated everywhere in frontend.
This commit is contained in:
psychedelicious
2024-03-14 23:37:40 +11:00
parent ed20255abf
commit 19d66d5ec7
31 changed files with 447 additions and 790 deletions

View File

@ -1,15 +1,14 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { groupBy, map, reduce } from 'lodash-es';
import { groupBy, reduce } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined;
modelConfigs: T[];
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
@ -29,13 +28,12 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
): UseGroupedModelComboboxReturn => {
const { t } = useTranslation();
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg;
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading } = arg;
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
if (!modelEntities) {
if (!modelConfigs) {
return [];
}
const modelEntitiesArray = map(modelEntities.entities);
const groupedModels = groupBy(modelEntitiesArray, 'base');
const groupedModels = groupBy(modelConfigs, 'base');
const _options = reduce(
groupedModels,
(acc, val, label) => {
@ -53,7 +51,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
);
_options.sort((a) => (a.label === base_model ? -1 : 1));
return _options;
}, [getIsDisabled, modelEntities, base_model]);
}, [getIsDisabled, modelConfigs, base_model]);
const value = useMemo(
() =>
@ -67,14 +65,14 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
onChange(null);
return;
}
const model = modelEntities?.entities[v.value];
const model = modelConfigs.find((m) => m.key === v.value);
if (!model) {
onChange(null);
return;
}
onChange(model);
},
[modelEntities?.entities, onChange]
[modelConfigs, onChange]
);
const placeholder = useMemo(() => {

View File

@ -1,13 +1,11 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
type UseModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined;
modelConfigs: T[];
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
@ -25,19 +23,14 @@ type UseModelComboboxReturn = {
export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelComboboxArg<T>): UseModelComboboxReturn => {
const { t } = useTranslation();
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
const options = useMemo<ComboboxOption[]>(() => {
if (!modelEntities) {
return [];
}
return map(modelEntities.entities)
.filter(optionsFilter)
.map((model) => ({
label: model.name,
value: model.key,
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
}));
}, [optionsFilter, getIsDisabled, modelEntities]);
return modelConfigs.filter(optionsFilter).map((model) => ({
label: model.name,
value: model.key,
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
}));
}, [optionsFilter, getIsDisabled, modelConfigs]);
const value = useMemo(
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
@ -50,14 +43,14 @@ export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelCombobox
onChange(null);
return;
}
const model = modelEntities?.entities[v.value];
const model = modelConfigs.find((m) => m.key === v.value);
if (!model) {
onChange(null);
return;
}
onChange(model);
},
[modelEntities?.entities, onChange]
[modelConfigs, onChange]
);
const placeholder = useMemo(() => {

View File

@ -1,15 +1,12 @@
import type { Item } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import { filter } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
data: EntityState<T, string> | undefined;
modelConfigs: T[];
isLoading: boolean;
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
@ -28,7 +25,7 @@ const modelFilterDefault = () => true;
const isModelDisabledDefault = () => false;
export const useModelCustomSelect = <T extends AnyModelConfig>({
data,
modelConfigs,
isLoading,
selectedModel,
onChange,
@ -39,30 +36,28 @@ export const useModelCustomSelect = <T extends AnyModelConfig>({
const items: Item[] = useMemo(
() =>
data
? filter(data.entities, modelFilter).map<Item>((m) => ({
label: m.name,
value: m.key,
description: m.description,
group: MODEL_TYPE_SHORT_MAP[m.base],
isDisabled: isModelDisabled(m),
}))
: EMPTY_ARRAY,
[data, isModelDisabled, modelFilter]
modelConfigs.filter(modelFilter).map<Item>((m) => ({
label: m.name,
value: m.key,
description: m.description,
group: MODEL_TYPE_SHORT_MAP[m.base],
isDisabled: isModelDisabled(m),
})),
[modelConfigs, isModelDisabled, modelFilter]
);
const _onChange = useCallback(
(item: Item | null) => {
if (!item || !data) {
if (!item || !modelConfigs) {
return;
}
const model = data.entities[item.value];
const model = modelConfigs.find((m) => m.key === item.value);
if (!model) {
return;
}
onChange(model);
},
[data, onChange]
[modelConfigs, onChange]
);
const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]);