mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): sampler
--> scheduler
This commit is contained in:
parent
da87378713
commit
6fe62a2705
@ -450,7 +450,7 @@
|
|||||||
"cfgScale": "CFG Scale",
|
"cfgScale": "CFG Scale",
|
||||||
"width": "Width",
|
"width": "Width",
|
||||||
"height": "Height",
|
"height": "Height",
|
||||||
"sampler": "Sampler",
|
"scheduler": "Scheduler",
|
||||||
"seed": "Seed",
|
"seed": "Seed",
|
||||||
"imageToImage": "Image to Image",
|
"imageToImage": "Image to Image",
|
||||||
"randomizeSeed": "Randomize Seed",
|
"randomizeSeed": "Randomize Seed",
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
// TODO: use Enums?
|
// TODO: use Enums?
|
||||||
|
|
||||||
export const SCHEDULERS: Array<string> = [
|
export const SCHEDULERS = [
|
||||||
'ddim',
|
'ddim',
|
||||||
'lms',
|
'lms',
|
||||||
'euler',
|
'euler',
|
||||||
@ -17,7 +17,12 @@ export const SCHEDULERS: Array<string> = [
|
|||||||
'heun',
|
'heun',
|
||||||
'heun_k',
|
'heun_k',
|
||||||
'unipc',
|
'unipc',
|
||||||
];
|
] as const;
|
||||||
|
|
||||||
|
export type Scheduler = (typeof SCHEDULERS)[number];
|
||||||
|
|
||||||
|
export const isScheduler = (x: string): x is Scheduler =>
|
||||||
|
SCHEDULERS.includes(x as Scheduler);
|
||||||
|
|
||||||
// Valid image widths
|
// Valid image widths
|
||||||
export const WIDTHS: Array<number> = Array.from(Array(64)).map(
|
export const WIDTHS: Array<number> = Array.from(Array(64)).map(
|
||||||
|
@ -19,7 +19,7 @@ import {
|
|||||||
setHeight,
|
setHeight,
|
||||||
setImg2imgStrength,
|
setImg2imgStrength,
|
||||||
setPerlin,
|
setPerlin,
|
||||||
setSampler,
|
setScheduler,
|
||||||
setSeamless,
|
setSeamless,
|
||||||
setSeed,
|
setSeed,
|
||||||
setSeedWeights,
|
setSeedWeights,
|
||||||
@ -202,9 +202,9 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|||||||
)}
|
)}
|
||||||
{node.scheduler && (
|
{node.scheduler && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Sampler"
|
label="Scheduler"
|
||||||
value={node.scheduler}
|
value={node.scheduler}
|
||||||
onClick={() => dispatch(setSampler(node.scheduler))}
|
onClick={() => dispatch(setScheduler(node.scheduler))}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{node.steps && (
|
{node.steps && (
|
||||||
|
@ -26,16 +26,18 @@ const buildBaseNode = (
|
|||||||
| ImageToImageInvocation
|
| ImageToImageInvocation
|
||||||
| InpaintInvocation
|
| InpaintInvocation
|
||||||
| undefined => {
|
| undefined => {
|
||||||
|
const dimensionsOverride = state.canvas.boundingBoxDimensions;
|
||||||
|
|
||||||
if (nodeType === 'txt2img') {
|
if (nodeType === 'txt2img') {
|
||||||
return buildTxt2ImgNode(state, state.canvas.boundingBoxDimensions);
|
return buildTxt2ImgNode(state, dimensionsOverride);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (nodeType === 'img2img') {
|
if (nodeType === 'img2img') {
|
||||||
return buildImg2ImgNode(state, state.canvas.boundingBoxDimensions);
|
return buildImg2ImgNode(state, dimensionsOverride);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (nodeType === 'inpaint' || nodeType === 'outpaint') {
|
if (nodeType === 'inpaint' || nodeType === 'outpaint') {
|
||||||
return buildInpaintNode(state, state.canvas.boundingBoxDimensions);
|
return buildInpaintNode(state, dimensionsOverride);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ export const buildImg2ImgNode = (
|
|||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
cfgScale,
|
cfgScale,
|
||||||
sampler,
|
scheduler,
|
||||||
model,
|
model,
|
||||||
img2imgStrength: strength,
|
img2imgStrength: strength,
|
||||||
shouldFitToWidthHeight: fit,
|
shouldFitToWidthHeight: fit,
|
||||||
@ -43,14 +43,14 @@ export const buildImg2ImgNode = (
|
|||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
cfg_scale: cfgScale,
|
cfg_scale: cfgScale,
|
||||||
scheduler: sampler as ImageToImageInvocation['scheduler'],
|
scheduler,
|
||||||
model,
|
model,
|
||||||
strength,
|
strength,
|
||||||
fit,
|
fit,
|
||||||
};
|
};
|
||||||
|
|
||||||
// on Canvas tab, we do not manually specific init image
|
// on Canvas tab, we do not manually specific init image
|
||||||
if (activeTabName === 'img2img') {
|
if (activeTabName !== 'unifiedCanvas') {
|
||||||
if (!initialImage) {
|
if (!initialImage) {
|
||||||
// TODO: handle this more better
|
// TODO: handle this more better
|
||||||
throw 'no initial image';
|
throw 'no initial image';
|
||||||
|
@ -1,17 +1,16 @@
|
|||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { InpaintInvocation } from 'services/api';
|
import { InpaintInvocation } from 'services/api';
|
||||||
import { initialImageSelector } from 'features/parameters/store/generationSelectors';
|
|
||||||
import { O } from 'ts-toolbelt';
|
import { O } from 'ts-toolbelt';
|
||||||
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
|
|
||||||
export const buildInpaintNode = (
|
export const buildInpaintNode = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
overrides: O.Partial<InpaintInvocation, 'deep'> = {}
|
overrides: O.Partial<InpaintInvocation, 'deep'> = {}
|
||||||
): InpaintInvocation => {
|
): InpaintInvocation => {
|
||||||
const nodeId = uuidv4();
|
const nodeId = uuidv4();
|
||||||
const { generation, models } = state;
|
const { generation } = state;
|
||||||
|
const activeTabName = activeTabNameSelector(state);
|
||||||
const { selectedModelName } = models;
|
|
||||||
|
|
||||||
const {
|
const {
|
||||||
prompt,
|
prompt,
|
||||||
@ -21,21 +20,15 @@ export const buildInpaintNode = (
|
|||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
cfgScale,
|
cfgScale,
|
||||||
sampler,
|
scheduler,
|
||||||
seamless,
|
model,
|
||||||
img2imgStrength: strength,
|
img2imgStrength: strength,
|
||||||
shouldFitToWidthHeight: fit,
|
shouldFitToWidthHeight: fit,
|
||||||
shouldRandomizeSeed,
|
shouldRandomizeSeed,
|
||||||
|
initialImage,
|
||||||
} = generation;
|
} = generation;
|
||||||
|
|
||||||
const initialImage = initialImageSelector(state);
|
const inpaintNode: InpaintInvocation = {
|
||||||
|
|
||||||
if (!initialImage) {
|
|
||||||
// TODO: handle this
|
|
||||||
// throw 'no initial image';
|
|
||||||
}
|
|
||||||
|
|
||||||
const imageToImageNode: InpaintInvocation = {
|
|
||||||
id: nodeId,
|
id: nodeId,
|
||||||
type: 'inpaint',
|
type: 'inpaint',
|
||||||
prompt: `${prompt} [${negativePrompt}]`,
|
prompt: `${prompt} [${negativePrompt}]`,
|
||||||
@ -43,25 +36,30 @@ export const buildInpaintNode = (
|
|||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
cfg_scale: cfgScale,
|
cfg_scale: cfgScale,
|
||||||
scheduler: sampler as InpaintInvocation['scheduler'],
|
scheduler,
|
||||||
seamless,
|
model,
|
||||||
model: selectedModelName,
|
|
||||||
progress_images: true,
|
|
||||||
image: initialImage
|
|
||||||
? {
|
|
||||||
image_name: initialImage.name,
|
|
||||||
image_type: initialImage.type,
|
|
||||||
}
|
|
||||||
: undefined,
|
|
||||||
strength,
|
strength,
|
||||||
fit,
|
fit,
|
||||||
};
|
};
|
||||||
|
|
||||||
if (!shouldRandomizeSeed) {
|
// on Canvas tab, we do not manually specific init image
|
||||||
imageToImageNode.seed = seed;
|
if (activeTabName !== 'unifiedCanvas') {
|
||||||
|
if (!initialImage) {
|
||||||
|
// TODO: handle this more better
|
||||||
|
throw 'no initial image';
|
||||||
|
}
|
||||||
|
|
||||||
|
inpaintNode.image = {
|
||||||
|
image_name: initialImage.name,
|
||||||
|
image_type: initialImage.type,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
Object.assign(imageToImageNode, overrides);
|
if (!shouldRandomizeSeed) {
|
||||||
|
inpaintNode.seed = seed;
|
||||||
|
}
|
||||||
|
|
||||||
return imageToImageNode;
|
Object.assign(inpaintNode, overrides);
|
||||||
|
|
||||||
|
return inpaintNode;
|
||||||
};
|
};
|
||||||
|
@ -18,7 +18,7 @@ export const buildTxt2ImgNode = (
|
|||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
sampler,
|
scheduler,
|
||||||
shouldRandomizeSeed,
|
shouldRandomizeSeed,
|
||||||
model,
|
model,
|
||||||
} = generation;
|
} = generation;
|
||||||
@ -31,7 +31,7 @@ export const buildTxt2ImgNode = (
|
|||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
scheduler: sampler as TextToImageInvocation['scheduler'],
|
scheduler,
|
||||||
model,
|
model,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
|
import { Scheduler } from 'app/constants';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAICustomSelect from 'common/components/IAICustomSelect';
|
import IAICustomSelect from 'common/components/IAICustomSelect';
|
||||||
import IAISelect from 'common/components/IAISelect';
|
import { setScheduler } from 'features/parameters/store/generationSlice';
|
||||||
import { setSampler } from 'features/parameters/store/generationSlice';
|
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { ChangeEvent, memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
const ParamSampler = () => {
|
const ParamScheduler = () => {
|
||||||
const sampler = useAppSelector(
|
const scheduler = useAppSelector(
|
||||||
(state: RootState) => state.generation.sampler
|
(state: RootState) => state.generation.scheduler
|
||||||
);
|
);
|
||||||
|
|
||||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||||
@ -28,15 +28,15 @@ const ParamSampler = () => {
|
|||||||
if (!v) {
|
if (!v) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(setSampler(v));
|
dispatch(setScheduler(v as Scheduler));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAICustomSelect
|
<IAICustomSelect
|
||||||
label={t('parameters.sampler')}
|
label={t('parameters.scheduler')}
|
||||||
selectedItem={sampler}
|
selectedItem={scheduler}
|
||||||
setSelectedItem={handleChange}
|
setSelectedItem={handleChange}
|
||||||
items={
|
items={
|
||||||
['img2img', 'unifiedCanvas'].includes(activeTabName)
|
['img2img', 'unifiedCanvas'].includes(activeTabName)
|
||||||
@ -48,4 +48,4 @@ const ParamSampler = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(ParamSampler);
|
export default memo(ParamScheduler);
|
@ -1,13 +1,13 @@
|
|||||||
import { Box, Flex } from '@chakra-ui/react';
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import ParamSampler from './ParamSampler';
|
|
||||||
import ModelSelect from 'features/system/components/ModelSelect';
|
import ModelSelect from 'features/system/components/ModelSelect';
|
||||||
|
import ParamScheduler from './ParamScheduler';
|
||||||
|
|
||||||
const ParamSchedulerAndModel = () => {
|
const ParamSchedulerAndModel = () => {
|
||||||
return (
|
return (
|
||||||
<Flex gap={3} w="full">
|
<Flex gap={3} w="full">
|
||||||
<Box w="16rem">
|
<Box w="16rem">
|
||||||
<ParamSampler />
|
<ParamScheduler />
|
||||||
</Box>
|
</Box>
|
||||||
<Box w="full">
|
<Box w="full">
|
||||||
<ModelSelect />
|
<ModelSelect />
|
||||||
|
@ -5,6 +5,7 @@ import promptToString from 'common/util/promptToString';
|
|||||||
import { clamp, sample } from 'lodash-es';
|
import { clamp, sample } from 'lodash-es';
|
||||||
import { setAllParametersReducer } from './setAllParametersReducer';
|
import { setAllParametersReducer } from './setAllParametersReducer';
|
||||||
import { receivedModels } from 'services/thunks/model';
|
import { receivedModels } from 'services/thunks/model';
|
||||||
|
import { Scheduler } from 'app/constants';
|
||||||
|
|
||||||
export interface GenerationState {
|
export interface GenerationState {
|
||||||
cfgScale: number;
|
cfgScale: number;
|
||||||
@ -16,7 +17,7 @@ export interface GenerationState {
|
|||||||
perlin: number;
|
perlin: number;
|
||||||
prompt: string;
|
prompt: string;
|
||||||
negativePrompt: string;
|
negativePrompt: string;
|
||||||
sampler: string;
|
scheduler: Scheduler;
|
||||||
seamBlur: number;
|
seamBlur: number;
|
||||||
seamSize: number;
|
seamSize: number;
|
||||||
seamSteps: number;
|
seamSteps: number;
|
||||||
@ -50,7 +51,7 @@ export const initialGenerationState: GenerationState = {
|
|||||||
perlin: 0,
|
perlin: 0,
|
||||||
prompt: '',
|
prompt: '',
|
||||||
negativePrompt: '',
|
negativePrompt: '',
|
||||||
sampler: 'lms',
|
scheduler: 'lms',
|
||||||
seamBlur: 16,
|
seamBlur: 16,
|
||||||
seamSize: 96,
|
seamSize: 96,
|
||||||
seamSteps: 30,
|
seamSteps: 30,
|
||||||
@ -133,8 +134,8 @@ export const generationSlice = createSlice({
|
|||||||
setWidth: (state, action: PayloadAction<number>) => {
|
setWidth: (state, action: PayloadAction<number>) => {
|
||||||
state.width = action.payload;
|
state.width = action.payload;
|
||||||
},
|
},
|
||||||
setSampler: (state, action: PayloadAction<string>) => {
|
setScheduler: (state, action: PayloadAction<Scheduler>) => {
|
||||||
state.sampler = action.payload;
|
state.scheduler = action.payload;
|
||||||
},
|
},
|
||||||
setSeed: (state, action: PayloadAction<number>) => {
|
setSeed: (state, action: PayloadAction<number>) => {
|
||||||
state.seed = action.payload;
|
state.seed = action.payload;
|
||||||
@ -244,7 +245,7 @@ export const {
|
|||||||
setPerlin,
|
setPerlin,
|
||||||
setPrompt,
|
setPrompt,
|
||||||
setNegativePrompt,
|
setNegativePrompt,
|
||||||
setSampler,
|
setScheduler,
|
||||||
setSeamBlur,
|
setSeamBlur,
|
||||||
setSeamSize,
|
setSeamSize,
|
||||||
setSeamSteps,
|
setSeamSteps,
|
||||||
|
@ -2,6 +2,7 @@ import { Draft, PayloadAction } from '@reduxjs/toolkit';
|
|||||||
import { Image } from 'app/types/invokeai';
|
import { Image } from 'app/types/invokeai';
|
||||||
import { GenerationState } from './generationSlice';
|
import { GenerationState } from './generationSlice';
|
||||||
import { ImageToImageInvocation } from 'services/api';
|
import { ImageToImageInvocation } from 'services/api';
|
||||||
|
import { isScheduler } from 'app/constants';
|
||||||
|
|
||||||
export const setAllParametersReducer = (
|
export const setAllParametersReducer = (
|
||||||
state: Draft<GenerationState>,
|
state: Draft<GenerationState>,
|
||||||
@ -34,7 +35,10 @@ export const setAllParametersReducer = (
|
|||||||
state.prompt = String(prompt);
|
state.prompt = String(prompt);
|
||||||
}
|
}
|
||||||
if (scheduler !== undefined) {
|
if (scheduler !== undefined) {
|
||||||
state.sampler = String(scheduler);
|
const schedulerString = String(scheduler);
|
||||||
|
if (isScheduler(schedulerString)) {
|
||||||
|
state.scheduler = schedulerString;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (seed !== undefined) {
|
if (seed !== undefined) {
|
||||||
state.seed = Number(seed);
|
state.seed = Number(seed);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user