mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(ui): compute prompt trigger options in the component
We can derive the valid trigger options in the component without needing to lift the options list into global state.
This commit is contained in:
parent
8319aca5f9
commit
b0275700b3
@ -55,7 +55,6 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addPromptTriggerListChanged } from './listeners/promptTriggerListChanged';
|
||||
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
@ -156,6 +155,5 @@ addUpscaleRequestedListener(startAppListening);
|
||||
|
||||
// Prompts
|
||||
addDynamicPromptsListener(startAppListening);
|
||||
addPromptTriggerListChanged(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening);
|
||||
|
@ -1,46 +0,0 @@
|
||||
import type { ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { loraAdded, loraIsEnabledChanged, loraRecalled, loraRemoved } from 'features/lora/store/loraSlice';
|
||||
import { modelChanged, triggerPhrasesChanged } from 'features/parameters/store/generationSlice';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
|
||||
const matcher = isAnyOf(loraAdded, loraRemoved, loraRecalled, loraIsEnabledChanged, modelChanged);
|
||||
|
||||
export const addPromptTriggerListChanged = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher,
|
||||
effect: async (action, { dispatch, getState, cancelActiveListeners }) => {
|
||||
cancelActiveListeners();
|
||||
const state = getState();
|
||||
const { model: mainModel } = state.generation;
|
||||
const { loras } = state.lora;
|
||||
|
||||
let triggerPhrases: ComboboxOption[] = [];
|
||||
|
||||
if (!mainModel) {
|
||||
dispatch(triggerPhrasesChanged([]));
|
||||
return;
|
||||
}
|
||||
|
||||
const { data: mainModelData } = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(mainModel.key));
|
||||
triggerPhrases = (mainModelData?.trigger_phrases || []).map((phrase) => ({ label: phrase, value: phrase }));
|
||||
|
||||
for (let index = 0; index < Object.values(loras).length; index++) {
|
||||
const lora = Object.values(loras)[index];
|
||||
if (lora && lora.isEnabled) {
|
||||
const { data: loraMetadata } = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(lora.model.key));
|
||||
const { data: loraConfig } = modelsApi.endpoints.getModelConfig.select(lora.model.key)(state);
|
||||
const loraTriggerPhrases = (loraMetadata?.trigger_phrases || []).map((phrase) => ({
|
||||
label: phrase,
|
||||
value: phrase,
|
||||
description: loraConfig?.name ? `(${loraConfig?.name})` : '',
|
||||
}));
|
||||
triggerPhrases = [...triggerPhrases, ...loraTriggerPhrases];
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(triggerPhrasesChanged(triggerPhrases));
|
||||
},
|
||||
});
|
||||
};
|
@ -1,4 +1,3 @@
|
||||
import type { ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
@ -57,7 +56,6 @@ const initialGenerationState: GenerationState = {
|
||||
shouldUseCpuNoise: true,
|
||||
shouldShowAdvancedOptions: false,
|
||||
aspectRatio: { ...initialAspectRatioState },
|
||||
triggerPhrases: [],
|
||||
};
|
||||
|
||||
export const generationSlice = createSlice({
|
||||
@ -209,9 +207,6 @@ export const generationSlice = createSlice({
|
||||
aspectRatioChanged: (state, action: PayloadAction<AspectRatioState>) => {
|
||||
state.aspectRatio = action.payload;
|
||||
},
|
||||
triggerPhrasesChanged: (state, action: PayloadAction<ComboboxOption[]>) => {
|
||||
state.triggerPhrases = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(configChanged, (state, action) => {
|
||||
@ -290,7 +285,6 @@ export const {
|
||||
heightChanged,
|
||||
widthRecalled,
|
||||
heightRecalled,
|
||||
triggerPhrasesChanged,
|
||||
} = generationSlice.actions;
|
||||
|
||||
export const { selectOptimalDimension } = generationSlice.selectors;
|
||||
|
@ -1,4 +1,3 @@
|
||||
import type { ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
|
||||
import type {
|
||||
@ -52,7 +51,6 @@ export interface GenerationState {
|
||||
shouldUseCpuNoise: boolean;
|
||||
shouldShowAdvancedOptions: boolean;
|
||||
aspectRatio: AspectRatioState;
|
||||
triggerPhrases: ComboboxOption[];
|
||||
}
|
||||
|
||||
export type PayloadActionWithOptimalDimension<T = void> = PayloadAction<T, string, { optimalDimension: number }>;
|
||||
|
@ -1,23 +1,32 @@
|
||||
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { GroupBase, OptionsOrGroups } from 'chakra-react-select';
|
||||
import type { GroupBase } from 'chakra-react-select';
|
||||
import { selectLoraSlice } from 'features/lora/store/loraSlice';
|
||||
import type { PromptTriggerSelectProps } from 'features/prompt/types';
|
||||
import { t } from 'i18next';
|
||||
import { map } from 'lodash-es';
|
||||
import { flatten, map } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
import {
|
||||
loraModelsAdapterSelectors,
|
||||
textualInversionModelsAdapterSelectors,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetTextualInversionModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
|
||||
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
|
||||
|
||||
const selectLoRAs = createMemoizedSelector(selectLoraSlice, (loras) => loras.loras);
|
||||
|
||||
export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSelectProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
const triggerPhrases = useAppSelector((s) => s.generation.triggerPhrases);
|
||||
|
||||
const { data, isLoading } = useGetTextualInversionModelsQuery();
|
||||
const addedLoRAs = useAppSelector(selectLoRAs);
|
||||
const { data: loraModels, isLoading: isLoadingLoRAs } = useGetLoRAModelsQuery();
|
||||
const { data: tiModels, isLoading: isLoadingTIs } = useGetTextualInversionModelsQuery();
|
||||
|
||||
const _onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
@ -32,32 +41,48 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
||||
);
|
||||
|
||||
const options = useMemo(() => {
|
||||
let embeddingOptions: OptionsOrGroups<ComboboxOption, GroupBase<ComboboxOption>> = [];
|
||||
const _options: GroupBase<ComboboxOption>[] = [];
|
||||
|
||||
if (data) {
|
||||
const compatibleEmbeddingsArray = map(data.entities).filter((model) => model.base === currentBaseModel);
|
||||
if (tiModels) {
|
||||
const embeddingOptions = textualInversionModelsAdapterSelectors
|
||||
.selectAll(tiModels)
|
||||
.filter((ti) => ti.base === currentBaseModel)
|
||||
.map((model) => ({ label: model.name, value: `<${model.name}>` }));
|
||||
|
||||
embeddingOptions = [
|
||||
{
|
||||
if (embeddingOptions.length > 0) {
|
||||
_options.push({
|
||||
label: t('prompt.compatibleEmbeddings'),
|
||||
options: compatibleEmbeddingsArray.map((model) => ({ label: model.name, value: `<${model.name}>` })),
|
||||
},
|
||||
];
|
||||
options: embeddingOptions,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const metadataOptions = [
|
||||
{
|
||||
label: t('modelManager.triggerPhrases'),
|
||||
options: triggerPhrases,
|
||||
},
|
||||
];
|
||||
return [...metadataOptions, ...embeddingOptions];
|
||||
}, [data, currentBaseModel, triggerPhrases, t]);
|
||||
if (loraModels) {
|
||||
const triggerPhraseOptions = loraModelsAdapterSelectors
|
||||
.selectAll(loraModels)
|
||||
.filter((lora) => map(addedLoRAs, (l) => l.model.key).includes(lora.key))
|
||||
.map((lora) => {
|
||||
if (lora.trigger_phrases) {
|
||||
return lora.trigger_phrases.map((triggerPhrase) => ({ label: triggerPhrase, value: triggerPhrase }));
|
||||
}
|
||||
return [];
|
||||
})
|
||||
.flatMap((x) => x);
|
||||
|
||||
if (triggerPhraseOptions.length > 0) {
|
||||
_options.push({
|
||||
label: t('modelManager.triggerPhrases'),
|
||||
options: flatten(triggerPhraseOptions),
|
||||
});
|
||||
}
|
||||
}
|
||||
return _options;
|
||||
}, [tiModels, loraModels, t, currentBaseModel, addedLoRAs]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<Combobox
|
||||
placeholder={isLoading ? t('common.loading') : t('prompt.addPromptTrigger')}
|
||||
placeholder={isLoadingLoRAs || isLoadingTIs ? t('common.loading') : t('prompt.addPromptTrigger')}
|
||||
defaultMenuIsOpen
|
||||
autoFocus
|
||||
value={null}
|
||||
|
@ -66,6 +66,7 @@ const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
@ -85,6 +86,10 @@ const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelC
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors(
|
||||
undefined,
|
||||
getSelectorsOptions
|
||||
);
|
||||
const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
|
Loading…
Reference in New Issue
Block a user