use a listener to recalculate trigger phrases when model or lora list changes

This commit is contained in:
Mary Hipp 2024-03-04 12:08:41 -05:00
parent 62f6c0a7fa
commit 641eefa4a1
6 changed files with 112 additions and 30 deletions

View File

@ -55,6 +55,7 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store';
import { addPromptTriggerListChanged } from './listeners/promptTriggerListChanged';
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
export const listenerMiddleware = createListenerMiddleware();
@ -153,7 +154,8 @@ addFirstListImagesListener(startAppListening);
// Ad-hoc upscale workflwo
addUpscaleRequestedListener(startAppListening);
// Dynamic prompts
// Prompts
addDynamicPromptsListener(startAppListening);
addPromptTriggerListChanged(startAppListening);
addSetDefaultSettingsListener(startAppListening);

View File

@ -0,0 +1,39 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { loraAdded, loraRemoved } from 'features/lora/store/loraSlice';
import { modelChanged, triggerPhrasesChanged } from 'features/parameters/store/generationSlice';
import { modelsApi } from 'services/api/endpoints/models';
const matcher = isAnyOf(loraAdded, loraRemoved, modelChanged);
export const addPromptTriggerListChanged = (startAppListening: AppStartListening) => {
startAppListening({
matcher,
effect: async (action, { dispatch, getState, cancelActiveListeners }) => {
cancelActiveListeners();
const state = getState();
const { model: mainModel } = state.generation;
const { loras } = state.lora;
let triggerPhrases: string[] = [];
if (!mainModel) {
dispatch(triggerPhrasesChanged([]));
return;
}
const { data: mainModelData } = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(mainModel.key));
triggerPhrases = mainModelData?.trigger_phrases || [];
for (let index = 0; index < Object.values(loras).length; index++) {
const lora = Object.values(loras)[index];
if (lora) {
const { data: loraData } = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(lora.model.key));
triggerPhrases = [...triggerPhrases, ...(loraData?.trigger_phrases || [])];
}
}
dispatch(triggerPhrasesChanged(triggerPhrases));
},
});
};

View File

@ -56,6 +56,7 @@ const initialGenerationState: GenerationState = {
shouldUseCpuNoise: true,
shouldShowAdvancedOptions: false,
aspectRatio: { ...initialAspectRatioState },
triggerPhrases: [],
};
export const generationSlice = createSlice({
@ -207,6 +208,9 @@ export const generationSlice = createSlice({
aspectRatioChanged: (state, action: PayloadAction<AspectRatioState>) => {
state.aspectRatio = action.payload;
},
triggerPhrasesChanged: (state, action: PayloadAction<string[]>) => {
state.triggerPhrases = action.payload;
},
},
extraReducers: (builder) => {
builder.addCase(configChanged, (state, action) => {
@ -285,6 +289,7 @@ export const {
heightChanged,
widthRecalled,
heightRecalled,
triggerPhrasesChanged,
} = generationSlice.actions;
export const { selectOptimalDimension } = generationSlice.selectors;

View File

@ -51,6 +51,7 @@ export interface GenerationState {
shouldUseCpuNoise: boolean;
shouldShowAdvancedOptions: boolean;
aspectRatio: AspectRatioState;
triggerPhrases: string[];
}
export type PayloadActionWithOptimalDimension<T = void> = PayloadAction<T, string, { optimalDimension: number }>;

View File

@ -1,13 +1,13 @@
import type { ChakraProps, ComboboxOnChange } from '@invoke-ai/ui-library';
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase, OptionsOrGroups } from 'chakra-react-select';
import type { PromptTriggerSelectProps } from 'features/prompt/types';
import { t } from 'i18next';
import { map } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetModelMetadataQuery, useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
@ -15,10 +15,9 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
const { t } = useTranslation();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const currentModelKey = useAppSelector((s) => s.generation.model?.key);
const triggerPhrases = useAppSelector((s) => s.generation.triggerPhrases);
const { data, isLoading } = useGetTextualInversionModelsQuery();
const { data: metadata } = useGetModelMetadataQuery(currentModelKey ?? skipToken);
const _onChange = useCallback<ComboboxOnChange>(
(v) => {
@ -32,34 +31,28 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
[onSelect]
);
const embeddingOptions = useMemo(() => {
if (!data) {
return [];
}
const options = useMemo(() => {
let embeddingOptions: OptionsOrGroups<ComboboxOption, GroupBase<ComboboxOption>> = [];
if (data) {
const compatibleEmbeddingsArray = map(data.entities).filter((model) => model.base === currentBaseModel);
return [
embeddingOptions = [
{
label: t('prompt.compatibleEmbeddings'),
options: compatibleEmbeddingsArray.map((model) => ({ label: model.name, value: `<${model.name}>` })),
},
];
}, [data, currentBaseModel, t]);
const options = useMemo(() => {
if (!metadata || !metadata.trigger_phrases) {
return [...embeddingOptions];
}
const metadataOptions = [
{
label: t('modelManager.triggerPhrases'),
options: metadata.trigger_phrases.map((phrase) => ({ label: phrase, value: phrase })),
options: triggerPhrases.map((phrase) => ({ label: phrase, value: phrase })),
},
];
return [...metadataOptions, ...embeddingOptions];
}, [embeddingOptions, metadata, t]);
}, [data, currentBaseModel, triggerPhrases, t]);
return (
<FormControl>

File diff suppressed because one or more lines are too long