Merge branch 'main' into patch-1

This commit is contained in:
Lincoln Stein 2023-02-06 12:54:07 -05:00 committed by GitHub
commit 13474e985b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 269 additions and 103 deletions

View File

@ -249,6 +249,7 @@ class InvokeAiInstance:
"--require-virtualenv", "--require-virtualenv",
"torch", "torch",
"torchvision", "torchvision",
"--force-reinstall",
"--find-links" if find_links is not None else None, "--find-links" if find_links is not None else None,
find_links, find_links,
"--extra-index-url" if extra_index_url is not None else None, "--extra-index-url" if extra_index_url is not None else None,
@ -325,6 +326,7 @@ class InvokeAiInstance:
Configure the InvokeAI runtime directory Configure the InvokeAI runtime directory
""" """
# set sys.argv to a consistent state
new_argv = [sys.argv[0]] new_argv = [sys.argv[0]]
for i in range(1,len(sys.argv)): for i in range(1,len(sys.argv)):
el = sys.argv[i] el = sys.argv[i]
@ -344,9 +346,6 @@ class InvokeAiInstance:
# NOTE: currently the config script does its own arg parsing! this means the command-line switches # NOTE: currently the config script does its own arg parsing! this means the command-line switches
# from the installer will also automatically propagate down to the config script. # from the installer will also automatically propagate down to the config script.
# this may change in the future with config refactoring! # this may change in the future with config refactoring!
# set sys.argv to a consistent state
invokeai_configure.main() invokeai_configure.main()
def install_user_scripts(self): def install_user_scripts(self):

View File

@ -1208,12 +1208,18 @@ class InvokeAIWebServer:
) )
except KeyboardInterrupt: except KeyboardInterrupt:
# Clear the CUDA cache on an exception
self.empty_cuda_cache()
self.socketio.emit("processingCanceled") self.socketio.emit("processingCanceled")
raise raise
except CanceledException: except CanceledException:
# Clear the CUDA cache on an exception
self.empty_cuda_cache()
self.socketio.emit("processingCanceled") self.socketio.emit("processingCanceled")
pass pass
except Exception as e: except Exception as e:
# Clear the CUDA cache on an exception
self.empty_cuda_cache()
print(e) print(e)
self.socketio.emit("error", {"message": (str(e))}) self.socketio.emit("error", {"message": (str(e))})
print("\n") print("\n")
@ -1221,6 +1227,12 @@ class InvokeAIWebServer:
traceback.print_exc() traceback.print_exc()
print("\n") print("\n")
def empty_cuda_cache(self):
if self.generate.device.type == "cuda":
import torch.cuda
torch.cuda.empty_cache()
def parameters_to_generated_image_metadata(self, parameters): def parameters_to_generated_image_metadata(self, parameters):
try: try:
# top-level metadata minus `image` or `images` # top-level metadata minus `image` or `images`

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -7,8 +7,8 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>InvokeAI - A Stable Diffusion Toolkit</title> <title>InvokeAI - A Stable Diffusion Toolkit</title>
<link rel="shortcut icon" type="icon" href="./assets/favicon.0d253ced.ico" /> <link rel="shortcut icon" type="icon" href="./assets/favicon.0d253ced.ico" />
<script type="module" crossorigin src="./assets/index.dd4ad8a1.js"></script> <script type="module" crossorigin src="./assets/index.9310184f.js"></script>
<link rel="stylesheet" href="./assets/index.8badc8b4.css"> <link rel="stylesheet" href="./assets/index.1536494e.css">
<script type="module">try{import.meta.url;import("_").catch(()=>1);}catch(e){}window.__vite_is_modern_browser=true;</script> <script type="module">try{import.meta.url;import("_").catch(()=>1);}catch(e){}window.__vite_is_modern_browser=true;</script>
<script type="module">!function(){if(window.__vite_is_modern_browser)return;console.warn("vite: loading legacy build because dynamic import or import.meta.url is unsupported, syntax error above should be ignored");var e=document.getElementById("vite-legacy-polyfill"),n=document.createElement("script");n.src=e.src,n.onload=function(){System.import(document.getElementById('vite-legacy-entry').getAttribute('data-src'))},document.body.appendChild(n)}();</script> <script type="module">!function(){if(window.__vite_is_modern_browser)return;console.warn("vite: loading legacy build because dynamic import or import.meta.url is unsupported, syntax error above should be ignored");var e=document.getElementById("vite-legacy-polyfill"),n=document.createElement("script");n.src=e.src,n.onload=function(){System.import(document.getElementById('vite-legacy-entry').getAttribute('data-src'))},document.body.appendChild(n)}();</script>
</head> </head>
@ -18,6 +18,6 @@
<script nomodule>!function(){var e=document,t=e.createElement("script");if(!("noModule"in t)&&"onbeforeload"in t){var n=!1;e.addEventListener("beforeload",(function(e){if(e.target===t)n=!0;else if(!e.target.hasAttribute("nomodule")||!n)return;e.preventDefault()}),!0),t.type="module",t.src=".",e.head.appendChild(t),t.remove()}}();</script> <script nomodule>!function(){var e=document,t=e.createElement("script");if(!("noModule"in t)&&"onbeforeload"in t){var n=!1;e.addEventListener("beforeload",(function(e){if(e.target===t)n=!0;else if(!e.target.hasAttribute("nomodule")||!n)return;e.preventDefault()}),!0),t.type="module",t.src=".",e.head.appendChild(t),t.remove()}}();</script>
<script nomodule crossorigin id="vite-legacy-polyfill" src="./assets/polyfills-legacy-dde3a68a.js"></script> <script nomodule crossorigin id="vite-legacy-polyfill" src="./assets/polyfills-legacy-dde3a68a.js"></script>
<script nomodule crossorigin id="vite-legacy-entry" data-src="./assets/index-legacy-8219c08f.js">System.import(document.getElementById('vite-legacy-entry').getAttribute('data-src'))</script> <script nomodule crossorigin id="vite-legacy-entry" data-src="./assets/index-legacy-a33ada34.js">System.import(document.getElementById('vite-legacy-entry').getAttribute('data-src'))</script>
</body> </body>
</html> </html>

