mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
use a listener to recalculate trigger phrases when model or lora list changes
This commit is contained in:
parent
62f6c0a7fa
commit
641eefa4a1
@ -55,6 +55,7 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
|
|||||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||||
import type { AppDispatch, RootState } from 'app/store/store';
|
import type { AppDispatch, RootState } from 'app/store/store';
|
||||||
|
|
||||||
|
import { addPromptTriggerListChanged } from './listeners/promptTriggerListChanged';
|
||||||
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
||||||
|
|
||||||
export const listenerMiddleware = createListenerMiddleware();
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
@ -153,7 +154,8 @@ addFirstListImagesListener(startAppListening);
|
|||||||
// Ad-hoc upscale workflwo
|
// Ad-hoc upscale workflwo
|
||||||
addUpscaleRequestedListener(startAppListening);
|
addUpscaleRequestedListener(startAppListening);
|
||||||
|
|
||||||
// Dynamic prompts
|
// Prompts
|
||||||
addDynamicPromptsListener(startAppListening);
|
addDynamicPromptsListener(startAppListening);
|
||||||
|
addPromptTriggerListChanged(startAppListening);
|
||||||
|
|
||||||
addSetDefaultSettingsListener(startAppListening);
|
addSetDefaultSettingsListener(startAppListening);
|
||||||
|
@ -0,0 +1,39 @@
|
|||||||
|
import { isAnyOf } from '@reduxjs/toolkit';
|
||||||
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
|
import { loraAdded, loraRemoved } from 'features/lora/store/loraSlice';
|
||||||
|
import { modelChanged, triggerPhrasesChanged } from 'features/parameters/store/generationSlice';
|
||||||
|
import { modelsApi } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
const matcher = isAnyOf(loraAdded, loraRemoved, modelChanged);
|
||||||
|
|
||||||
|
export const addPromptTriggerListChanged = (startAppListening: AppStartListening) => {
|
||||||
|
startAppListening({
|
||||||
|
matcher,
|
||||||
|
effect: async (action, { dispatch, getState, cancelActiveListeners }) => {
|
||||||
|
cancelActiveListeners();
|
||||||
|
const state = getState();
|
||||||
|
const { model: mainModel } = state.generation;
|
||||||
|
const { loras } = state.lora;
|
||||||
|
|
||||||
|
let triggerPhrases: string[] = [];
|
||||||
|
|
||||||
|
if (!mainModel) {
|
||||||
|
dispatch(triggerPhrasesChanged([]));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { data: mainModelData } = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(mainModel.key));
|
||||||
|
triggerPhrases = mainModelData?.trigger_phrases || [];
|
||||||
|
|
||||||
|
for (let index = 0; index < Object.values(loras).length; index++) {
|
||||||
|
const lora = Object.values(loras)[index];
|
||||||
|
if (lora) {
|
||||||
|
const { data: loraData } = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(lora.model.key));
|
||||||
|
triggerPhrases = [...triggerPhrases, ...(loraData?.trigger_phrases || [])];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(triggerPhrasesChanged(triggerPhrases));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -56,6 +56,7 @@ const initialGenerationState: GenerationState = {
|
|||||||
shouldUseCpuNoise: true,
|
shouldUseCpuNoise: true,
|
||||||
shouldShowAdvancedOptions: false,
|
shouldShowAdvancedOptions: false,
|
||||||
aspectRatio: { ...initialAspectRatioState },
|
aspectRatio: { ...initialAspectRatioState },
|
||||||
|
triggerPhrases: [],
|
||||||
};
|
};
|
||||||
|
|
||||||
export const generationSlice = createSlice({
|
export const generationSlice = createSlice({
|
||||||
@ -207,6 +208,9 @@ export const generationSlice = createSlice({
|
|||||||
aspectRatioChanged: (state, action: PayloadAction<AspectRatioState>) => {
|
aspectRatioChanged: (state, action: PayloadAction<AspectRatioState>) => {
|
||||||
state.aspectRatio = action.payload;
|
state.aspectRatio = action.payload;
|
||||||
},
|
},
|
||||||
|
triggerPhrasesChanged: (state, action: PayloadAction<string[]>) => {
|
||||||
|
state.triggerPhrases = action.payload;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(configChanged, (state, action) => {
|
builder.addCase(configChanged, (state, action) => {
|
||||||
@ -285,6 +289,7 @@ export const {
|
|||||||
heightChanged,
|
heightChanged,
|
||||||
widthRecalled,
|
widthRecalled,
|
||||||
heightRecalled,
|
heightRecalled,
|
||||||
|
triggerPhrasesChanged,
|
||||||
} = generationSlice.actions;
|
} = generationSlice.actions;
|
||||||
|
|
||||||
export const { selectOptimalDimension } = generationSlice.selectors;
|
export const { selectOptimalDimension } = generationSlice.selectors;
|
||||||
|
@ -51,6 +51,7 @@ export interface GenerationState {
|
|||||||
shouldUseCpuNoise: boolean;
|
shouldUseCpuNoise: boolean;
|
||||||
shouldShowAdvancedOptions: boolean;
|
shouldShowAdvancedOptions: boolean;
|
||||||
aspectRatio: AspectRatioState;
|
aspectRatio: AspectRatioState;
|
||||||
|
triggerPhrases: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export type PayloadActionWithOptimalDimension<T = void> = PayloadAction<T, string, { optimalDimension: number }>;
|
export type PayloadActionWithOptimalDimension<T = void> = PayloadAction<T, string, { optimalDimension: number }>;
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
import type { ChakraProps, ComboboxOnChange } from '@invoke-ai/ui-library';
|
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import type { GroupBase, OptionsOrGroups } from 'chakra-react-select';
|
||||||
import type { PromptTriggerSelectProps } from 'features/prompt/types';
|
import type { PromptTriggerSelectProps } from 'features/prompt/types';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { map } from 'lodash-es';
|
import { map } from 'lodash-es';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetModelMetadataQuery, useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
|
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
|
||||||
|
|
||||||
@ -15,10 +15,9 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||||
const currentModelKey = useAppSelector((s) => s.generation.model?.key);
|
const triggerPhrases = useAppSelector((s) => s.generation.triggerPhrases);
|
||||||
|
|
||||||
const { data, isLoading } = useGetTextualInversionModelsQuery();
|
const { data, isLoading } = useGetTextualInversionModelsQuery();
|
||||||
const { data: metadata } = useGetModelMetadataQuery(currentModelKey ?? skipToken);
|
|
||||||
|
|
||||||
const _onChange = useCallback<ComboboxOnChange>(
|
const _onChange = useCallback<ComboboxOnChange>(
|
||||||
(v) => {
|
(v) => {
|
||||||
@ -32,34 +31,28 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
|||||||
[onSelect]
|
[onSelect]
|
||||||
);
|
);
|
||||||
|
|
||||||
const embeddingOptions = useMemo(() => {
|
const options = useMemo(() => {
|
||||||
if (!data) {
|
let embeddingOptions: OptionsOrGroups<ComboboxOption, GroupBase<ComboboxOption>> = [];
|
||||||
return [];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if (data) {
|
||||||
const compatibleEmbeddingsArray = map(data.entities).filter((model) => model.base === currentBaseModel);
|
const compatibleEmbeddingsArray = map(data.entities).filter((model) => model.base === currentBaseModel);
|
||||||
|
|
||||||
return [
|
embeddingOptions = [
|
||||||
{
|
{
|
||||||
label: t('prompt.compatibleEmbeddings'),
|
label: t('prompt.compatibleEmbeddings'),
|
||||||
options: compatibleEmbeddingsArray.map((model) => ({ label: model.name, value: `<${model.name}>` })),
|
options: compatibleEmbeddingsArray.map((model) => ({ label: model.name, value: `<${model.name}>` })),
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
}, [data, currentBaseModel, t]);
|
|
||||||
|
|
||||||
const options = useMemo(() => {
|
|
||||||
if (!metadata || !metadata.trigger_phrases) {
|
|
||||||
return [...embeddingOptions];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const metadataOptions = [
|
const metadataOptions = [
|
||||||
{
|
{
|
||||||
label: t('modelManager.triggerPhrases'),
|
label: t('modelManager.triggerPhrases'),
|
||||||
options: metadata.trigger_phrases.map((phrase) => ({ label: phrase, value: phrase })),
|
options: triggerPhrases.map((phrase) => ({ label: phrase, value: phrase })),
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
return [...metadataOptions, ...embeddingOptions];
|
return [...metadataOptions, ...embeddingOptions];
|
||||||
}, [embeddingOptions, metadata, t]);
|
}, [data, currentBaseModel, triggerPhrases, t]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<FormControl>
|
<FormControl>
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user