mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): get vae model select working
This commit is contained in:
parent
f870f810d5
commit
e7e3045a8a
@ -6,7 +6,6 @@ import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import { groupBy, map, reduce } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { getModelId } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
||||
@ -58,8 +57,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
|
||||
const value = useMemo(
|
||||
() =>
|
||||
options.flatMap((o) => o.options).find((m) => (selectedModel ? m.value === getModelId(selectedModel) : false)) ??
|
||||
null,
|
||||
options.flatMap((o) => o.options).find((m) => (selectedModel ? m.value === selectedModel.key : false)) ?? null,
|
||||
[options, selectedModel]
|
||||
);
|
||||
|
||||
|
@ -1,14 +1,14 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import { map } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/endpoints/models';
|
||||
import { getModelId } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type UseModelComboboxArg<T extends AnyModelConfig> = {
|
||||
modelEntities: EntityState<T, string> | undefined;
|
||||
selectedModel?: Pick<T, 'base_model' | 'model_name' | 'model_type'> | null;
|
||||
selectedModel?: ModelIdentifierWithBase | null;
|
||||
onChange: (value: T | null) => void;
|
||||
getIsDisabled?: (model: T) => boolean;
|
||||
optionsFilter?: (model: T) => boolean;
|
||||
@ -33,14 +33,14 @@ export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelCombobox
|
||||
return map(modelEntities.entities)
|
||||
.filter(optionsFilter)
|
||||
.map((model) => ({
|
||||
label: model.model_name,
|
||||
value: model.id,
|
||||
label: model.name,
|
||||
value: model.key,
|
||||
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
|
||||
}));
|
||||
}, [optionsFilter, getIsDisabled, modelEntities]);
|
||||
|
||||
const value = useMemo(
|
||||
() => options.find((m) => (selectedModel ? m.value === getModelId(selectedModel) : false)),
|
||||
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
|
||||
[options, selectedModel]
|
||||
);
|
||||
|
||||
|
@ -7,8 +7,8 @@ import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/ge
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { VAEConfig } from 'services/api/endpoints/models';
|
||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { VAEConfig } from 'services/api/types';
|
||||
|
||||
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
|
||||
const { model, vae } = generation;
|
||||
@ -22,26 +22,28 @@ const ParamVAEModelSelect = () => {
|
||||
const { data, isLoading } = useGetVaeModelsQuery();
|
||||
const getIsDisabled = useCallback(
|
||||
(vae: VAEConfig): boolean => {
|
||||
const isCompatible = model?.base_model === vae.base_model;
|
||||
const hasMainModel = Boolean(model?.base_model);
|
||||
const isCompatible = model?.base === vae.base;
|
||||
const hasMainModel = Boolean(model?.base);
|
||||
return !hasMainModel || !isCompatible;
|
||||
},
|
||||
[model?.base_model]
|
||||
[model?.base]
|
||||
);
|
||||
const _onChange = useCallback(
|
||||
(vae: VAEConfig | null) => {
|
||||
dispatch(vaeSelected(vae ? pick(vae, 'base_model', 'model_name') : null));
|
||||
dispatch(vaeSelected(vae ? pick(vae, 'key', 'base') : null));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
onChange: _onChange,
|
||||
selectedModel: vae ? { ...vae, model_type: 'vae' } : null,
|
||||
selectedModel: vae ? pick(vae, 'key', 'base') : null,
|
||||
isLoading,
|
||||
getIsDisabled,
|
||||
});
|
||||
|
||||
console.log(value)
|
||||
|
||||
return (
|
||||
<FormControl isDisabled={!options.length} isInvalid={!options.length}>
|
||||
<InformationalPopover feature="paramVAE">
|
||||
@ -50,7 +52,7 @@ const ParamVAEModelSelect = () => {
|
||||
<Combobox
|
||||
isClearable
|
||||
value={value}
|
||||
placeholder={value ? placeholder : t('models.defaultVAE')}
|
||||
placeholder={value ? value.value : t('models.defaultVAE')}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
|
@ -21,9 +21,6 @@ import type {
|
||||
import type { ApiTagDescription, tagTypes } from '..';
|
||||
import { api, buildV2Url, LIST_TAG } from '..';
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
export const getModelId = (input: any): any => input;
|
||||
|
||||
type UpdateMainModelArg = {
|
||||
base_model: BaseModelType;
|
||||
model_name: string;
|
||||
|
Loading…
Reference in New Issue
Block a user