feat(ui): add w/h to default model settings

This commit is contained in:
psychedelicious 2024-03-12 20:45:49 +11:00
parent 1adaf63253
commit 2584a950aa
5 changed files with 226 additions and 11 deletions

View File

@ -1,19 +1,23 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { setDefaultSettings } from 'features/parameters/store/actions'; import { setDefaultSettings } from 'features/parameters/store/actions';
import { import {
heightChanged,
setCfgRescaleMultiplier, setCfgRescaleMultiplier,
setCfgScale, setCfgScale,
setScheduler, setScheduler,
setSteps, setSteps,
vaePrecisionChanged, vaePrecisionChanged,
vaeSelected, vaeSelected,
widthChanged,
} from 'features/parameters/store/generationSlice'; } from 'features/parameters/store/generationSlice';
import { import {
isParameterCFGRescaleMultiplier, isParameterCFGRescaleMultiplier,
isParameterCFGScale, isParameterCFGScale,
isParameterHeight,
isParameterPrecision, isParameterPrecision,
isParameterScheduler, isParameterScheduler,
isParameterSteps, isParameterSteps,
isParameterWidth,
zParameterVAEModel, zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
@ -42,7 +46,7 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
} }
if (isNonRefinerMainModelConfig(modelConfig) && modelConfig.default_settings) { if (isNonRefinerMainModelConfig(modelConfig) && modelConfig.default_settings) {
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler, width, height } =
modelConfig.default_settings; modelConfig.default_settings;
if (vae) { if (vae) {
@ -93,6 +97,18 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
} }
} }
if (width) {
if (isParameterWidth(width)) {
dispatch(widthChanged(width));
}
}
if (height) {
if (isParameterHeight(height)) {
dispatch(heightChanged(height));
}
}
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) }))); dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
} }
}, },

View File

@ -1,6 +1,7 @@
import { skipToken } from '@reduxjs/toolkit/query'; import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { selectConfigSlice } from 'features/system/store/configSlice'; import { selectConfigSlice } from 'features/system/store/configSlice';
import { isNil } from 'lodash-es'; import { isNil } from 'lodash-es';
import { useMemo } from 'react'; import { useMemo } from 'react';
@ -8,7 +9,7 @@ import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelCo
import { isNonRefinerMainModelConfig } from 'services/api/types'; import { isNonRefinerMainModelConfig } from 'services/api/types';
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => { const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision } = config.sd; const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision, width, height } = config.sd;
return { return {
initialSteps: steps.initial, initialSteps: steps.initial,
@ -16,14 +17,23 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
initialScheduler: scheduler, initialScheduler: scheduler,
initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial, initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial,
initialVaePrecision: vaePrecision, initialVaePrecision: vaePrecision,
initialWidth: width.initial,
initialHeight: height.initial,
}; };
}); });
export const useMainModelDefaultSettings = (modelKey?: string | null) => { export const useMainModelDefaultSettings = (modelKey?: string | null) => {
const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(modelKey ?? skipToken, isNonRefinerMainModelConfig); const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(modelKey ?? skipToken, isNonRefinerMainModelConfig);
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } = const {
useAppSelector(initialStatesSelector); initialSteps,
initialCfg,
initialScheduler,
initialCfgRescaleMultiplier,
initialVaePrecision,
initialWidth,
initialHeight,
} = useAppSelector(initialStatesSelector);
const defaultSettingsDefaults = useMemo(() => { const defaultSettingsDefaults = useMemo(() => {
return { return {
@ -51,15 +61,25 @@ export const useMainModelDefaultSettings = (modelKey?: string | null) => {
isEnabled: !isNil(modelConfig?.default_settings?.cfg_rescale_multiplier), isEnabled: !isNil(modelConfig?.default_settings?.cfg_rescale_multiplier),
value: modelConfig?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier, value: modelConfig?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier,
}, },
width: {
isEnabled: !isNil(modelConfig?.default_settings?.width),
value: modelConfig?.default_settings?.width || initialWidth,
},
height: {
isEnabled: !isNil(modelConfig?.default_settings?.height),
value: modelConfig?.default_settings?.height || initialHeight,
},
}; };
}, [ }, [
modelConfig?.default_settings, modelConfig,
initialVaePrecision,
initialScheduler,
initialSteps, initialSteps,
initialCfg, initialCfg,
initialScheduler,
initialCfgRescaleMultiplier, initialCfgRescaleMultiplier,
initialVaePrecision, initialWidth,
initialHeight,
]); ]);
return { defaultSettingsDefaults, isLoading }; return { defaultSettingsDefaults, isLoading, optimalDimension: getOptimalDimension(modelConfig) };
}; };

View File

