make trigger phrases a list of options and add lora name as description to appear in dropdown

This commit is contained in:
Mary Hipp 2024-03-04 14:56:37 -05:00
parent 5284ba1812
commit 6ae33f2a21
5 changed files with 20 additions and 9 deletions

View File

@ -1,3 +1,4 @@
import type { ComboboxOption } from '@invoke-ai/ui-library';
import { isAnyOf } from '@reduxjs/toolkit'; import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { loraAdded, loraIsEnabledChanged, loraRecalled, loraRemoved } from 'features/lora/store/loraSlice'; import { loraAdded, loraIsEnabledChanged, loraRecalled, loraRemoved } from 'features/lora/store/loraSlice';
@ -15,7 +16,7 @@ export const addPromptTriggerListChanged = (startAppListening: AppStartListening
const { model: mainModel } = state.generation; const { model: mainModel } = state.generation;
const { loras } = state.lora; const { loras } = state.lora;
let triggerPhrases: string[] = []; let triggerPhrases: ComboboxOption[] = [];
if (!mainModel) { if (!mainModel) {
dispatch(triggerPhrasesChanged([])); dispatch(triggerPhrasesChanged([]));
@ -23,13 +24,19 @@ export const addPromptTriggerListChanged = (startAppListening: AppStartListening
} }
const { data: mainModelData } = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(mainModel.key)); const { data: mainModelData } = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(mainModel.key));
triggerPhrases = mainModelData?.trigger_phrases || []; triggerPhrases = (mainModelData?.trigger_phrases || []).map((phrase) => ({ label: phrase, value: phrase }));
for (let index = 0; index < Object.values(loras).length; index++) { for (let index = 0; index < Object.values(loras).length; index++) {
const lora = Object.values(loras)[index]; const lora = Object.values(loras)[index];
if (lora && lora.isEnabled) { if (lora && lora.isEnabled) {
const { data: loraData } = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(lora.model.key)); const { data: loraMetadata } = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(lora.model.key));
triggerPhrases = [...triggerPhrases, ...(loraData?.trigger_phrases || [])]; const { data: loraConfig } = modelsApi.endpoints.getModelConfig.select(lora.model.key)(state);
const loraTriggerPhrases = (loraMetadata?.trigger_phrases || []).map((phrase) => ({
label: phrase,
value: phrase,
description: loraConfig?.name ? `(${loraConfig?.name})` : '',
}));
triggerPhrases = [...triggerPhrases, ...loraTriggerPhrases];
} }
} }

View File

@ -1,10 +1,11 @@
import { Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query'; import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import Loading from 'common/components/Loading/Loading';
import { selectConfigSlice } from 'features/system/store/configSlice'; import { selectConfigSlice } from 'features/system/store/configSlice';
import { isNil } from 'lodash-es'; import { isNil } from 'lodash-es';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetModelMetadataQuery } from 'services/api/endpoints/models'; import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm'; import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
@ -23,6 +24,7 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
export const DefaultSettings = () => { export const DefaultSettings = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { t } = useTranslation();
const { data, isLoading } = useGetModelMetadataQuery(selectedModelKey ?? skipToken); const { data, isLoading } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } = const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
@ -59,7 +61,7 @@ export const DefaultSettings = () => {
]); ]);
if (isLoading) { if (isLoading) {
return <Loading />; return <Text>{t('common.loading')}</Text>;
} }
return <DefaultSettingsForm defaultSettingsDefaults={defaultSettingsDefaults} />; return <DefaultSettingsForm defaultSettingsDefaults={defaultSettingsDefaults} />;

View File

@ -1,3 +1,4 @@
import type { ComboboxOption } from '@invoke-ai/ui-library';
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { PersistConfig, RootState } from 'app/store/store';
@ -208,7 +209,7 @@ export const generationSlice = createSlice({
aspectRatioChanged: (state, action: PayloadAction<AspectRatioState>) => { aspectRatioChanged: (state, action: PayloadAction<AspectRatioState>) => {
state.aspectRatio = action.payload; state.aspectRatio = action.payload;
}, },
triggerPhrasesChanged: (state, action: PayloadAction<string[]>) => { triggerPhrasesChanged: (state, action: PayloadAction<ComboboxOption[]>) => {
state.triggerPhrases = action.payload; state.triggerPhrases = action.payload;
}, },
}, },

View File

@ -1,3 +1,4 @@
import type { ComboboxOption } from '@invoke-ai/ui-library';
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types'; import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
import type { import type {
@ -51,7 +52,7 @@ export interface GenerationState {
shouldUseCpuNoise: boolean; shouldUseCpuNoise: boolean;
shouldShowAdvancedOptions: boolean; shouldShowAdvancedOptions: boolean;
aspectRatio: AspectRatioState; aspectRatio: AspectRatioState;
triggerPhrases: string[]; triggerPhrases: ComboboxOption[];
} }
export type PayloadActionWithOptimalDimension<T = void> = PayloadAction<T, string, { optimalDimension: number }>; export type PayloadActionWithOptimalDimension<T = void> = PayloadAction<T, string, { optimalDimension: number }>;

View File

@ -48,7 +48,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
const metadataOptions = [ const metadataOptions = [
{ {
label: t('modelManager.triggerPhrases'), label: t('modelManager.triggerPhrases'),
options: triggerPhrases.map((phrase) => ({ label: phrase, value: phrase })), options: triggerPhrases,
}, },
]; ];
return [...metadataOptions, ...embeddingOptions]; return [...metadataOptions, ...embeddingOptions];