mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
use a listener to recalculate trigger phrases when model or lora list changes
This commit is contained in:
parent
62f6c0a7fa
commit
641eefa4a1
@ -55,6 +55,7 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addPromptTriggerListChanged } from './listeners/promptTriggerListChanged';
|
||||
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
@ -153,7 +154,8 @@ addFirstListImagesListener(startAppListening);
|
||||
// Ad-hoc upscale workflwo
|
||||
addUpscaleRequestedListener(startAppListening);
|
||||
|
||||
// Dynamic prompts
|
||||
// Prompts
|
||||
addDynamicPromptsListener(startAppListening);
|
||||
addPromptTriggerListChanged(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening);
|
||||
|
@ -0,0 +1,39 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { loraAdded, loraRemoved } from 'features/lora/store/loraSlice';
|
||||
import { modelChanged, triggerPhrasesChanged } from 'features/parameters/store/generationSlice';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
|
||||
const matcher = isAnyOf(loraAdded, loraRemoved, modelChanged);
|
||||
|
||||
export const addPromptTriggerListChanged = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher,
|
||||
effect: async (action, { dispatch, getState, cancelActiveListeners }) => {
|
||||
cancelActiveListeners();
|
||||
const state = getState();
|
||||
const { model: mainModel } = state.generation;
|
||||
const { loras } = state.lora;
|
||||
|
||||
let triggerPhrases: string[] = [];
|
||||
|
||||
if (!mainModel) {
|
||||
dispatch(triggerPhrasesChanged([]));
|
||||
return;
|
||||
}
|
||||
|
||||
const { data: mainModelData } = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(mainModel.key));
|
||||
triggerPhrases = mainModelData?.trigger_phrases || [];
|
||||
|
||||
for (let index = 0; index < Object.values(loras).length; index++) {
|
||||
const lora = Object.values(loras)[index];
|
||||
if (lora) {
|
||||
const { data: loraData } = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(lora.model.key));
|
||||
triggerPhrases = [...triggerPhrases, ...(loraData?.trigger_phrases || [])];
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(triggerPhrasesChanged(triggerPhrases));
|
||||
},
|
||||
});
|
||||
};
|
@ -56,6 +56,7 @@ const initialGenerationState: GenerationState = {
|
||||
shouldUseCpuNoise: true,
|
||||
shouldShowAdvancedOptions: false,
|
||||
aspectRatio: { ...initialAspectRatioState },
|
||||
triggerPhrases: [],
|
||||
};
|
||||
|
||||
export const generationSlice = createSlice({
|
||||
@ -207,6 +208,9 @@ export const generationSlice = createSlice({
|
||||
aspectRatioChanged: (state, action: PayloadAction<AspectRatioState>) => {
|
||||
state.aspectRatio = action.payload;
|
||||
},
|
||||
triggerPhrasesChanged: (state, action: PayloadAction<string[]>) => {
|
||||
state.triggerPhrases = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(configChanged, (state, action) => {
|
||||
@ -285,6 +289,7 @@ export const {
|
||||
heightChanged,
|
||||
widthRecalled,
|
||||
heightRecalled,
|
||||
triggerPhrasesChanged,
|
||||
} = generationSlice.actions;
|
||||
|
||||
export const { selectOptimalDimension } = generationSlice.selectors;
|
||||
|
@ -51,6 +51,7 @@ export interface GenerationState {
|
||||
shouldUseCpuNoise: boolean;
|
||||
shouldShowAdvancedOptions: boolean;
|
||||
aspectRatio: AspectRatioState;
|
||||
triggerPhrases: string[];
|
||||
}
|
||||
|
||||
export type PayloadActionWithOptimalDimension<T = void> = PayloadAction<T, string, { optimalDimension: number }>;
|
||||
|
@ -1,13 +1,13 @@
|
||||
import type { ChakraProps, ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { GroupBase, OptionsOrGroups } from 'chakra-react-select';
|
||||
import type { PromptTriggerSelectProps } from 'features/prompt/types';
|
||||
import { t } from 'i18next';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelMetadataQuery, useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
|
||||
|
||||
@ -15,10 +15,9 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
||||
const { t } = useTranslation();
|
||||
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
const currentModelKey = useAppSelector((s) => s.generation.model?.key);
|
||||
const triggerPhrases = useAppSelector((s) => s.generation.triggerPhrases);
|
||||
|
||||
const { data, isLoading } = useGetTextualInversionModelsQuery();
|
||||
const { data: metadata } = useGetModelMetadataQuery(currentModelKey ?? skipToken);
|
||||
|
||||
const _onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
@ -32,34 +31,28 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
||||
[onSelect]
|
||||
);
|
||||
|
||||
const embeddingOptions = useMemo(() => {
|
||||
if (!data) {
|
||||
return [];
|
||||
}
|
||||
const options = useMemo(() => {
|
||||
let embeddingOptions: OptionsOrGroups<ComboboxOption, GroupBase<ComboboxOption>> = [];
|
||||
|
||||
if (data) {
|
||||
const compatibleEmbeddingsArray = map(data.entities).filter((model) => model.base === currentBaseModel);
|
||||
|
||||
return [
|
||||
embeddingOptions = [
|
||||
{
|
||||
label: t('prompt.compatibleEmbeddings'),
|
||||
options: compatibleEmbeddingsArray.map((model) => ({ label: model.name, value: `<${model.name}>` })),
|
||||
},
|
||||
];
|
||||
}, [data, currentBaseModel, t]);
|
||||
|
||||
const options = useMemo(() => {
|
||||
if (!metadata || !metadata.trigger_phrases) {
|
||||
return [...embeddingOptions];
|
||||
}
|
||||
|
||||
const metadataOptions = [
|
||||
{
|
||||
label: t('modelManager.triggerPhrases'),
|
||||
options: metadata.trigger_phrases.map((phrase) => ({ label: phrase, value: phrase })),
|
||||
options: triggerPhrases.map((phrase) => ({ label: phrase, value: phrase })),
|
||||
},
|
||||
];
|
||||
return [...metadataOptions, ...embeddingOptions];
|
||||
}, [embeddingOptions, metadata, t]);
|
||||
}, [data, currentBaseModel, triggerPhrases, t]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user