mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
use a listener to recalculate trigger phrases when model or lora list changes
This commit is contained in:
parent
caafbf2f0d
commit
ef171e890a
@ -55,6 +55,7 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addPromptTriggerListChanged } from './listeners/promptTriggerListChanged';
|
||||
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
@ -153,7 +154,8 @@ addFirstListImagesListener(startAppListening);
|
||||
// Ad-hoc upscale workflwo
|
||||
addUpscaleRequestedListener(startAppListening);
|
||||
|
||||
// Dynamic prompts
|
||||
// Prompts
|
||||
addDynamicPromptsListener(startAppListening);
|
||||
addPromptTriggerListChanged(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening);
|
||||
|
@ -0,0 +1,39 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { loraAdded, loraRemoved } from 'features/lora/store/loraSlice';
|
||||
import { modelChanged, triggerPhrasesChanged } from 'features/parameters/store/generationSlice';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
|
||||
const matcher = isAnyOf(loraAdded, loraRemoved, modelChanged);
|
||||
|
||||
export const addPromptTriggerListChanged = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher,
|
||||
effect: async (action, { dispatch, getState, cancelActiveListeners }) => {
|
||||
cancelActiveListeners();
|
||||
const state = getState();
|
||||
const { model: mainModel } = state.generation;
|
||||
const { loras } = state.lora;
|
||||
|
||||
let triggerPhrases: string[] = [];
|
||||
|
||||
if (!mainModel) {
|
||||
dispatch(triggerPhrasesChanged([]));
|
||||
return;
|
||||
}
|
||||
|
||||
const { data: mainModelData } = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(mainModel.key));
|
||||
triggerPhrases = mainModelData?.trigger_phrases || [];
|
||||
|
||||
for (let index = 0; index < Object.values(loras).length; index++) {
|
||||
const lora = Object.values(loras)[index];
|
||||
if (lora) {
|
||||
const { data: loraData } = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(lora.model.key));
|
||||
triggerPhrases = [...triggerPhrases, ...(loraData?.trigger_phrases || [])];
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(triggerPhrasesChanged(triggerPhrases));
|
||||
},
|
||||
});
|
||||
};
|
@ -56,6 +56,7 @@ const initialGenerationState: GenerationState = {
|
||||
shouldUseCpuNoise: true,
|
||||
shouldShowAdvancedOptions: false,
|
||||
aspectRatio: { ...initialAspectRatioState },
|
||||
triggerPhrases: [],
|
||||
};
|
||||
|
||||
export const generationSlice = createSlice({
|
||||
@ -207,6 +208,9 @@ export const generationSlice = createSlice({
|
||||
aspectRatioChanged: (state, action: PayloadAction<AspectRatioState>) => {
|
||||
state.aspectRatio = action.payload;
|
||||
},
|
||||
triggerPhrasesChanged: (state, action: PayloadAction<string[]>) => {
|
||||
state.triggerPhrases = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(configChanged, (state, action) => {
|
||||
@ -285,6 +289,7 @@ export const {
|
||||
heightChanged,
|
||||
widthRecalled,
|
||||
heightRecalled,
|
||||
triggerPhrasesChanged,
|
||||
} = generationSlice.actions;
|
||||
|
||||
export const { selectOptimalDimension } = generationSlice.selectors;
|
||||
|
@ -51,6 +51,7 @@ export interface GenerationState {
|
||||
shouldUseCpuNoise: boolean;
|
||||
shouldShowAdvancedOptions: boolean;
|
||||
aspectRatio: AspectRatioState;
|
||||
triggerPhrases: string[];
|
||||
}
|
||||
|
||||
export type PayloadActionWithOptimalDimension<T = void> = PayloadAction<T, string, { optimalDimension: number }>;
|
||||
|
@ -1,13 +1,13 @@
|
||||
import type { ChakraProps, ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { GroupBase, OptionsOrGroups } from 'chakra-react-select';
|
||||
import type { PromptTriggerSelectProps } from 'features/prompt/types';
|
||||
import { t } from 'i18next';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelMetadataQuery, useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
|
||||
|
||||
@ -15,10 +15,9 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
||||
const { t } = useTranslation();
|
||||
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
const currentModelKey = useAppSelector((s) => s.generation.model?.key);
|
||||
const triggerPhrases = useAppSelector((s) => s.generation.triggerPhrases);
|
||||
|
||||
const { data, isLoading } = useGetTextualInversionModelsQuery();
|
||||
const { data: metadata } = useGetModelMetadataQuery(currentModelKey ?? skipToken);
|
||||
|
||||
const _onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
@ -32,34 +31,28 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
||||
[onSelect]
|
||||
);
|
||||
|
||||
const embeddingOptions = useMemo(() => {
|
||||
if (!data) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const compatibleEmbeddingsArray = map(data.entities).filter((model) => model.base === currentBaseModel);
|
||||
|
||||
return [
|
||||
{
|
||||
label: t('prompt.compatibleEmbeddings'),
|
||||
options: compatibleEmbeddingsArray.map((model) => ({ label: model.name, value: `<${model.name}>` })),
|
||||
},
|
||||
];
|
||||
}, [data, currentBaseModel, t]);
|
||||
|
||||
const options = useMemo(() => {
|
||||
if (!metadata || !metadata.trigger_phrases) {
|
||||
return [...embeddingOptions];
|
||||
let embeddingOptions: OptionsOrGroups<ComboboxOption, GroupBase<ComboboxOption>> = [];
|
||||
|
||||
if (data) {
|
||||
const compatibleEmbeddingsArray = map(data.entities).filter((model) => model.base === currentBaseModel);
|
||||
|
||||
embeddingOptions = [
|
||||
{
|
||||
label: t('prompt.compatibleEmbeddings'),
|
||||
options: compatibleEmbeddingsArray.map((model) => ({ label: model.name, value: `<${model.name}>` })),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
const metadataOptions = [
|
||||
{
|
||||
label: t('modelManager.triggerPhrases'),
|
||||
options: metadata.trigger_phrases.map((phrase) => ({ label: phrase, value: phrase })),
|
||||
options: triggerPhrases.map((phrase) => ({ label: phrase, value: phrase })),
|
||||
},
|
||||
];
|
||||
return [...metadataOptions, ...embeddingOptions];
|
||||
}, [embeddingOptions, metadata, t]);
|
||||
}, [data, currentBaseModel, triggerPhrases, t]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
|
@ -4135,6 +4135,22 @@ export type components = {
|
||||
*/
|
||||
type: "gradient_mask_output";
|
||||
};
|
||||
/**
|
||||
* GradientMaskOutput
|
||||
* @description Outputs a denoise mask and an image representing the total gradient of the mask.
|
||||
*/
|
||||
GradientMaskOutput: {
|
||||
/** @description Mask for denoise model run */
|
||||
denoise_mask: components["schemas"]["DenoiseMaskField"];
|
||||
/** @description Image representing the total gradient area of the mask. For paste-back purposes. */
|
||||
expanded_mask_area: components["schemas"]["ImageField"];
|
||||
/**
|
||||
* type
|
||||
* @default gradient_mask_output
|
||||
* @constant
|
||||
*/
|
||||
type: "gradient_mask_output";
|
||||
};
|
||||
/** Graph */
|
||||
Graph: {
|
||||
/**
|
||||
@ -7377,6 +7393,21 @@ export type components = {
|
||||
/** Cfg Rescale Multiplier */
|
||||
cfg_rescale_multiplier: number | null;
|
||||
};
|
||||
/** ModelDefaultSettings */
|
||||
ModelDefaultSettings: {
|
||||
/** Vae */
|
||||
vae: string | null;
|
||||
/** Vae Precision */
|
||||
vae_precision: string | null;
|
||||
/** Scheduler */
|
||||
scheduler: ("ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm") | null;
|
||||
/** Steps */
|
||||
steps: number | null;
|
||||
/** Cfg Scale */
|
||||
cfg_scale: number | null;
|
||||
/** Cfg Rescale Multiplier */
|
||||
cfg_rescale_multiplier: number | null;
|
||||
};
|
||||
/**
|
||||
* ModelFormat
|
||||
* @description Storage format of model.
|
||||
@ -7526,9 +7557,12 @@ export type components = {
|
||||
* @description A set of changes to apply to model metadata.
|
||||
*
|
||||
* Only limited changes are valid:
|
||||
* - `default_settings`: the user-configured default settings for this model
|
||||
* - `trigger_phrases`: the list of trigger phrases for this model
|
||||
*/
|
||||
ModelMetadataChanges: {
|
||||
/** @description The user-configured default settings for this model */
|
||||
default_settings?: components["schemas"]["ModelDefaultSettings"] | null;
|
||||
/**
|
||||
* Trigger Phrases
|
||||
* @description The model's list of trigger phrases
|
||||
|
Loading…
Reference in New Issue
Block a user