fix(ui): get vae model select working

This commit is contained in:
psychedelicious 2024-02-20 11:14:08 +11:00
parent f870f810d5
commit e7e3045a8a
4 changed files with 17 additions and 20 deletions

View File

@ -6,7 +6,6 @@ import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { groupBy, map, reduce } from 'lodash-es'; import { groupBy, map, reduce } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { getModelId } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types'; import type { AnyModelConfig } from 'services/api/types';
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = { type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
@ -58,8 +57,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
const value = useMemo( const value = useMemo(
() => () =>
options.flatMap((o) => o.options).find((m) => (selectedModel ? m.value === getModelId(selectedModel) : false)) ?? options.flatMap((o) => o.options).find((m) => (selectedModel ? m.value === selectedModel.key : false)) ?? null,
null,
[options, selectedModel] [options, selectedModel]
); );

View File

@ -1,14 +1,14 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit'; import type { EntityState } from '@reduxjs/toolkit';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/endpoints/models'; import type { AnyModelConfig } from 'services/api/types';
import { getModelId } from 'services/api/endpoints/models';
type UseModelComboboxArg<T extends AnyModelConfig> = { type UseModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined; modelEntities: EntityState<T, string> | undefined;
selectedModel?: Pick<T, 'base_model' | 'model_name' | 'model_type'> | null; selectedModel?: ModelIdentifierWithBase | null;
onChange: (value: T | null) => void; onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean; getIsDisabled?: (model: T) => boolean;
optionsFilter?: (model: T) => boolean; optionsFilter?: (model: T) => boolean;
@ -33,14 +33,14 @@ export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelCombobox
return map(modelEntities.entities) return map(modelEntities.entities)
.filter(optionsFilter) .filter(optionsFilter)
.map((model) => ({ .map((model) => ({
label: model.model_name, label: model.name,
value: model.id, value: model.key,
isDisabled: getIsDisabled ? getIsDisabled(model) : false, isDisabled: getIsDisabled ? getIsDisabled(model) : false,
})); }));
}, [optionsFilter, getIsDisabled, modelEntities]); }, [optionsFilter, getIsDisabled, modelEntities]);
const value = useMemo( const value = useMemo(
() => options.find((m) => (selectedModel ? m.value === getModelId(selectedModel) : false)), () => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
[options, selectedModel] [options, selectedModel]
); );

View File

@ -7,8 +7,8 @@ import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/ge
import { pick } from 'lodash-es'; import { pick } from 'lodash-es';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { VAEConfig } from 'services/api/endpoints/models';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import type { VAEConfig } from 'services/api/types';
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => { const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
const { model, vae } = generation; const { model, vae } = generation;
@ -22,26 +22,28 @@ const ParamVAEModelSelect = () => {
const { data, isLoading } = useGetVaeModelsQuery(); const { data, isLoading } = useGetVaeModelsQuery();
const getIsDisabled = useCallback( const getIsDisabled = useCallback(
(vae: VAEConfig): boolean => { (vae: VAEConfig): boolean => {
const isCompatible = model?.base_model === vae.base_model; const isCompatible = model?.base === vae.base;
const hasMainModel = Boolean(model?.base_model); const hasMainModel = Boolean(model?.base);
return !hasMainModel || !isCompatible; return !hasMainModel || !isCompatible;
}, },
[model?.base_model] [model?.base]
); );
const _onChange = useCallback( const _onChange = useCallback(
(vae: VAEConfig | null) => { (vae: VAEConfig | null) => {
dispatch(vaeSelected(vae ? pick(vae, 'base_model', 'model_name') : null)); dispatch(vaeSelected(vae ? pick(vae, 'key', 'base') : null));
}, },
[dispatch] [dispatch]
); );
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data, modelEntities: data,
onChange: _onChange, onChange: _onChange,
selectedModel: vae ? { ...vae, model_type: 'vae' } : null, selectedModel: vae ? pick(vae, 'key', 'base') : null,
isLoading, isLoading,
getIsDisabled, getIsDisabled,
}); });
console.log(value)
return ( return (
<FormControl isDisabled={!options.length} isInvalid={!options.length}> <FormControl isDisabled={!options.length} isInvalid={!options.length}>
<InformationalPopover feature="paramVAE"> <InformationalPopover feature="paramVAE">
@ -50,7 +52,7 @@ const ParamVAEModelSelect = () => {
<Combobox <Combobox
isClearable isClearable
value={value} value={value}
placeholder={value ? placeholder : t('models.defaultVAE')} placeholder={value ? value.value : t('models.defaultVAE')}
options={options} options={options}
onChange={onChange} onChange={onChange}
noOptionsMessage={noOptionsMessage} noOptionsMessage={noOptionsMessage}

View File

@ -21,9 +21,6 @@ import type {
import type { ApiTagDescription, tagTypes } from '..'; import type { ApiTagDescription, tagTypes } from '..';
import { api, buildV2Url, LIST_TAG } from '..'; import { api, buildV2Url, LIST_TAG } from '..';
/* eslint-disable @typescript-eslint/no-explicit-any */
export const getModelId = (input: any): any => input;
type UpdateMainModelArg = { type UpdateMainModelArg = {
base_model: BaseModelType; base_model: BaseModelType;
model_name: string; model_name: string;