View File

@ -43,6 +43,7 @@
"invoke": "Invoke", "invoke": "Invoke",
"cancel": "Cancel", "cancel": "Cancel",
"promptPlaceholder": "Type prompt here. [negative tokens], (upweight)++, (downweight)--, swap and blend are available (see docs)", "promptPlaceholder": "Type prompt here. [negative tokens], (upweight)++, (downweight)--, swap and blend are available (see docs)",
"negativePrompts": "Negative Prompts",
"sendTo": "Send to", "sendTo": "Send to",
"sendToImg2Img": "Send to Image to Image", "sendToImg2Img": "Send to Image to Image",
"sendToUnifiedCanvas": "Send To Unified Canvas", "sendToUnifiedCanvas": "Send To Unified Canvas",

View File

@ -43,6 +43,7 @@
"invoke": "Invoke", "invoke": "Invoke",
"cancel": "Cancel", "cancel": "Cancel",
"promptPlaceholder": "Type prompt here. [negative tokens], (upweight)++, (downweight)--, swap and blend are available (see docs)", "promptPlaceholder": "Type prompt here. [negative tokens], (upweight)++, (downweight)--, swap and blend are available (see docs)",
"negativePrompts": "Negative Prompts",
"sendTo": "Send to", "sendTo": "Send to",
"sendToImg2Img": "Send to Image to Image", "sendToImg2Img": "Send to Image to Image",
"sendToUnifiedCanvas": "Send To Unified Canvas", "sendToUnifiedCanvas": "Send To Unified Canvas",

View File

