feat(ui): sampler --> scheduler

This commit is contained in:
psychedelicious 2023-05-16 10:40:26 +10:00
parent da87378713
commit 6fe62a2705
11 changed files with 70 additions and 60 deletions

View File

@ -450,7 +450,7 @@
"cfgScale": "CFG Scale", "cfgScale": "CFG Scale",
"width": "Width", "width": "Width",
"height": "Height", "height": "Height",
"sampler": "Sampler", "scheduler": "Scheduler",
"seed": "Seed", "seed": "Seed",
"imageToImage": "Image to Image", "imageToImage": "Image to Image",
"randomizeSeed": "Randomize Seed", "randomizeSeed": "Randomize Seed",

View File

@ -1,6 +1,6 @@
// TODO: use Enums? // TODO: use Enums?
export const SCHEDULERS: Array<string> = [ export const SCHEDULERS = [
'ddim', 'ddim',
'lms', 'lms',
'euler', 'euler',
@ -17,7 +17,12 @@ export const SCHEDULERS: Array<string> = [
'heun', 'heun',
'heun_k', 'heun_k',
'unipc', 'unipc',
]; ] as const;
export type Scheduler = (typeof SCHEDULERS)[number];
export const isScheduler = (x: string): x is Scheduler =>
SCHEDULERS.includes(x as Scheduler);
// Valid image widths // Valid image widths
export const WIDTHS: Array<number> = Array.from(Array(64)).map( export const WIDTHS: Array<number> = Array.from(Array(64)).map(

View File

@ -19,7 +19,7 @@ import {
setHeight, setHeight,
setImg2imgStrength, setImg2imgStrength,
setPerlin, setPerlin,
setSampler, setScheduler,
setSeamless, setSeamless,
setSeed, setSeed,
setSeedWeights, setSeedWeights,
@ -202,9 +202,9 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
)} )}
{node.scheduler && ( {node.scheduler && (
<MetadataItem <MetadataItem
label="Sampler" label="Scheduler"
value={node.scheduler} value={node.scheduler}
onClick={() => dispatch(setSampler(node.scheduler))} onClick={() => dispatch(setScheduler(node.scheduler))}
/> />
)} )}
{node.steps && ( {node.steps && (

View File

@ -26,16 +26,18 @@ const buildBaseNode = (
| ImageToImageInvocation | ImageToImageInvocation
| InpaintInvocation | InpaintInvocation
| undefined => { | undefined => {
const dimensionsOverride = state.canvas.boundingBoxDimensions;
if (nodeType === 'txt2img') { if (nodeType === 'txt2img') {
return buildTxt2ImgNode(state, state.canvas.boundingBoxDimensions); return buildTxt2ImgNode(state, dimensionsOverride);
} }
if (nodeType === 'img2img') { if (nodeType === 'img2img') {
return buildImg2ImgNode(state, state.canvas.boundingBoxDimensions); return buildImg2ImgNode(state, dimensionsOverride);
} }
if (nodeType === 'inpaint' || nodeType === 'outpaint') { if (nodeType === 'inpaint' || nodeType === 'outpaint') {
return buildInpaintNode(state, state.canvas.boundingBoxDimensions); return buildInpaintNode(state, dimensionsOverride);
} }
}; };

View File

@ -25,7 +25,7 @@ export const buildImg2ImgNode = (
width, width,
height, height,
cfgScale, cfgScale,
sampler, scheduler,
model, model,
img2imgStrength: strength, img2imgStrength: strength,
shouldFitToWidthHeight: fit, shouldFitToWidthHeight: fit,
@ -43,14 +43,14 @@ export const buildImg2ImgNode = (
width, width,
height, height,
cfg_scale: cfgScale, cfg_scale: cfgScale,
scheduler: sampler as ImageToImageInvocation['scheduler'], scheduler,
model, model,
strength, strength,
fit, fit,
}; };
// on Canvas tab, we do not manually specific init image // on Canvas tab, we do not manually specific init image
if (activeTabName === 'img2img') { if (activeTabName !== 'unifiedCanvas') {
if (!initialImage) { if (!initialImage) {
// TODO: handle this more better // TODO: handle this more better
throw 'no initial image'; throw 'no initial image';

View File

@ -1,17 +1,16 @@
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { InpaintInvocation } from 'services/api'; import { InpaintInvocation } from 'services/api';
import { initialImageSelector } from 'features/parameters/store/generationSelectors';
import { O } from 'ts-toolbelt'; import { O } from 'ts-toolbelt';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
export const buildInpaintNode = ( export const buildInpaintNode = (
state: RootState, state: RootState,
overrides: O.Partial<InpaintInvocation, 'deep'> = {} overrides: O.Partial<InpaintInvocation, 'deep'> = {}
): InpaintInvocation => { ): InpaintInvocation => {
const nodeId = uuidv4(); const nodeId = uuidv4();
const { generation, models } = state; const { generation } = state;
const activeTabName = activeTabNameSelector(state);
const { selectedModelName } = models;
const { const {
prompt, prompt,
@ -21,21 +20,15 @@ export const buildInpaintNode = (
width, width,
height, height,
cfgScale, cfgScale,
sampler, scheduler,
seamless, model,
img2imgStrength: strength, img2imgStrength: strength,
shouldFitToWidthHeight: fit, shouldFitToWidthHeight: fit,
shouldRandomizeSeed, shouldRandomizeSeed,
initialImage,
} = generation; } = generation;
const initialImage = initialImageSelector(state); const inpaintNode: InpaintInvocation = {
if (!initialImage) {
// TODO: handle this
// throw 'no initial image';
}
const imageToImageNode: InpaintInvocation = {
id: nodeId, id: nodeId,
type: 'inpaint', type: 'inpaint',
prompt: `${prompt} [${negativePrompt}]`, prompt: `${prompt} [${negativePrompt}]`,
@ -43,25 +36,30 @@ export const buildInpaintNode = (
width, width,
height, height,
cfg_scale: cfgScale, cfg_scale: cfgScale,
scheduler: sampler as InpaintInvocation['scheduler'], scheduler,
seamless, model,
model: selectedModelName,
progress_images: true,
image: initialImage
? {
image_name: initialImage.name,
image_type: initialImage.type,
}
: undefined,
strength, strength,
fit, fit,
}; };
if (!shouldRandomizeSeed) { // on Canvas tab, we do not manually specific init image
imageToImageNode.seed = seed; if (activeTabName !== 'unifiedCanvas') {
if (!initialImage) {
// TODO: handle this more better
throw 'no initial image';
}
inpaintNode.image = {
image_name: initialImage.name,
image_type: initialImage.type,
};
} }
Object.assign(imageToImageNode, overrides); if (!shouldRandomizeSeed) {
inpaintNode.seed = seed;
}
return imageToImageNode; Object.assign(inpaintNode, overrides);
return inpaintNode;
}; };

View File

@ -18,7 +18,7 @@ export const buildTxt2ImgNode = (
width, width,
height, height,
cfgScale: cfg_scale, cfgScale: cfg_scale,
sampler, scheduler,
shouldRandomizeSeed, shouldRandomizeSeed,
model, model,
} = generation; } = generation;
@ -31,7 +31,7 @@ export const buildTxt2ImgNode = (
width, width,
height, height,
cfg_scale, cfg_scale,
scheduler: sampler as TextToImageInvocation['scheduler'], scheduler,
model, model,
}; };

View File

@ -1,15 +1,15 @@
import { Scheduler } from 'app/constants';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAICustomSelect from 'common/components/IAICustomSelect'; import IAICustomSelect from 'common/components/IAICustomSelect';
import IAISelect from 'common/components/IAISelect'; import { setScheduler } from 'features/parameters/store/generationSlice';
import { setSampler } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { ChangeEvent, memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
const ParamSampler = () => { const ParamScheduler = () => {
const sampler = useAppSelector( const scheduler = useAppSelector(
(state: RootState) => state.generation.sampler (state: RootState) => state.generation.scheduler
); );
const activeTabName = useAppSelector(activeTabNameSelector); const activeTabName = useAppSelector(activeTabNameSelector);
@ -28,15 +28,15 @@ const ParamSampler = () => {
if (!v) { if (!v) {
return; return;
} }
dispatch(setSampler(v)); dispatch(setScheduler(v as Scheduler));
}, },
[dispatch] [dispatch]
); );
return ( return (
<IAICustomSelect <IAICustomSelect
label={t('parameters.sampler')} label={t('parameters.scheduler')}
selectedItem={sampler} selectedItem={scheduler}
setSelectedItem={handleChange} setSelectedItem={handleChange}
items={ items={
['img2img', 'unifiedCanvas'].includes(activeTabName) ['img2img', 'unifiedCanvas'].includes(activeTabName)
@ -48,4 +48,4 @@ const ParamSampler = () => {
); );
}; };
export default memo(ParamSampler); export default memo(ParamScheduler);

View File

@ -1,13 +1,13 @@
import { Box, Flex } from '@chakra-ui/react'; import { Box, Flex } from '@chakra-ui/react';
import { memo } from 'react'; import { memo } from 'react';
import ParamSampler from './ParamSampler';
import ModelSelect from 'features/system/components/ModelSelect'; import ModelSelect from 'features/system/components/ModelSelect';
import ParamScheduler from './ParamScheduler';
const ParamSchedulerAndModel = () => { const ParamSchedulerAndModel = () => {
return ( return (
<Flex gap={3} w="full"> <Flex gap={3} w="full">
<Box w="16rem"> <Box w="16rem">
<ParamSampler /> <ParamScheduler />
</Box> </Box>
<Box w="full"> <Box w="full">
<ModelSelect /> <ModelSelect />

View File

@ -5,6 +5,7 @@ import promptToString from 'common/util/promptToString';
import { clamp, sample } from 'lodash-es'; import { clamp, sample } from 'lodash-es';
import { setAllParametersReducer } from './setAllParametersReducer'; import { setAllParametersReducer } from './setAllParametersReducer';
import { receivedModels } from 'services/thunks/model'; import { receivedModels } from 'services/thunks/model';
import { Scheduler } from 'app/constants';
export interface GenerationState { export interface GenerationState {
cfgScale: number; cfgScale: number;
@ -16,7 +17,7 @@ export interface GenerationState {
perlin: number; perlin: number;
prompt: string; prompt: string;
negativePrompt: string; negativePrompt: string;
sampler: string; scheduler: Scheduler;
seamBlur: number; seamBlur: number;
seamSize: number; seamSize: number;
seamSteps: number; seamSteps: number;
@ -50,7 +51,7 @@ export const initialGenerationState: GenerationState = {
perlin: 0, perlin: 0,
prompt: '', prompt: '',
negativePrompt: '', negativePrompt: '',
sampler: 'lms', scheduler: 'lms',
seamBlur: 16, seamBlur: 16,
seamSize: 96, seamSize: 96,
seamSteps: 30, seamSteps: 30,
@ -133,8 +134,8 @@ export const generationSlice = createSlice({
setWidth: (state, action: PayloadAction<number>) => { setWidth: (state, action: PayloadAction<number>) => {
state.width = action.payload; state.width = action.payload;
}, },
setSampler: (state, action: PayloadAction<string>) => { setScheduler: (state, action: PayloadAction<Scheduler>) => {
state.sampler = action.payload; state.scheduler = action.payload;
}, },
setSeed: (state, action: PayloadAction<number>) => { setSeed: (state, action: PayloadAction<number>) => {
state.seed = action.payload; state.seed = action.payload;
@ -244,7 +245,7 @@ export const {
setPerlin, setPerlin,
setPrompt, setPrompt,
setNegativePrompt, setNegativePrompt,
setSampler, setScheduler,
setSeamBlur, setSeamBlur,
setSeamSize, setSeamSize,
setSeamSteps, setSeamSteps,

View File

@ -2,6 +2,7 @@ import { Draft, PayloadAction } from '@reduxjs/toolkit';
import { Image } from 'app/types/invokeai'; import { Image } from 'app/types/invokeai';
import { GenerationState } from './generationSlice'; import { GenerationState } from './generationSlice';
import { ImageToImageInvocation } from 'services/api'; import { ImageToImageInvocation } from 'services/api';
import { isScheduler } from 'app/constants';
export const setAllParametersReducer = ( export const setAllParametersReducer = (
state: Draft<GenerationState>, state: Draft<GenerationState>,
@ -34,7 +35,10 @@ export const setAllParametersReducer = (
state.prompt = String(prompt); state.prompt = String(prompt);
} }
if (scheduler !== undefined) { if (scheduler !== undefined) {
state.sampler = String(scheduler); const schedulerString = String(scheduler);
if (isScheduler(schedulerString)) {
state.scheduler = schedulerString;
}
} }
if (seed !== undefined) { if (seed !== undefined) {
state.seed = Number(seed); state.seed = Number(seed);