fix(ui): get vae model select working

This commit is contained in:
psychedelicious 2024-02-20 11:14:08 +11:00
parent f870f810d5
commit e7e3045a8a
4 changed files with 17 additions and 20 deletions

View File

@ -6,7 +6,6 @@ import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { groupBy, map, reduce } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { getModelId } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
@ -58,8 +57,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
const value = useMemo(
() =>
options.flatMap((o) => o.options).find((m) => (selectedModel ? m.value === getModelId(selectedModel) : false)) ??
null,
options.flatMap((o) => o.options).find((m) => (selectedModel ? m.value === selectedModel.key : false)) ?? null,
[options, selectedModel]
);

View File

@ -1,14 +1,14 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/endpoints/models';
import { getModelId } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
type UseModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined;
selectedModel?: Pick<T, 'base_model' | 'model_name' | 'model_type'> | null;
selectedModel?: ModelIdentifierWithBase | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
optionsFilter?: (model: T) => boolean;
@ -33,14 +33,14 @@ export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelCombobox
return map(modelEntities.entities)
.filter(optionsFilter)
.map((model) => ({
label: model.model_name,
value: model.id,
label: model.name,
value: model.key,
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
}));
}, [optionsFilter, getIsDisabled, modelEntities]);
const value = useMemo(
() => options.find((m) => (selectedModel ? m.value === getModelId(selectedModel) : false)),
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
[options, selectedModel]
);

View File

@ -7,8 +7,8 @@ import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/ge
import { pick } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { VAEConfig } from 'services/api/endpoints/models';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import type { VAEConfig } from 'services/api/types';
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
const { model, vae } = generation;
@ -22,25 +22,27 @@ const ParamVAEModelSelect = () => {
const { data, isLoading } = useGetVaeModelsQuery();
const getIsDisabled = useCallback(
(vae: VAEConfig): boolean => {
const isCompatible = model?.base_model === vae.base_model;
const hasMainModel = Boolean(model?.base_model);
const isCompatible = model?.base === vae.base;
const hasMainModel = Boolean(model?.base);
return !hasMainModel || !isCompatible;
},
[model?.base_model]
[model?.base]
);
const _onChange = useCallback(
(vae: VAEConfig | null) => {
dispatch(vaeSelected(vae ? pick(vae, 'base_model', 'model_name') : null));
dispatch(vaeSelected(vae ? pick(vae, 'key', 'base') : null));
},
[dispatch]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
onChange: _onChange,
selectedModel: vae ? { ...vae, model_type: 'vae' } : null,
selectedModel: vae ? pick(vae, 'key', 'base') : null,
isLoading,
getIsDisabled,
});
console.log(value)
return (
<FormControl isDisabled={!options.length} isInvalid={!options.length}>
@ -50,7 +52,7 @@ const ParamVAEModelSelect = () => {
<Combobox
isClearable
value={value}
placeholder={value ? placeholder : t('models.defaultVAE')}
placeholder={value ? value.value : t('models.defaultVAE')}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}

View File

@ -21,9 +21,6 @@ import type {
import type { ApiTagDescription, tagTypes } from '..';
import { api, buildV2Url, LIST_TAG } from '..';
/* eslint-disable @typescript-eslint/no-explicit-any */
export const getModelId = (input: any): any => input;
type UpdateMainModelArg = {
base_model: BaseModelType;
model_name: string;