@ -11,7 +11,6 @@ const useClickOutsideWatcher = () => {
function handleClickOutside(e: MouseEvent) { function handleClickOutside(e: MouseEvent) {
watchers.forEach(({ ref, enable, callback }) => { watchers.forEach(({ ref, enable, callback }) => {
if (enable && ref.current && !ref.current.contains(e.target as Node)) { if (enable && ref.current && !ref.current.contains(e.target as Node)) {
console.log('callback');
callback(); callback();
} }
}); });

View File

@ -0,0 +1,20 @@
import * as InvokeAI from 'app/invokeai';
import promptToString from './promptToString';
export function getPromptAndNegative(input_prompt: InvokeAI.Prompt) {
let prompt: string = promptToString(input_prompt);
let negativePrompt: string | null = null;
const negativePromptRegExp = new RegExp(/(?<=\[)[^\][]*(?=])/, 'gi');
const negativePromptMatches = [...prompt.matchAll(negativePromptRegExp)];
if (negativePromptMatches && negativePromptMatches.length > 0) {
negativePrompt = negativePromptMatches.join(', ');
prompt = prompt
.replaceAll(negativePromptRegExp, '')
.replaceAll('[]', '')
.trim();
}
return [prompt, negativePrompt];
}

View File

@ -106,6 +106,7 @@ export const frontendToBackendParameters = (
iterations, iterations,
perlin, perlin,
prompt, prompt,
negativePrompt,
sampler, sampler,
seamBlur, seamBlur,
seamless, seamless,
@ -155,6 +156,10 @@ export const frontendToBackendParameters = (
let esrganParameters: false | BackendEsrGanParameters = false; let esrganParameters: false | BackendEsrGanParameters = false;
let facetoolParameters: false | BackendFacetoolParameters = false; let facetoolParameters: false | BackendFacetoolParameters = false;
if (negativePrompt !== '') {
generationParameters.prompt = `${prompt} [${negativePrompt}]`;
}
generationParameters.seed = shouldRandomizeSeed generationParameters.seed = shouldRandomizeSeed
? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX) ? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)
: seed; : seed;

View File

@ -9,6 +9,7 @@ import {
setAllParameters, setAllParameters,
setInitialImage, setInitialImage,
setIsLightBoxOpen, setIsLightBoxOpen,
setNegativePrompt,
setPrompt, setPrompt,
setSeed, setSeed,
setShouldShowImageDetails, setShouldShowImageDetails,
@ -44,6 +45,7 @@ import { GalleryState } from 'features/gallery/store/gallerySlice';
import { activeTabNameSelector } from 'features/options/store/optionsSelectors'; import { activeTabNameSelector } from 'features/options/store/optionsSelectors';
import IAIPopover from 'common/components/IAIPopover'; import IAIPopover from 'common/components/IAIPopover';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { getPromptAndNegative } from 'common/util/getPromptAndNegative';
const systemSelector = createSelector( const systemSelector = createSelector(
[ [
@ -241,9 +243,18 @@ const CurrentImageButtons = () => {
[currentImage] [currentImage]
); );
const handleClickUsePrompt = () => const handleClickUsePrompt = () => {
currentImage?.metadata?.image?.prompt && if (currentImage?.metadata?.image?.prompt) {
dispatch(setPrompt(currentImage.metadata.image.prompt)); const [prompt, negativePrompt] = getPromptAndNegative(
currentImage?.metadata?.image?.prompt
);
prompt && dispatch(setPrompt(prompt));
negativePrompt
? dispatch(setNegativePrompt(negativePrompt))
: dispatch(setNegativePrompt(''));
}
};
useHotkeys( useHotkeys(
'p', 'p',

View File

@ -10,9 +10,10 @@ import { DragEvent, memo, useState } from 'react';
import { import {
setActiveTab, setActiveTab,
setAllImageToImageParameters, setAllImageToImageParameters,
setAllTextToImageParameters, setAllParameters,
setInitialImage, setInitialImage,
setIsLightBoxOpen, setIsLightBoxOpen,
setNegativePrompt,
setPrompt, setPrompt,
setSeed, setSeed,
} from 'features/options/store/optionsSlice'; } from 'features/options/store/optionsSlice';
@ -24,6 +25,7 @@ import {
} from 'features/canvas/store/canvasSlice'; } from 'features/canvas/store/canvasSlice';
import { hoverableImageSelector } from 'features/gallery/store/gallerySliceSelectors'; import { hoverableImageSelector } from 'features/gallery/store/gallerySliceSelectors';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { getPromptAndNegative } from 'common/util/getPromptAndNegative';
interface HoverableImageProps { interface HoverableImageProps {
image: InvokeAI.Image; image: InvokeAI.Image;
@ -62,7 +64,17 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const handleMouseOut = () => setIsHovered(false); const handleMouseOut = () => setIsHovered(false);
const handleUsePrompt = () => { const handleUsePrompt = () => {
image.metadata && dispatch(setPrompt(image.metadata.image.prompt)); if (image.metadata) {
const [prompt, negativePrompt] = getPromptAndNegative(
image.metadata?.image?.prompt
);
prompt && dispatch(setPrompt(prompt));
negativePrompt
? dispatch(setNegativePrompt(negativePrompt))
: dispatch(setNegativePrompt(''));
}
toast({ toast({
title: t('toast:promptSet'), title: t('toast:promptSet'),
status: 'success', status: 'success',
@ -115,7 +127,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
}; };
const handleUseAllParameters = () => { const handleUseAllParameters = () => {
metadata && dispatch(setAllTextToImageParameters(metadata)); metadata && dispatch(setAllParameters(metadata));
toast({ toast({
title: t('toast:parametersSet'), title: t('toast:parametersSet'),
status: 'success', status: 'success',

View File

@ -38,7 +38,6 @@ export const uploadImage =
}); });
const image = (await response.json()) as InvokeAI.ImageUploadResponse; const image = (await response.json()) as InvokeAI.ImageUploadResponse;
console.log(image);
const newImage: InvokeAI.Image = { const newImage: InvokeAI.Image = {
uuid: uuidv4(), uuid: uuidv4(),
category: 'user', category: 'user',

View File

@ -0,0 +1,38 @@
import { FormControl, Textarea } from '@chakra-ui/react';
import type { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { setNegativePrompt } from 'features/options/store/optionsSlice';
import { useTranslation } from 'react-i18next';
export function NegativePromptInput() {
const negativePrompt = useAppSelector(
(state: RootState) => state.options.negativePrompt
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
return (
<FormControl>
<Textarea
id="negativePrompt"
name="negativePrompt"
value={negativePrompt}
onChange={(e) => dispatch(setNegativePrompt(e.target.value))}
background="var(--prompt-bg-color)"
placeholder={t('options:negativePrompts')}
_placeholder={{ fontSize: '0.8rem' }}
borderColor="var(--border-color)"
_hover={{
borderColor: 'var(--border-color-light)',
}}
_focusVisible={{
borderColor: 'var(--border-color-invalid)',
boxShadow: '0 0 10px var(--box-shadow-color-invalid)',
}}
fontSize="0.9rem"
color="var(--text-color-secondary)"
/>
</FormControl>
);
}

View File

@ -5,6 +5,7 @@ import promptToString from 'common/util/promptToString';
import { seedWeightsToString } from 'common/util/seedWeightPairs'; import { seedWeightsToString } from 'common/util/seedWeightPairs';
import { FACETOOL_TYPES } from 'app/constants'; import { FACETOOL_TYPES } from 'app/constants';
import { InvokeTabName, tabMap } from 'features/tabs/tabMap'; import { InvokeTabName, tabMap } from 'features/tabs/tabMap';
import { getPromptAndNegative } from 'common/util/getPromptAndNegative';
export type UpscalingLevel = 2 | 4; export type UpscalingLevel = 2 | 4;
@ -28,6 +29,7 @@ export interface OptionsState {
optionsPanelScrollPosition: number; optionsPanelScrollPosition: number;
perlin: number; perlin: number;
prompt: string; prompt: string;
negativePrompt: string;
sampler: string; sampler: string;
seamBlur: number; seamBlur: number;
seamless: boolean; seamless: boolean;
@ -77,6 +79,7 @@ const initialOptionsState: OptionsState = {
optionsPanelScrollPosition: 0, optionsPanelScrollPosition: 0,
perlin: 0, perlin: 0,
prompt: '', prompt: '',
negativePrompt: '',
sampler: 'k_lms', sampler: 'k_lms',
seamBlur: 16, seamBlur: 16,
seamless: false, seamless: false,
@ -123,6 +126,17 @@ export const optionsSlice = createSlice({
state.prompt = promptToString(newPrompt); state.prompt = promptToString(newPrompt);
} }
}, },
setNegativePrompt: (
state,
action: PayloadAction<string | InvokeAI.Prompt>
) => {
const newPrompt = action.payload;
if (typeof newPrompt === 'string') {
state.negativePrompt = newPrompt;
} else {
state.negativePrompt = promptToString(newPrompt);
}
},
setIterations: (state, action: PayloadAction<number>) => { setIterations: (state, action: PayloadAction<number>) => {
state.iterations = action.payload; state.iterations = action.payload;
}, },
@ -307,7 +321,14 @@ export const optionsSlice = createSlice({
state.shouldRandomizeSeed = false; state.shouldRandomizeSeed = false;
} }
if (prompt) state.prompt = promptToString(prompt); if (prompt) {
const [promptOnly, negativePrompt] = getPromptAndNegative(prompt);
if (promptOnly) state.prompt = promptOnly;
negativePrompt
? (state.negativePrompt = negativePrompt)
: (state.negativePrompt = '');
}
if (sampler) state.sampler = sampler; if (sampler) state.sampler = sampler;
if (steps) state.steps = steps; if (steps) state.steps = steps;
if (cfg_scale) state.cfgScale = cfg_scale; if (cfg_scale) state.cfgScale = cfg_scale;
@ -448,6 +469,7 @@ export const {
setParameter, setParameter,
setPerlin, setPerlin,
setPrompt, setPrompt,
setNegativePrompt,
setSampler, setSampler,
setSeamBlur, setSeamBlur,
setSeamless, setSeamless,

View File

@ -13,16 +13,16 @@ export default function LanguagePicker() {
const LANGUAGES = { const LANGUAGES = {
en: t('common:langEnglish'), en: t('common:langEnglish'),
ru: t('common:langRussian'),
it: t('common:langItalian'),
pt_br: t('common:langBrPortuguese'),
de: t('common:langGerman'),
pl: t('common:langPolish'),
zh_cn: t('common:langSimplifiedChinese'),
es: t('common:langSpanish'),
ja: t('common:langJapanese'),
nl: t('common:langDutch'), nl: t('common:langDutch'),
fr: t('common:langFrench'), fr: t('common:langFrench'),
de: t('common:langGerman'),
it: t('common:langItalian'),
ja: t('common:langJapanese'),
pl: t('common:langPolish'),
pt_br: t('common:langBrPortuguese'),
ru: t('common:langRussian'),
zh_cn: t('common:langSimplifiedChinese'),
es: t('common:langSpanish'),
ua: t('common:langUkranian'), ua: t('common:langUkranian'),
}; };

View File

@ -316,7 +316,6 @@ export default function CheckpointModelEdit() {
) : ( ) : (
<Flex <Flex
width="100%" width="100%"
height="250px"
justifyContent="center" justifyContent="center"
alignItems="center" alignItems="center"
backgroundColor="var(--background-color)" backgroundColor="var(--background-color)"

View File

@ -271,7 +271,6 @@ export default function DiffusersModelEdit() {
) : ( ) : (
<Flex <Flex
width="100%" width="100%"
height="250px"
justifyContent="center" justifyContent="center"
alignItems="center" alignItems="center"
backgroundColor="var(--background-color)" backgroundColor="var(--background-color)"

View File

@ -19,6 +19,8 @@ import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import InvokeOptionsPanel from 'features/tabs/components/InvokeOptionsPanel'; import InvokeOptionsPanel from 'features/tabs/components/InvokeOptionsPanel';
import { activeTabNameSelector } from 'features/options/store/optionsSelectors'; import { activeTabNameSelector } from 'features/options/store/optionsSelectors';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { Flex } from '@chakra-ui/react';
import { NegativePromptInput } from 'features/options/components/PromptInput/NegativePromptInput';
export default function ImageToImagePanel() { export default function ImageToImagePanel() {
const { t } = useTranslation(); const { t } = useTranslation();
@ -67,7 +69,10 @@ export default function ImageToImagePanel() {
return ( return (
<InvokeOptionsPanel> <InvokeOptionsPanel>
<Flex flexDir="column" rowGap="0.5rem">
<PromptInput /> <PromptInput />
<NegativePromptInput />
</Flex>
<ProcessButtons /> <ProcessButtons />
<MainOptions /> <MainOptions />
<ImageToImageStrength <ImageToImageStrength

View File

@ -24,8 +24,8 @@
} }
svg { svg {
width: 26px; width: 24px;
height: 26px; height: 24px;
} }
&[aria-selected='true'] { &[aria-selected='true'] {

View File

@ -1,3 +1,4 @@
import { Flex } from '@chakra-ui/react';
import { Feature } from 'app/features'; import { Feature } from 'app/features';
import FaceRestoreOptions from 'features/options/components/AdvancedOptions/FaceRestore/FaceRestoreOptions'; import FaceRestoreOptions from 'features/options/components/AdvancedOptions/FaceRestore/FaceRestoreOptions';
import FaceRestoreToggle from 'features/options/components/AdvancedOptions/FaceRestore/FaceRestoreToggle'; import FaceRestoreToggle from 'features/options/components/AdvancedOptions/FaceRestore/FaceRestoreToggle';
@ -10,6 +11,7 @@ import VariationsOptions from 'features/options/components/AdvancedOptions/Varia
import MainOptions from 'features/options/components/MainOptions/MainOptions'; import MainOptions from 'features/options/components/MainOptions/MainOptions';
import OptionsAccordion from 'features/options/components/OptionsAccordion'; import OptionsAccordion from 'features/options/components/OptionsAccordion';
import ProcessButtons from 'features/options/components/ProcessButtons/ProcessButtons'; import ProcessButtons from 'features/options/components/ProcessButtons/ProcessButtons';
import { NegativePromptInput } from 'features/options/components/PromptInput/NegativePromptInput';
import PromptInput from 'features/options/components/PromptInput/PromptInput'; import PromptInput from 'features/options/components/PromptInput/PromptInput';
import InvokeOptionsPanel from 'features/tabs/components/InvokeOptionsPanel'; import InvokeOptionsPanel from 'features/tabs/components/InvokeOptionsPanel';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -50,7 +52,10 @@ export default function TextToImagePanel() {
return ( return (
<InvokeOptionsPanel> <InvokeOptionsPanel>
<Flex flexDir="column" rowGap="0.5rem">
<PromptInput /> <PromptInput />
<NegativePromptInput />
</Flex>
<ProcessButtons /> <ProcessButtons />
<MainOptions /> <MainOptions />
<OptionsAccordion accordionInfo={textToImageAccordions} /> <OptionsAccordion accordionInfo={textToImageAccordions} />

View File

@ -13,6 +13,8 @@ import InvokeOptionsPanel from 'features/tabs/components/InvokeOptionsPanel';
import BoundingBoxSettings from 'features/options/components/AdvancedOptions/Canvas/BoundingBoxSettings/BoundingBoxSettings'; import BoundingBoxSettings from 'features/options/components/AdvancedOptions/Canvas/BoundingBoxSettings/BoundingBoxSettings';
import InfillAndScalingOptions from 'features/options/components/AdvancedOptions/Canvas/InfillAndScalingOptions'; import InfillAndScalingOptions from 'features/options/components/AdvancedOptions/Canvas/InfillAndScalingOptions';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { Flex } from '@chakra-ui/react';
import { NegativePromptInput } from 'features/options/components/PromptInput/NegativePromptInput';
export default function UnifiedCanvasPanel() { export default function UnifiedCanvasPanel() {
const { t } = useTranslation(); const { t } = useTranslation();
@ -48,7 +50,10 @@ export default function UnifiedCanvasPanel() {
return ( return (
<InvokeOptionsPanel> <InvokeOptionsPanel>
<Flex flexDir="column" rowGap="0.5rem">
<PromptInput /> <PromptInput />
<NegativePromptInput />
</Flex>
<ProcessButtons /> <ProcessButtons />
<MainOptions /> <MainOptions />
<ImageToImageStrength <ImageToImageStrength

View File

@ -344,6 +344,7 @@ class Generate:
**args, **args,
): # eat up additional cruft ): # eat up additional cruft
self.clear_cuda_stats()
""" """
ldm.generate.prompt2image() is the common entry point for txt2img() and img2img() ldm.generate.prompt2image() is the common entry point for txt2img() and img2img()
It takes the following arguments: It takes the following arguments:
@ -548,6 +549,7 @@ class Generate:
inpaint_width = inpaint_width, inpaint_width = inpaint_width,
enable_image_debugging = enable_image_debugging, enable_image_debugging = enable_image_debugging,
free_gpu_mem=self.free_gpu_mem, free_gpu_mem=self.free_gpu_mem,
clear_cuda_cache=self.clear_cuda_cache
) )
if init_color: if init_color:
@ -565,11 +567,17 @@ class Generate:
image_callback = image_callback) image_callback = image_callback)
except KeyboardInterrupt: except KeyboardInterrupt:
# Clear the CUDA cache on an exception
self.clear_cuda_cache()
if catch_interrupts: if catch_interrupts:
print('**Interrupted** Partial results will be returned.') print('**Interrupted** Partial results will be returned.')
else: else:
raise KeyboardInterrupt raise KeyboardInterrupt
except RuntimeError: except RuntimeError:
# Clear the CUDA cache on an exception
self.clear_cuda_cache()
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
print('>> Could not generate image.') print('>> Could not generate image.')
@ -579,22 +587,42 @@ class Generate:
f'>> {len(results)} image(s) generated in', '%4.2fs' % ( f'>> {len(results)} image(s) generated in', '%4.2fs' % (
toc - tic) toc - tic)
) )
self.print_cuda_stats()
return results
def clear_cuda_cache(self):
if self._has_cuda():
self.max_memory_allocated = max(
self.max_memory_allocated,
torch.cuda.max_memory_allocated()
)
self.memory_allocated = max(
self.memory_allocated,
torch.cuda.memory_allocated()
)
self.session_peakmem = max(
self.session_peakmem,
torch.cuda.max_memory_allocated()
)
torch.cuda.empty_cache()
def clear_cuda_stats(self):
self.max_memory_allocated = 0
self.memory_allocated = 0
def print_cuda_stats(self):
if self._has_cuda(): if self._has_cuda():
print( print(
'>> Max VRAM used for this generation:', '>> Max VRAM used for this generation:',
'%4.2fG.' % (torch.cuda.max_memory_allocated() / 1e9), '%4.2fG.' % (self.max_memory_allocated / 1e9),
'Current VRAM utilization:', 'Current VRAM utilization:',
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9), '%4.2fG' % (self.memory_allocated / 1e9),
) )
self.session_peakmem = max(
self.session_peakmem, torch.cuda.max_memory_allocated()
)
print( print(
'>> Max VRAM used since script start: ', '>> Max VRAM used since script start: ',
'%4.2fG' % (self.session_peakmem / 1e9), '%4.2fG' % (self.session_peakmem / 1e9),
) )
return results
# this needs to be generalized to all sorts of postprocessors, which should be wrapped # this needs to be generalized to all sorts of postprocessors, which should be wrapped
# in a nice harmonized call signature. For now we have a bunch of if/elses! # in a nice harmonized call signature. For now we have a bunch of if/elses!

View File

@ -123,8 +123,9 @@ class Generator:
seed = self.new_seed() seed = self.new_seed()
# Free up memory from the last generation. # Free up memory from the last generation.
if self.model.device.type == 'cuda': clear_cuda_cache = kwargs['clear_cuda_cache'] or None
torch.cuda.empty_cache() if clear_cuda_cache is not None:
clear_cuda_cache()
return results return results

View File

@ -65,6 +65,11 @@ class Txt2Img2Img(Generator):
mode="bilinear" mode="bilinear"
) )
# Free up memory from the last generation.
clear_cuda_cache = kwargs['clear_cuda_cache'] or None
if clear_cuda_cache is not None:
clear_cuda_cache()
second_pass_noise = self.get_noise_like(resized_latents) second_pass_noise = self.get_noise_like(resized_latents)
verbosity = get_verbosity() verbosity = get_verbosity()