@ -0,0 +1,81 @@
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings';
type DefaultHeightType = MainModelDefaultSettingsFormData['height'];
type Props = {
control: UseControllerProps<MainModelDefaultSettingsFormData>['control'];
optimalDimension: number;
};
export function DefaultHeight({ control, optimalDimension }: Props) {
const { field } = useController({ control, name: 'height' });
const sliderMin = useAppSelector((s) => s.config.sd.height.sliderMin);
const sliderMax = useAppSelector((s) => s.config.sd.height.sliderMax);
const numberInputMin = useAppSelector((s) => s.config.sd.height.numberInputMin);
const numberInputMax = useAppSelector((s) => s.config.sd.height.numberInputMax);
const coarseStep = useAppSelector((s) => s.config.sd.height.coarseStep);
const fineStep = useAppSelector((s) => s.config.sd.height.fineStep);
const { t } = useTranslation();
const marks = useMemo(() => [sliderMin, optimalDimension, sliderMax], [sliderMin, optimalDimension, sliderMax]);
const onChange = useCallback(
(v: number) => {
const updatedValue = {
...(field.value as DefaultHeightType),
value: v,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => {
return field.value.value;
}, [field.value]);
const isDisabled = useMemo(() => {
return !field.value.isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={2} alignItems="flex-start">
<Flex justifyContent="space-between" w="full">
<InformationalPopover feature="paramHeight">
<FormLabel>{t('parameters.height')}</FormLabel>
</InformationalPopover>
<SettingToggle control={control} name="height" />
</Flex>
<Flex w="full" gap={4}>
<CompositeSlider
value={value}
min={sliderMin}
max={sliderMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
marks={marks}
isDisabled={isDisabled}
/>
<CompositeNumberInput
value={value}
min={numberInputMin}
max={numberInputMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
}

View File

@ -0,0 +1,81 @@
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings';
type DefaultWidthType = MainModelDefaultSettingsFormData['width'];
type Props = {
control: UseControllerProps<MainModelDefaultSettingsFormData>['control'];
optimalDimension: number;
};
export function DefaultWidth({ control, optimalDimension }: Props) {
const { field } = useController({ control, name: 'width' });
const sliderMin = useAppSelector((s) => s.config.sd.width.sliderMin);
const sliderMax = useAppSelector((s) => s.config.sd.width.sliderMax);
const numberInputMin = useAppSelector((s) => s.config.sd.width.numberInputMin);
const numberInputMax = useAppSelector((s) => s.config.sd.width.numberInputMax);
const coarseStep = useAppSelector((s) => s.config.sd.width.coarseStep);
const fineStep = useAppSelector((s) => s.config.sd.width.fineStep);
const { t } = useTranslation();
const marks = useMemo(() => [sliderMin, optimalDimension, sliderMax], [sliderMin, optimalDimension, sliderMax]);
const onChange = useCallback(
(v: number) => {
const updatedValue = {
...(field.value as DefaultWidthType),
value: v,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => {
return (field.value as DefaultWidthType).value;
}, [field.value]);
const isDisabled = useMemo(() => {
return !(field.value as DefaultWidthType).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={2} alignItems="flex-start">
<Flex justifyContent="space-between" w="full">
<InformationalPopover feature="paramWidth">
<FormLabel>{t('parameters.width')}</FormLabel>
</InformationalPopover>
<SettingToggle control={control} name="width" />
</Flex>
<Flex w="full" gap={4}>
<CompositeSlider
value={value}
min={sliderMin}
max={sliderMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
marks={marks}
isDisabled={isDisabled}
/>
<CompositeNumberInput
value={value}
min={numberInputMin}
max={numberInputMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
}

View File

@ -1,6 +1,8 @@
import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library'; import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings'; import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings';
import { DefaultHeight } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultHeight';
import { DefaultWidth } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultWidth';
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas'; import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
@ -25,11 +27,13 @@ export interface FormField<T> {
export type MainModelDefaultSettingsFormData = { export type MainModelDefaultSettingsFormData = {
vae: FormField<string>; vae: FormField<string>;
vaePrecision: FormField<string>; vaePrecision: FormField<'fp16' | 'fp32'>;
scheduler: FormField<ParameterScheduler>; scheduler: FormField<ParameterScheduler>;
steps: FormField<number>; steps: FormField<number>;
cfgScale: FormField<number>; cfgScale: FormField<number>;
cfgRescaleMultiplier: FormField<number>; cfgRescaleMultiplier: FormField<number>;
width: FormField<number>;
height: FormField<number>;
}; };
export const MainModelDefaultSettings = () => { export const MainModelDefaultSettings = () => {
@ -37,8 +41,11 @@ export const MainModelDefaultSettings = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { defaultSettingsDefaults, isLoading: isLoadingDefaultSettings } = const {
useMainModelDefaultSettings(selectedModelKey); defaultSettingsDefaults,
isLoading: isLoadingDefaultSettings,
optimalDimension,
} = useMainModelDefaultSettings(selectedModelKey);
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation(); const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
@ -59,6 +66,8 @@ export const MainModelDefaultSettings = () => {
cfg_rescale_multiplier: data.cfgRescaleMultiplier.isEnabled ? data.cfgRescaleMultiplier.value : null, cfg_rescale_multiplier: data.cfgRescaleMultiplier.isEnabled ? data.cfgRescaleMultiplier.value : null,
steps: data.steps.isEnabled ? data.steps.value : null, steps: data.steps.isEnabled ? data.steps.value : null,
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null, scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
width: data.width.isEnabled ? data.width.value : null,
height: data.height.isEnabled ? data.height.value : null,
}; };
updateModel({ updateModel({
@ -139,6 +148,14 @@ export const MainModelDefaultSettings = () => {
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" /> <DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
</Flex> </Flex>
</Flex> </Flex>
<Flex gap={8}>
<Flex gap={4} w="full">
<DefaultWidth control={control} optimalDimension={optimalDimension} />
</Flex>
<Flex gap={4} w="full">
<DefaultHeight control={control} optimalDimension={optimalDimension} />
</Flex>
</Flex>
</Flex> </Flex>
</> </>
); );