refactor(ui): compute prompt trigger options in the component

We can derive the valid trigger options in the component without needing to lift the options list into global state.
This commit is contained in:
psychedelicious 2024-03-05 23:07:18 +11:00
parent 8319aca5f9
commit b0275700b3
6 changed files with 53 additions and 79 deletions

View File

@ -55,7 +55,6 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested'; import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store'; import type { AppDispatch, RootState } from 'app/store/store';
import { addPromptTriggerListChanged } from './listeners/promptTriggerListChanged';
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings'; import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -156,6 +155,5 @@ addUpscaleRequestedListener(startAppListening);
// Prompts // Prompts
addDynamicPromptsListener(startAppListening); addDynamicPromptsListener(startAppListening);
addPromptTriggerListChanged(startAppListening);
addSetDefaultSettingsListener(startAppListening); addSetDefaultSettingsListener(startAppListening);

View File

@ -1,46 +0,0 @@
import type { ComboboxOption } from '@invoke-ai/ui-library';
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { loraAdded, loraIsEnabledChanged, loraRecalled, loraRemoved } from 'features/lora/store/loraSlice';
import { modelChanged, triggerPhrasesChanged } from 'features/parameters/store/generationSlice';
import { modelsApi } from 'services/api/endpoints/models';
const matcher = isAnyOf(loraAdded, loraRemoved, loraRecalled, loraIsEnabledChanged, modelChanged);
export const addPromptTriggerListChanged = (startAppListening: AppStartListening) => {
startAppListening({
matcher,
effect: async (action, { dispatch, getState, cancelActiveListeners }) => {
cancelActiveListeners();
const state = getState();
const { model: mainModel } = state.generation;
const { loras } = state.lora;
let triggerPhrases: ComboboxOption[] = [];
if (!mainModel) {
dispatch(triggerPhrasesChanged([]));
return;
}
const { data: mainModelData } = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(mainModel.key));
triggerPhrases = (mainModelData?.trigger_phrases || []).map((phrase) => ({ label: phrase, value: phrase }));
for (let index = 0; index < Object.values(loras).length; index++) {
const lora = Object.values(loras)[index];
if (lora && lora.isEnabled) {
const { data: loraMetadata } = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(lora.model.key));
const { data: loraConfig } = modelsApi.endpoints.getModelConfig.select(lora.model.key)(state);
const loraTriggerPhrases = (loraMetadata?.trigger_phrases || []).map((phrase) => ({
label: phrase,
value: phrase,
description: loraConfig?.name ? `(${loraConfig?.name})` : '',
}));
triggerPhrases = [...triggerPhrases, ...loraTriggerPhrases];
}
}
dispatch(triggerPhrasesChanged(triggerPhrases));
},
});
};

View File

