mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): get vae model select working
This commit is contained in:
parent
f870f810d5
commit
e7e3045a8a
@ -6,7 +6,6 @@ import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
|||||||
import { groupBy, map, reduce } from 'lodash-es';
|
import { groupBy, map, reduce } from 'lodash-es';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { getModelId } from 'services/api/endpoints/models';
|
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
||||||
@ -58,8 +57,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
|||||||
|
|
||||||
const value = useMemo(
|
const value = useMemo(
|
||||||
() =>
|
() =>
|
||||||
options.flatMap((o) => o.options).find((m) => (selectedModel ? m.value === getModelId(selectedModel) : false)) ??
|
options.flatMap((o) => o.options).find((m) => (selectedModel ? m.value === selectedModel.key : false)) ?? null,
|
||||||
null,
|
|
||||||
[options, selectedModel]
|
[options, selectedModel]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
import type { EntityState } from '@reduxjs/toolkit';
|
import type { EntityState } from '@reduxjs/toolkit';
|
||||||
|
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||||
import { map } from 'lodash-es';
|
import { map } from 'lodash-es';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { AnyModelConfig } from 'services/api/endpoints/models';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
import { getModelId } from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
type UseModelComboboxArg<T extends AnyModelConfig> = {
|
type UseModelComboboxArg<T extends AnyModelConfig> = {
|
||||||
modelEntities: EntityState<T, string> | undefined;
|
modelEntities: EntityState<T, string> | undefined;
|
||||||
selectedModel?: Pick<T, 'base_model' | 'model_name' | 'model_type'> | null;
|
selectedModel?: ModelIdentifierWithBase | null;
|
||||||
onChange: (value: T | null) => void;
|
onChange: (value: T | null) => void;
|
||||||
getIsDisabled?: (model: T) => boolean;
|
getIsDisabled?: (model: T) => boolean;
|
||||||
optionsFilter?: (model: T) => boolean;
|
optionsFilter?: (model: T) => boolean;
|
||||||
@ -33,14 +33,14 @@ export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelCombobox
|
|||||||
return map(modelEntities.entities)
|
return map(modelEntities.entities)
|
||||||
.filter(optionsFilter)
|
.filter(optionsFilter)
|
||||||
.map((model) => ({
|
.map((model) => ({
|
||||||
label: model.model_name,
|
label: model.name,
|
||||||
value: model.id,
|
value: model.key,
|
||||||
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
|
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
|
||||||
}));
|
}));
|
||||||
}, [optionsFilter, getIsDisabled, modelEntities]);
|
}, [optionsFilter, getIsDisabled, modelEntities]);
|
||||||
|
|
||||||
const value = useMemo(
|
const value = useMemo(
|
||||||
() => options.find((m) => (selectedModel ? m.value === getModelId(selectedModel) : false)),
|
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
|
||||||
[options, selectedModel]
|
[options, selectedModel]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -7,8 +7,8 @@ import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/ge
|
|||||||
import { pick } from 'lodash-es';
|
import { pick } from 'lodash-es';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { VAEConfig } from 'services/api/endpoints/models';
|
|
||||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import type { VAEConfig } from 'services/api/types';
|
||||||
|
|
||||||
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
|
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
|
||||||
const { model, vae } = generation;
|
const { model, vae } = generation;
|
||||||
@ -22,26 +22,28 @@ const ParamVAEModelSelect = () => {
|
|||||||
const { data, isLoading } = useGetVaeModelsQuery();
|
const { data, isLoading } = useGetVaeModelsQuery();
|
||||||
const getIsDisabled = useCallback(
|
const getIsDisabled = useCallback(
|
||||||
(vae: VAEConfig): boolean => {
|
(vae: VAEConfig): boolean => {
|
||||||
const isCompatible = model?.base_model === vae.base_model;
|
const isCompatible = model?.base === vae.base;
|
||||||
const hasMainModel = Boolean(model?.base_model);
|
const hasMainModel = Boolean(model?.base);
|
||||||
return !hasMainModel || !isCompatible;
|
return !hasMainModel || !isCompatible;
|
||||||
},
|
},
|
||||||
[model?.base_model]
|
[model?.base]
|
||||||
);
|
);
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(vae: VAEConfig | null) => {
|
(vae: VAEConfig | null) => {
|
||||||
dispatch(vaeSelected(vae ? pick(vae, 'base_model', 'model_name') : null));
|
dispatch(vaeSelected(vae ? pick(vae, 'key', 'base') : null));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelEntities: data,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: vae ? { ...vae, model_type: 'vae' } : null,
|
selectedModel: vae ? pick(vae, 'key', 'base') : null,
|
||||||
isLoading,
|
isLoading,
|
||||||
getIsDisabled,
|
getIsDisabled,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
console.log(value)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<FormControl isDisabled={!options.length} isInvalid={!options.length}>
|
<FormControl isDisabled={!options.length} isInvalid={!options.length}>
|
||||||
<InformationalPopover feature="paramVAE">
|
<InformationalPopover feature="paramVAE">
|
||||||
@ -50,7 +52,7 @@ const ParamVAEModelSelect = () => {
|
|||||||
<Combobox
|
<Combobox
|
||||||
isClearable
|
isClearable
|
||||||
value={value}
|
value={value}
|
||||||
placeholder={value ? placeholder : t('models.defaultVAE')}
|
placeholder={value ? value.value : t('models.defaultVAE')}
|
||||||
options={options}
|
options={options}
|
||||||
onChange={onChange}
|
onChange={onChange}
|
||||||
noOptionsMessage={noOptionsMessage}
|
noOptionsMessage={noOptionsMessage}
|
||||||
|
@ -21,9 +21,6 @@ import type {
|
|||||||
import type { ApiTagDescription, tagTypes } from '..';
|
import type { ApiTagDescription, tagTypes } from '..';
|
||||||
import { api, buildV2Url, LIST_TAG } from '..';
|
import { api, buildV2Url, LIST_TAG } from '..';
|
||||||
|
|
||||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
|
||||||
export const getModelId = (input: any): any => input;
|
|
||||||
|
|
||||||
type UpdateMainModelArg = {
|
type UpdateMainModelArg = {
|
||||||
base_model: BaseModelType;
|
base_model: BaseModelType;
|
||||||
model_name: string;
|
model_name: string;
|
||||||
|
Loading…
Reference in New Issue
Block a user