mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add button to set default settings in parameters
This commit is contained in:
parent
c46b2b6fa6
commit
f0bfa7f0e0
@ -854,6 +854,7 @@
|
||||
"upcastAttention": "Upcast Attention",
|
||||
"updateModel": "Update Model",
|
||||
"useCustomConfig": "Use Custom Config",
|
||||
"useDefaultSettings": "Use Default Settings",
|
||||
"v1": "v1",
|
||||
"v2_768": "v2 (768px)",
|
||||
"v2_base": "v2 (512px)",
|
||||
|
@ -55,6 +55,8 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
@ -153,3 +155,5 @@ addUpscaleRequestedListener(startAppListening);
|
||||
|
||||
// Dynamic prompts
|
||||
addDynamicPromptsListener(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening)
|
||||
|
@ -0,0 +1,88 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||
import { setCfgRescaleMultiplier, setCfgScale, setScheduler, setSteps, vaePrecisionChanged, vaeSelected } from 'features/parameters/store/generationSlice';
|
||||
import { isParameterCFGRescaleMultiplier, isParameterCFGScale, isParameterPrecision, isParameterScheduler, isParameterSteps, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { map } from 'lodash-es';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
|
||||
|
||||
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: setDefaultSettings,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
|
||||
const currentModel = state.generation.model;
|
||||
|
||||
if (!currentModel) {
|
||||
return
|
||||
}
|
||||
|
||||
const metadata = await dispatch(
|
||||
modelsApi.endpoints.getModelMetadata.initiate(currentModel.key)
|
||||
).unwrap();
|
||||
|
||||
console.log({ metadata })
|
||||
|
||||
|
||||
if (!metadata || !metadata.default_settings) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = metadata.default_settings
|
||||
|
||||
if (vae) {
|
||||
// we store this as "default" within default settings
|
||||
// to distinguish it from no default set
|
||||
if (vae === "default") {
|
||||
dispatch(vaeSelected(null))
|
||||
} else {
|
||||
const { data } = modelsApi.endpoints.getVaeModels.select()(state)
|
||||
const vaeArray = map(data?.entities)
|
||||
const validVae = vaeArray.find(model => model.key === vae)
|
||||
|
||||
const result = zParameterVAEModel.safeParse(validVae);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
dispatch(vaeSelected(result.data));
|
||||
}
|
||||
}
|
||||
|
||||
if (vae_precision) {
|
||||
if (isParameterPrecision(vae_precision)) {
|
||||
dispatch(vaePrecisionChanged(vae_precision));
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg_scale) {
|
||||
if (isParameterCFGScale(cfg_scale)) {
|
||||
dispatch(setCfgScale(cfg_scale));
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg_rescale_multiplier) {
|
||||
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
|
||||
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
||||
}
|
||||
}
|
||||
|
||||
if (steps) {
|
||||
if (isParameterSteps(steps)) {
|
||||
dispatch(setSteps(steps));
|
||||
}
|
||||
}
|
||||
|
||||
if (scheduler) {
|
||||
if (isParameterScheduler(scheduler)) {
|
||||
dispatch(setScheduler(scheduler));
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: "Default settings" }) })))
|
||||
},
|
||||
});
|
||||
};
|
@ -90,7 +90,7 @@ export const DefaultSettingsForm = ({
|
||||
}
|
||||
});
|
||||
},
|
||||
[selectedModelKey, dispatch, editModelMetadata]
|
||||
[selectedModelKey, dispatch, editModelMetadata, t]
|
||||
);
|
||||
|
||||
return (
|
||||
|
@ -0,0 +1,28 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { RiSparklingFill } from 'react-icons/ri';
|
||||
|
||||
export const UseDefaultSettingsButton = () => {
|
||||
const model = useAppSelector((s) => s.generation.model);
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleClickDefaultSettings = useCallback(() => {
|
||||
dispatch(setDefaultSettings());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
icon={<RiSparklingFill />}
|
||||
tooltip={t('modelManager.useDefaultSettings')}
|
||||
aria-label={t('modelManager.useDefaultSettings')}
|
||||
isDisabled={!model}
|
||||
onClick={handleClickDefaultSettings}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
/>
|
||||
);
|
||||
};
|
@ -5,3 +5,5 @@ import type { ImageDTO } from 'services/api/types';
|
||||
export const initialImageSelected = createAction<ImageDTO | undefined>('generation/initialImageSelected');
|
||||
|
||||
export const modelSelected = createAction<ParameterModel>('generation/modelSelected');
|
||||
|
||||
export const setDefaultSettings = createAction('generation/setDefaultSettings');
|
@ -21,6 +21,7 @@ import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
|
||||
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
|
||||
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
|
||||
import ParamMainModelSelect from 'features/parameters/components/MainModel/ParamMainModelSelect';
|
||||
import { UseDefaultSettingsButton } from 'features/parameters/components/MainModel/UseDefaultSettingsButton';
|
||||
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||
import { filter } from 'lodash-es';
|
||||
@ -71,8 +72,11 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
<TabPanel overflow="visible" px={4} pt={4}>
|
||||
<Flex gap={4} alignItems="center">
|
||||
<ParamMainModelSelect />
|
||||
<Flex>
|
||||
<UseDefaultSettingsButton />
|
||||
<SyncModelsIconButton />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Expander isOpen={isOpenExpander} onToggle={onToggleExpander}>
|
||||
<Flex gap={4} flexDir="column" pb={4}>
|
||||
<FormControlGroup formLabelProps={formLabelProps}>
|
||||
|
Loading…
Reference in New Issue
Block a user