use a listener to recalculate trigger phrases when model or lora list changes

This commit is contained in:
Mary Hipp 2024-03-04 12:08:41 -05:00 committed by psychedelicious
parent caafbf2f0d
commit ef171e890a
6 changed files with 99 additions and 25 deletions

View File

@ -55,6 +55,7 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store';
import { addPromptTriggerListChanged } from './listeners/promptTriggerListChanged';
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
export const listenerMiddleware = createListenerMiddleware();
@ -153,7 +154,8 @@ addFirstListImagesListener(startAppListening);
// Ad-hoc upscale workflwo
addUpscaleRequestedListener(startAppListening);
// Dynamic prompts
// Prompts
addDynamicPromptsListener(startAppListening);
addPromptTriggerListChanged(startAppListening);
addSetDefaultSettingsListener(startAppListening);

View File

@ -0,0 +1,39 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { loraAdded, loraRemoved } from 'features/lora/store/loraSlice';
import { modelChanged, triggerPhrasesChanged } from 'features/parameters/store/generationSlice';
import { modelsApi } from 'services/api/endpoints/models';
const matcher = isAnyOf(loraAdded, loraRemoved, modelChanged);
export const addPromptTriggerListChanged = (startAppListening: AppStartListening) => {
startAppListening({
matcher,
effect: async (action, { dispatch, getState, cancelActiveListeners }) => {
cancelActiveListeners();
const state = getState();
const { model: mainModel } = state.generation;
const { loras } = state.lora;
let triggerPhrases: string[] = [];
if (!mainModel) {
dispatch(triggerPhrasesChanged([]));
return;
}
const { data: mainModelData } = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(mainModel.key));
triggerPhrases = mainModelData?.trigger_phrases || [];
for (let index = 0; index < Object.values(loras).length; index++) {
const lora = Object.values(loras)[index];
if (lora) {
const { data: loraData } = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(lora.model.key));
triggerPhrases = [...triggerPhrases, ...(loraData?.trigger_phrases || [])];
}
}
dispatch(triggerPhrasesChanged(triggerPhrases));
},
});
};

View File

@ -56,6 +56,7 @@ const initialGenerationState: GenerationState = {
shouldUseCpuNoise: true,
shouldShowAdvancedOptions: false,
aspectRatio: { ...initialAspectRatioState },
triggerPhrases: [],
};
export const generationSlice = createSlice({
@ -207,6 +208,9 @@ export const generationSlice = createSlice({
aspectRatioChanged: (state, action: PayloadAction<AspectRatioState>) => {
state.aspectRatio = action.payload;
},
triggerPhrasesChanged: (state, action: PayloadAction<string[]>) => {
state.triggerPhrases = action.payload;
},
},
extraReducers: (builder) => {
builder.addCase(configChanged, (state, action) => {
@ -285,6 +289,7 @@ export const {
heightChanged,
widthRecalled,
heightRecalled,
triggerPhrasesChanged,
} = generationSlice.actions;
export const { selectOptimalDimension } = generationSlice.selectors;

View File

@ -51,6 +51,7 @@ export interface GenerationState {
shouldUseCpuNoise: boolean;
shouldShowAdvancedOptions: boolean;
aspectRatio: AspectRatioState;
triggerPhrases: string[];
}
export type PayloadActionWithOptimalDimension<T = void> = PayloadAction<T, string, { optimalDimension: number }>;

View File

@ -1,13 +1,13 @@
import type { ChakraProps, ComboboxOnChange } from '@invoke-ai/ui-library';
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase, OptionsOrGroups } from 'chakra-react-select';
import type { PromptTriggerSelectProps } from 'features/prompt/types';
import { t } from 'i18next';
import { map } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetModelMetadataQuery, useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
@ -15,10 +15,9 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
const { t } = useTranslation();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const currentModelKey = useAppSelector((s) => s.generation.model?.key);
const triggerPhrases = useAppSelector((s) => s.generation.triggerPhrases);
const { data, isLoading } = useGetTextualInversionModelsQuery();
const { data: metadata } = useGetModelMetadataQuery(currentModelKey ?? skipToken);
const _onChange = useCallback<ComboboxOnChange>(
(v) => {
@ -32,34 +31,28 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
[onSelect]
);
const embeddingOptions = useMemo(() => {
if (!data) {
return [];
}
const compatibleEmbeddingsArray = map(data.entities).filter((model) => model.base === currentBaseModel);
return [
{
label: t('prompt.compatibleEmbeddings'),
options: compatibleEmbeddingsArray.map((model) => ({ label: model.name, value: `<${model.name}>` })),
},
];
}, [data, currentBaseModel, t]);
const options = useMemo(() => {
if (!metadata || !metadata.trigger_phrases) {
return [...embeddingOptions];
let embeddingOptions: OptionsOrGroups<ComboboxOption, GroupBase<ComboboxOption>> = [];
if (data) {
const compatibleEmbeddingsArray = map(data.entities).filter((model) => model.base === currentBaseModel);
embeddingOptions = [
{
label: t('prompt.compatibleEmbeddings'),
options: compatibleEmbeddingsArray.map((model) => ({ label: model.name, value: `<${model.name}>` })),
},
];
}
const metadataOptions = [
{
label: t('modelManager.triggerPhrases'),
options: metadata.trigger_phrases.map((phrase) => ({ label: phrase, value: phrase })),
options: triggerPhrases.map((phrase) => ({ label: phrase, value: phrase })),
},
];
return [...metadataOptions, ...embeddingOptions];
}, [embeddingOptions, metadata, t]);
}, [data, currentBaseModel, triggerPhrases, t]);
return (
<FormControl>

View File

@ -4135,6 +4135,22 @@ export type components = {
*/
type: "gradient_mask_output";
};
/**
* GradientMaskOutput
* @description Outputs a denoise mask and an image representing the total gradient of the mask.
*/
GradientMaskOutput: {
/** @description Mask for denoise model run */
denoise_mask: components["schemas"]["DenoiseMaskField"];
/** @description Image representing the total gradient area of the mask. For paste-back purposes. */
expanded_mask_area: components["schemas"]["ImageField"];
/**
* type
* @default gradient_mask_output
* @constant
*/
type: "gradient_mask_output";
};
/** Graph */
Graph: {
/**
@ -7377,6 +7393,21 @@ export type components = {
/** Cfg Rescale Multiplier */
cfg_rescale_multiplier: number | null;
};
/** ModelDefaultSettings */
ModelDefaultSettings: {
/** Vae */
vae: string | null;
/** Vae Precision */
vae_precision: string | null;
/** Scheduler */
scheduler: ("ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm") | null;
/** Steps */
steps: number | null;
/** Cfg Scale */
cfg_scale: number | null;
/** Cfg Rescale Multiplier */
cfg_rescale_multiplier: number | null;
};
/**
* ModelFormat
* @description Storage format of model.
@ -7526,9 +7557,12 @@ export type components = {
* @description A set of changes to apply to model metadata.
*
* Only limited changes are valid:
* - `default_settings`: the user-configured default settings for this model
* - `trigger_phrases`: the list of trigger phrases for this model
*/
ModelMetadataChanges: {
/** @description The user-configured default settings for this model */
default_settings?: components["schemas"]["ModelDefaultSettings"] | null;
/**
* Trigger Phrases
* @description The model's list of trigger phrases