mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
adding support for ESRGAN denoising strength (#2598)
pulling in denoising support from upstream (its already there, invoke just isn't using it). I've enabled this as a command line argument as construction of the ESRGAN handler happens once. Ideally this would be a UI option that could be adjusted for each upscaling task. Unfortunately that is beyond my current level of InvokeAI-foo. Upstream reference is here, starting on line 99 "use dni to control the denoise strength" https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan.py
This commit is contained in:
commit
4e95a68582
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,4 +1,5 @@
|
||||
# ignore default image save location and model symbolic link
|
||||
.idea/
|
||||
embeddings/
|
||||
outputs/
|
||||
models/ldm/stable-diffusion-v1/model.ckpt
|
||||
@ -232,4 +233,4 @@ installer/update.bat
|
||||
installer/update.sh
|
||||
|
||||
# no longer stored in source directory
|
||||
models
|
||||
models
|
||||
|
@ -680,7 +680,8 @@ class InvokeAIWebServer:
|
||||
image = self.esrgan.process(
|
||||
image=image,
|
||||
upsampler_scale=postprocessing_parameters["upscale"][0],
|
||||
strength=postprocessing_parameters["upscale"][1],
|
||||
denoise_str=postprocessing_parameters["upscale"][1],
|
||||
strength=postprocessing_parameters["upscale"][2],
|
||||
seed=seed,
|
||||
)
|
||||
elif postprocessing_parameters["type"] == "gfpgan":
|
||||
@ -1064,6 +1065,7 @@ class InvokeAIWebServer:
|
||||
image = self.esrgan.process(
|
||||
image=image,
|
||||
upsampler_scale=esrgan_parameters["level"],
|
||||
denoise_str=esrgan_parameters['denoise_str'],
|
||||
strength=esrgan_parameters["strength"],
|
||||
seed=seed,
|
||||
)
|
||||
@ -1071,6 +1073,7 @@ class InvokeAIWebServer:
|
||||
postprocessing = True
|
||||
all_parameters["upscale"] = [
|
||||
esrgan_parameters["level"],
|
||||
esrgan_parameters['denoise_str'],
|
||||
esrgan_parameters["strength"],
|
||||
]
|
||||
|
||||
@ -1287,7 +1290,8 @@ class InvokeAIWebServer:
|
||||
{
|
||||
"type": "esrgan",
|
||||
"scale": int(parameters["upscale"][0]),
|
||||
"strength": float(parameters["upscale"][1]),
|
||||
"denoise_str": int(parameters["upscale"][1]),
|
||||
"strength": float(parameters["upscale"][2]),
|
||||
}
|
||||
)
|
||||
|
||||
@ -1361,7 +1365,8 @@ class InvokeAIWebServer:
|
||||
if parameters["type"] == "esrgan":
|
||||
postprocessing_metadata["type"] = "esrgan"
|
||||
postprocessing_metadata["scale"] = parameters["upscale"][0]
|
||||
postprocessing_metadata["strength"] = parameters["upscale"][1]
|
||||
postprocessing_metadata["denoise_str"] = parameters["upscale"][1]
|
||||
postprocessing_metadata["strength"] = parameters["upscale"][2]
|
||||
elif parameters["type"] == "gfpgan":
|
||||
postprocessing_metadata["type"] = "gfpgan"
|
||||
postprocessing_metadata["strength"] = parameters["facetool_strength"]
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
invokeai/frontend/dist/assets/index-fecb6dd4.css
vendored
Normal file
1
invokeai/frontend/dist/assets/index-fecb6dd4.css
vendored
Normal file
File diff suppressed because one or more lines are too long
4
invokeai/frontend/dist/index.html
vendored
4
invokeai/frontend/dist/index.html
vendored
@ -5,8 +5,8 @@
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>InvokeAI - A Stable Diffusion Toolkit</title>
|
||||
<link rel="shortcut icon" type="icon" href="./assets/favicon-0d253ced.ico" />
|
||||
<script type="module" crossorigin src="./assets/index-252612ad.js"></script>
|
||||
<link rel="stylesheet" href="./assets/index-b0bf79f4.css">
|
||||
<script type="module" crossorigin src="./assets/index-ad762ffd.js"></script>
|
||||
<link rel="stylesheet" href="./assets/index-fecb6dd4.css">
|
||||
</head>
|
||||
|
||||
<body>
|
||||
|
@ -20,6 +20,7 @@
|
||||
"upscaling": "Upscaling",
|
||||
"upscale": "Upscale",
|
||||
"upscaleImage": "Upscale Image",
|
||||
"denoisingStrength": "Denoising Strength",
|
||||
"scale": "Scale",
|
||||
"otherOptions": "Other Options",
|
||||
"seamlessTiling": "Seamless Tiling",
|
||||
|
@ -20,6 +20,7 @@
|
||||
"upscaling": "Upscaling",
|
||||
"upscale": "Upscale",
|
||||
"upscaleImage": "Upscale Image",
|
||||
"denoisingStrength": "Denoising Strength",
|
||||
"scale": "Scale",
|
||||
"otherOptions": "Other Options",
|
||||
"seamlessTiling": "Seamless Tiling",
|
||||
|
1
invokeai/frontend/src/app/invokeai.d.ts
vendored
1
invokeai/frontend/src/app/invokeai.d.ts
vendored
@ -92,6 +92,7 @@ export declare type ESRGANMetadata = CommonPostProcessedImageMetadata & {
|
||||
type: 'esrgan';
|
||||
scale: 2 | 4;
|
||||
strength: number;
|
||||
denoise_str: number;
|
||||
};
|
||||
|
||||
export declare type FacetoolMetadata = CommonPostProcessedImageMetadata & {
|
||||
|
@ -93,11 +93,15 @@ const makeSocketIOEmitters = (
|
||||
dispatch(setIsProcessing(true));
|
||||
|
||||
const {
|
||||
postprocessing: { upscalingLevel, upscalingStrength },
|
||||
postprocessing: {
|
||||
upscalingLevel,
|
||||
upscalingDenoising,
|
||||
upscalingStrength,
|
||||
},
|
||||
} = getState();
|
||||
|
||||
const esrganParameters = {
|
||||
upscale: [upscalingLevel, upscalingStrength],
|
||||
upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
|
||||
};
|
||||
socketio.emit('runPostprocessing', imageToProcess, {
|
||||
type: 'esrgan',
|
||||
|
@ -69,6 +69,7 @@ export type BackendGenerationParameters = {
|
||||
|
||||
export type BackendEsrGanParameters = {
|
||||
level: UpscalingLevel;
|
||||
denoise_str: number;
|
||||
strength: number;
|
||||
};
|
||||
|
||||
@ -111,13 +112,12 @@ export const frontendToBackendParameters = (
|
||||
shouldRunFacetool,
|
||||
upscalingLevel,
|
||||
upscalingStrength,
|
||||
upscalingDenoising,
|
||||
} = postprocessingState;
|
||||
|
||||
const {
|
||||
cfgScale,
|
||||
|
||||
height,
|
||||
|
||||
img2imgStrength,
|
||||
infillMethod,
|
||||
initialImage,
|
||||
@ -136,11 +136,9 @@ export const frontendToBackendParameters = (
|
||||
shouldFitToWidthHeight,
|
||||
shouldGenerateVariations,
|
||||
shouldRandomizeSeed,
|
||||
|
||||
steps,
|
||||
threshold,
|
||||
tileSize,
|
||||
|
||||
variationAmount,
|
||||
width,
|
||||
} = generationState;
|
||||
@ -190,6 +188,7 @@ export const frontendToBackendParameters = (
|
||||
if (shouldRunESRGAN) {
|
||||
esrganParameters = {
|
||||
level: upscalingLevel,
|
||||
denoise_str: upscalingDenoising,
|
||||
strength: upscalingStrength,
|
||||
};
|
||||
}
|
||||
|
@ -34,6 +34,7 @@ import {
|
||||
setFacetoolStrength,
|
||||
setFacetoolType,
|
||||
setHiresFix,
|
||||
setUpscalingDenoising,
|
||||
setUpscalingLevel,
|
||||
setUpscalingStrength,
|
||||
} from 'features/parameters/store/postprocessingSlice';
|
||||
@ -147,11 +148,11 @@ const ImageMetadataViewer = memo(
|
||||
postprocessing,
|
||||
prompt,
|
||||
sampler,
|
||||
scale,
|
||||
seamless,
|
||||
seed,
|
||||
steps,
|
||||
strength,
|
||||
denoise_str,
|
||||
threshold,
|
||||
type,
|
||||
variations,
|
||||
@ -184,27 +185,6 @@ const ImageMetadataViewer = memo(
|
||||
{['esrgan', 'gfpgan'].includes(type) && (
|
||||
<MetadataItem label="Original image" value={orig_path} />
|
||||
)}
|
||||
{type === 'gfpgan' && strength !== undefined && (
|
||||
<MetadataItem
|
||||
label="Fix faces strength"
|
||||
value={strength}
|
||||
onClick={() => dispatch(setFacetoolStrength(strength))}
|
||||
/>
|
||||
)}
|
||||
{type === 'esrgan' && scale !== undefined && (
|
||||
<MetadataItem
|
||||
label="Upscaling scale"
|
||||
value={scale}
|
||||
onClick={() => dispatch(setUpscalingLevel(scale))}
|
||||
/>
|
||||
)}
|
||||
{type === 'esrgan' && strength !== undefined && (
|
||||
<MetadataItem
|
||||
label="Upscaling strength"
|
||||
value={strength}
|
||||
onClick={() => dispatch(setUpscalingStrength(strength))}
|
||||
/>
|
||||
)}
|
||||
{prompt && (
|
||||
<MetadataItem
|
||||
label="Prompt"
|
||||
@ -331,7 +311,7 @@ const ImageMetadataViewer = memo(
|
||||
i: number
|
||||
) => {
|
||||
if (postprocess.type === 'esrgan') {
|
||||
const { scale, strength } = postprocess;
|
||||
const { scale, strength, denoise_str } = postprocess;
|
||||
return (
|
||||
<Flex
|
||||
key={i}
|
||||
@ -354,6 +334,15 @@ const ImageMetadataViewer = memo(
|
||||
dispatch(setUpscalingStrength(strength))
|
||||
}
|
||||
/>
|
||||
{denoise_str !== undefined && (
|
||||
<MetadataItem
|
||||
label="Denoising strength"
|
||||
value={denoise_str}
|
||||
onClick={() =>
|
||||
dispatch(setUpscalingDenoising(denoise_str))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
} else if (postprocess.type === 'gfpgan') {
|
||||
|
@ -1,5 +0,0 @@
|
||||
.upscale-settings {
|
||||
display: grid;
|
||||
grid-template-columns: auto 1fr;
|
||||
column-gap: 1rem;
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
|
||||
import {
|
||||
setUpscalingDenoising,
|
||||
setUpscalingLevel,
|
||||
setUpscalingStrength,
|
||||
UpscalingLevel,
|
||||
@ -8,20 +9,25 @@ import {
|
||||
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { UPSCALING_LEVELS } from 'app/constants';
|
||||
import IAINumberInput from 'common/components/IAINumberInput';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { isEqual } from 'lodash';
|
||||
import { ChangeEvent } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
|
||||
const parametersSelector = createSelector(
|
||||
[postprocessingSelector, systemSelector],
|
||||
|
||||
({ upscalingLevel, upscalingStrength }, { isESRGANAvailable }) => {
|
||||
(
|
||||
{ upscalingLevel, upscalingStrength, upscalingDenoising },
|
||||
{ isESRGANAvailable }
|
||||
) => {
|
||||
return {
|
||||
upscalingLevel,
|
||||
upscalingDenoising,
|
||||
upscalingStrength,
|
||||
isESRGANAvailable,
|
||||
};
|
||||
@ -38,8 +44,12 @@ const parametersSelector = createSelector(
|
||||
*/
|
||||
const UpscaleSettings = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { upscalingLevel, upscalingStrength, isESRGANAvailable } =
|
||||
useAppSelector(parametersSelector);
|
||||
const {
|
||||
upscalingLevel,
|
||||
upscalingStrength,
|
||||
upscalingDenoising,
|
||||
isESRGANAvailable,
|
||||
} = useAppSelector(parametersSelector);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@ -49,7 +59,7 @@ const UpscaleSettings = () => {
|
||||
const handleChangeStrength = (v: number) => dispatch(setUpscalingStrength(v));
|
||||
|
||||
return (
|
||||
<div className="upscale-settings">
|
||||
<Flex flexDir="column" rowGap="1rem" minWidth="20rem">
|
||||
<IAISelect
|
||||
isDisabled={!isESRGANAvailable}
|
||||
label={t('parameters:scale')}
|
||||
@ -57,17 +67,39 @@ const UpscaleSettings = () => {
|
||||
onChange={handleChangeLevel}
|
||||
validValues={UPSCALING_LEVELS}
|
||||
/>
|
||||
<IAINumberInput
|
||||
isDisabled={!isESRGANAvailable}
|
||||
label={t('parameters:strength')}
|
||||
step={0.05}
|
||||
<IAISlider
|
||||
label={t('parameters:denoisingStrength')}
|
||||
value={upscalingDenoising}
|
||||
min={0}
|
||||
max={1}
|
||||
onChange={handleChangeStrength}
|
||||
value={upscalingStrength}
|
||||
isInteger={false}
|
||||
step={0.01}
|
||||
onChange={(v) => {
|
||||
dispatch(setUpscalingDenoising(v));
|
||||
}}
|
||||
handleReset={() => dispatch(setUpscalingDenoising(0.75))}
|
||||
withSliderMarks
|
||||
withInput
|
||||
withReset
|
||||
isSliderDisabled={!isESRGANAvailable}
|
||||
isInputDisabled={!isESRGANAvailable}
|
||||
isResetDisabled={!isESRGANAvailable}
|
||||
/>
|
||||
</div>
|
||||
<IAISlider
|
||||
label={`${t('parameters:upscale')} ${t('parameters:strength')}`}
|
||||
value={upscalingStrength}
|
||||
min={0}
|
||||
max={1}
|
||||
step={0.05}
|
||||
onChange={handleChangeStrength}
|
||||
handleReset={() => dispatch(setUpscalingStrength(0.75))}
|
||||
withSliderMarks
|
||||
withInput
|
||||
withReset
|
||||
isSliderDisabled={!isESRGANAvailable}
|
||||
isInputDisabled={!isESRGANAvailable}
|
||||
isResetDisabled={!isESRGANAvailable}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -16,6 +16,7 @@ export interface PostprocessingState {
|
||||
shouldRunESRGAN: boolean;
|
||||
shouldRunFacetool: boolean;
|
||||
upscalingLevel: UpscalingLevel;
|
||||
upscalingDenoising: number;
|
||||
upscalingStrength: number;
|
||||
}
|
||||
|
||||
@ -29,6 +30,7 @@ const initialPostprocessingState: PostprocessingState = {
|
||||
shouldRunESRGAN: false,
|
||||
shouldRunFacetool: false,
|
||||
upscalingLevel: 4,
|
||||
upscalingDenoising: 0.75,
|
||||
upscalingStrength: 0.75,
|
||||
};
|
||||
|
||||
@ -47,6 +49,9 @@ export const postprocessingSlice = createSlice({
|
||||
setUpscalingLevel: (state, action: PayloadAction<UpscalingLevel>) => {
|
||||
state.upscalingLevel = action.payload;
|
||||
},
|
||||
setUpscalingDenoising: (state, action: PayloadAction<number>) => {
|
||||
state.upscalingDenoising = action.payload;
|
||||
},
|
||||
setUpscalingStrength: (state, action: PayloadAction<number>) => {
|
||||
state.upscalingStrength = action.payload;
|
||||
},
|
||||
@ -88,6 +93,7 @@ export const {
|
||||
setShouldRunESRGAN,
|
||||
setShouldRunFacetool,
|
||||
setUpscalingLevel,
|
||||
setUpscalingDenoising,
|
||||
setUpscalingStrength,
|
||||
} = postprocessingSlice.actions;
|
||||
|
||||
|
@ -27,7 +27,6 @@
|
||||
@use '../features/parameters/components/ProcessButtons/ProcessButtons.scss';
|
||||
@use '../features/parameters/components/MainParameters/MainParameters.scss';
|
||||
@use '../features/parameters/components/AccordionItems/AdvancedSettings.scss';
|
||||
@use '../features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings.scss';
|
||||
@use '../features/parameters/components/AdvancedParameters/Canvas/BoundingBox/BoundingBoxSettings.scss';
|
||||
|
||||
// gallery
|
||||
|
File diff suppressed because one or more lines are too long
@ -671,6 +671,12 @@ class Args(object):
|
||||
default=400,
|
||||
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--esrgan_denoise_str',
|
||||
type=float,
|
||||
default=0.75,
|
||||
help='esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--gfpgan_model_path',
|
||||
type=str,
|
||||
|
@ -128,7 +128,7 @@ script do it for you. Manual installation is described at:
|
||||
|
||||
https://invoke-ai.github.io/InvokeAI/installation/020_INSTALL_MANUAL/
|
||||
|
||||
You may download the recommended models (about 15GB total), install all models (40 GB!!)
|
||||
You may download the recommended models (about 15GB total), install all models (40 GB!!)
|
||||
select a customized set, or completely skip this step.
|
||||
"""
|
||||
)
|
||||
@ -583,7 +583,7 @@ def new_config_file_contents(successfully_downloaded: dict, config_file: Path, o
|
||||
# model is a diffusers (indicated with a path)
|
||||
if conf.get(model) and Path(successfully_downloaded[model]).is_dir():
|
||||
offer_to_delete_weights(model, conf[model], opt.yes_to_all)
|
||||
|
||||
|
||||
stanza = {}
|
||||
mod = Datasets[model]
|
||||
stanza["description"] = mod["description"]
|
||||
@ -635,7 +635,7 @@ def offer_to_delete_weights(model_name: str, conf_stanza: dict, yes_to_all: bool
|
||||
weights.unlink()
|
||||
except OSError as e:
|
||||
print(str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
# this will preload the Bert tokenizer fles
|
||||
def download_bert():
|
||||
@ -683,10 +683,18 @@ def download_clip():
|
||||
def download_realesrgan():
|
||||
print("Installing models from RealESRGAN...", file=sys.stderr)
|
||||
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
|
||||
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
|
||||
|
||||
model_dest = os.path.join(
|
||||
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
|
||||
)
|
||||
|
||||
wdn_model_dest = os.path.join(
|
||||
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
)
|
||||
|
||||
download_with_progress_bar(model_url, model_dest, "RealESRGAN")
|
||||
download_with_progress_bar(wdn_model_url, wdn_model_dest, "RealESRGANwdn")
|
||||
|
||||
|
||||
def download_gfpgan():
|
||||
|
@ -16,7 +16,7 @@ class ESRGAN():
|
||||
else:
|
||||
use_half_precision = True
|
||||
|
||||
def load_esrgan_bg_upsampler(self):
|
||||
def load_esrgan_bg_upsampler(self, denoise_str):
|
||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||
use_half_precision = False
|
||||
else:
|
||||
@ -26,14 +26,16 @@ class ESRGAN():
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
model_path = os.path.join(Globals.root,'models/realesrgan/realesr-general-x4v3.pth')
|
||||
model_path = os.path.join(Globals.root, 'models/realesrgan/realesr-general-x4v3.pth')
|
||||
wdn_model_path = os.path.join(Globals.root, 'models/realesrgan/realesr-general-wdn-x4v3.pth')
|
||||
scale = 4
|
||||
|
||||
bg_upsampler = RealESRGANer(
|
||||
scale=scale,
|
||||
model_path=model_path,
|
||||
model_path=[model_path, wdn_model_path],
|
||||
model=model,
|
||||
tile=self.bg_tile_size,
|
||||
dni_weight=[denoise_str, 1 - denoise_str],
|
||||
tile_pad=10,
|
||||
pre_pad=0,
|
||||
half=use_half_precision,
|
||||
@ -41,13 +43,13 @@ class ESRGAN():
|
||||
|
||||
return bg_upsampler
|
||||
|
||||
def process(self, image: ImageType, strength: float, seed: str = None, upsampler_scale: int = 2):
|
||||
def process(self, image: ImageType, strength: float, seed: str = None, upsampler_scale: int = 2, denoise_str: float = 0.75):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
try:
|
||||
upsampler = self.load_esrgan_bg_upsampler()
|
||||
upsampler = self.load_esrgan_bg_upsampler(denoise_str)
|
||||
except Exception:
|
||||
import traceback
|
||||
import sys
|
||||
@ -60,7 +62,7 @@ class ESRGAN():
|
||||
|
||||
if seed is not None:
|
||||
print(
|
||||
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
|
||||
f'>> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}'
|
||||
)
|
||||
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
|
||||
image = image.convert("RGB")
|
||||
|
Loading…
Reference in New Issue
Block a user