@ -1,4 +1,3 @@
import type { ComboboxOption } from '@invoke-ai/ui-library';
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { PersistConfig, RootState } from 'app/store/store';
@ -57,7 +56,6 @@ const initialGenerationState: GenerationState = {
shouldUseCpuNoise: true, shouldUseCpuNoise: true,
shouldShowAdvancedOptions: false, shouldShowAdvancedOptions: false,
aspectRatio: { ...initialAspectRatioState }, aspectRatio: { ...initialAspectRatioState },
triggerPhrases: [],
}; };
export const generationSlice = createSlice({ export const generationSlice = createSlice({
@ -209,9 +207,6 @@ export const generationSlice = createSlice({
aspectRatioChanged: (state, action: PayloadAction<AspectRatioState>) => { aspectRatioChanged: (state, action: PayloadAction<AspectRatioState>) => {
state.aspectRatio = action.payload; state.aspectRatio = action.payload;
}, },
triggerPhrasesChanged: (state, action: PayloadAction<ComboboxOption[]>) => {
state.triggerPhrases = action.payload;
},
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
builder.addCase(configChanged, (state, action) => { builder.addCase(configChanged, (state, action) => {
@ -290,7 +285,6 @@ export const {
heightChanged, heightChanged,
widthRecalled, widthRecalled,
heightRecalled, heightRecalled,
triggerPhrasesChanged,
} = generationSlice.actions; } = generationSlice.actions;
export const { selectOptimalDimension } = generationSlice.selectors; export const { selectOptimalDimension } = generationSlice.selectors;

View File

@ -1,4 +1,3 @@
import type { ComboboxOption } from '@invoke-ai/ui-library';
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types'; import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
import type { import type {
@ -52,7 +51,6 @@ export interface GenerationState {
shouldUseCpuNoise: boolean; shouldUseCpuNoise: boolean;
shouldShowAdvancedOptions: boolean; shouldShowAdvancedOptions: boolean;
aspectRatio: AspectRatioState; aspectRatio: AspectRatioState;
triggerPhrases: ComboboxOption[];
} }
export type PayloadActionWithOptimalDimension<T = void> = PayloadAction<T, string, { optimalDimension: number }>; export type PayloadActionWithOptimalDimension<T = void> = PayloadAction<T, string, { optimalDimension: number }>;

View File

@ -1,23 +1,32 @@
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl } from '@invoke-ai/ui-library'; import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase, OptionsOrGroups } from 'chakra-react-select'; import type { GroupBase } from 'chakra-react-select';
import { selectLoraSlice } from 'features/lora/store/loraSlice';
import type { PromptTriggerSelectProps } from 'features/prompt/types'; import type { PromptTriggerSelectProps } from 'features/prompt/types';
import { t } from 'i18next'; import { t } from 'i18next';
import { map } from 'lodash-es'; import { flatten, map } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models'; import {
loraModelsAdapterSelectors,
textualInversionModelsAdapterSelectors,
useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery,
} from 'services/api/endpoints/models';
const noOptionsMessage = () => t('prompt.noMatchingTriggers'); const noOptionsMessage = () => t('prompt.noMatchingTriggers');
const selectLoRAs = createMemoizedSelector(selectLoraSlice, (loras) => loras.loras);
export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSelectProps) => { export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSelectProps) => {
const { t } = useTranslation(); const { t } = useTranslation();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base); const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const triggerPhrases = useAppSelector((s) => s.generation.triggerPhrases); const addedLoRAs = useAppSelector(selectLoRAs);
const { data: loraModels, isLoading: isLoadingLoRAs } = useGetLoRAModelsQuery();
const { data, isLoading } = useGetTextualInversionModelsQuery(); const { data: tiModels, isLoading: isLoadingTIs } = useGetTextualInversionModelsQuery();
const _onChange = useCallback<ComboboxOnChange>( const _onChange = useCallback<ComboboxOnChange>(
(v) => { (v) => {
@ -32,32 +41,48 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
); );
const options = useMemo(() => { const options = useMemo(() => {
let embeddingOptions: OptionsOrGroups<ComboboxOption, GroupBase<ComboboxOption>> = []; const _options: GroupBase<ComboboxOption>[] = [];
if (data) { if (tiModels) {
const compatibleEmbeddingsArray = map(data.entities).filter((model) => model.base === currentBaseModel); const embeddingOptions = textualInversionModelsAdapterSelectors
.selectAll(tiModels)
.filter((ti) => ti.base === currentBaseModel)
.map((model) => ({ label: model.name, value: `<${model.name}>` }));
embeddingOptions = [ if (embeddingOptions.length > 0) {
{ _options.push({
label: t('prompt.compatibleEmbeddings'), label: t('prompt.compatibleEmbeddings'),
options: compatibleEmbeddingsArray.map((model) => ({ label: model.name, value: `<${model.name}>` })), options: embeddingOptions,
}, });
]; }
} }
const metadataOptions = [ if (loraModels) {
{ const triggerPhraseOptions = loraModelsAdapterSelectors
label: t('modelManager.triggerPhrases'), .selectAll(loraModels)
options: triggerPhrases, .filter((lora) => map(addedLoRAs, (l) => l.model.key).includes(lora.key))
}, .map((lora) => {
]; if (lora.trigger_phrases) {
return [...metadataOptions, ...embeddingOptions]; return lora.trigger_phrases.map((triggerPhrase) => ({ label: triggerPhrase, value: triggerPhrase }));
}, [data, currentBaseModel, triggerPhrases, t]); }
return [];
})
.flatMap((x) => x);
if (triggerPhraseOptions.length > 0) {
_options.push({
label: t('modelManager.triggerPhrases'),
options: flatten(triggerPhraseOptions),
});
}
}
return _options;
}, [tiModels, loraModels, t, currentBaseModel, addedLoRAs]);
return ( return (
<FormControl> <FormControl>
<Combobox <Combobox
placeholder={isLoading ? t('common.loading') : t('prompt.addPromptTrigger')} placeholder={isLoadingLoRAs || isLoadingTIs ? t('common.loading') : t('prompt.addPromptTrigger')}
defaultMenuIsOpen defaultMenuIsOpen
autoFocus autoFocus
value={null} value={null}

View File

@ -66,6 +66,7 @@ const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
selectId: (entity) => entity.key, selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.name.localeCompare(b.name),
}); });
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({ const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
selectId: (entity) => entity.key, selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.name.localeCompare(b.name),
@ -85,6 +86,10 @@ const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelC
selectId: (entity) => entity.key, selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.name.localeCompare(b.name),
}); });
export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors(
undefined,
getSelectorsOptions
);
const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({ const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
selectId: (entity) => entity.key, selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.name.localeCompare(b.name),