GFPGAN and Real ESRGAN Implementation Refactor

This commit is contained in:
blessedcoolant 2022-09-14 05:17:14 +12:00 committed by Lincoln Stein
parent e8bb39370c
commit 1b5013ab72
11 changed files with 532 additions and 377 deletions

View File

@ -4,10 +4,16 @@ title: Upscale
# :material-image-size-select-large: Upscale # :material-image-size-select-large: Upscale
## **Intro**
The script provides the ability to restore faces and upscale.
You can enable these features by passing `--restore` and `--esrgan` to your launch script to enable
face restoration modules and upscaling modules respectively.
## **GFPGAN and Real-ESRGAN Support** ## **GFPGAN and Real-ESRGAN Support**
The script also provides the ability to do face restoration and upscaling with the help of GFPGAN The default face restoration module is GFPGAN and the default upscaling module is ESRGAN.
and Real-ESRGAN respectively.
As of version 1.14, environment.yaml will install the Real-ESRGAN package into the standard install As of version 1.14, environment.yaml will install the Real-ESRGAN package into the standard install
location for python packages, and will put GFPGAN into a subdirectory of "src" in the location for python packages, and will put GFPGAN into a subdirectory of "src" in the

View File

@ -370,16 +370,19 @@ class Args(object):
type=str, type=str,
help='Path to a pre-trained embedding manager checkpoint - can only be set on command line', help='Path to a pre-trained embedding manager checkpoint - can only be set on command line',
) )
# GFPGAN related args # Restoration related args
postprocessing_group.add_argument( postprocessing_group.add_argument(
'--gfpgan_bg_upsampler', '--restore',
type=str, action='store_true',
default='realesrgan', help='Enable Face Restoration',
help='Background upsampler. Default: realesrgan. Options: realesrgan, none.',
) )
postprocessing_group.add_argument( postprocessing_group.add_argument(
'--gfpgan_bg_tile', '--esrgan',
action='store_true',
help='Enable Upscaling',
)
postprocessing_group.add_argument(
'--esrgan_bg_tile',
type=int, type=int,
default=400, default=400,
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.', help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',

View File

@ -10,6 +10,7 @@ from ldm.dream.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.dream.generator.img2img import Img2Img from ldm.dream.generator.img2img import Img2Img
class Embiggen(Generator): class Embiggen(Generator):
def __init__(self, model, precision): def __init__(self, model, precision):
super().__init__(model, precision) super().__init__(model, precision)
@ -42,7 +43,8 @@ class Embiggen(Generator):
embiggen = [1.0] # If not specified, assume no scaling embiggen = [1.0] # If not specified, assume no scaling
elif embiggen[0] < 0: elif embiggen[0] < 0:
embiggen[0] = 1.0 embiggen[0] = 1.0
print('>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !') print(
'>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !')
if len(embiggen) < 2: if len(embiggen) < 2:
embiggen.append(0.75) embiggen.append(0.75)
elif embiggen[1] > 1.0 or embiggen[1] < 0: elif embiggen[1] > 1.0 or embiggen[1] < 0:
@ -77,28 +79,29 @@ class Embiggen(Generator):
initsuperwidth = round(initsuperwidth*embiggen[0]) initsuperwidth = round(initsuperwidth*embiggen[0])
initsuperheight = round(initsuperheight*embiggen[0]) initsuperheight = round(initsuperheight*embiggen[0])
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
from ldm.gfpgan.gfpgan_tools import ( from ldm.restoration.realesrgan import ESRGAN
real_esrgan_upscale, esrgan = ESRGAN()
) print(
print(f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}') f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}')
if embiggen[0] > 2: if embiggen[0] > 2:
initsuperimage = real_esrgan_upscale( initsuperimage = esrgan.process(
initsuperimage, initsuperimage,
embiggen[1], # upscale strength embiggen[1], # upscale strength
4, # upscale scale
self.seed, self.seed,
4, # upscale scale
) )
else: else:
initsuperimage = real_esrgan_upscale( initsuperimage = esrgan.process(
initsuperimage, initsuperimage,
embiggen[1], # upscale strength embiggen[1], # upscale strength
2, # upscale scale
self.seed, self.seed,
2, # upscale scale
) )
# We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x # We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x
# but from personal experiance it doesn't greatly improve anything after 4x # but from personal experiance it doesn't greatly improve anything after 4x
# Resize to target scaling factor resolution # Resize to target scaling factor resolution
initsuperimage = initsuperimage.resize((initsuperwidth, initsuperheight), Image.Resampling.LANCZOS) initsuperimage = initsuperimage.resize(
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS)
# Use width and height as tile widths and height # Use width and height as tile widths and height
# Determine buffer size in pixels # Determine buffer size in pixels
@ -121,16 +124,19 @@ class Embiggen(Generator):
emb_tiles_x = 1 emb_tiles_x = 1
emb_tiles_y = 1 emb_tiles_y = 1
if (initsuperwidth - width) > 0: if (initsuperwidth - width) > 0:
emb_tiles_x = ceildiv(initsuperwidth - width, width - overlap_size_x) + 1 emb_tiles_x = ceildiv(initsuperwidth - width,
width - overlap_size_x) + 1
if (initsuperheight - height) > 0: if (initsuperheight - height) > 0:
emb_tiles_y = ceildiv(initsuperheight - height, height - overlap_size_y) + 1 emb_tiles_y = ceildiv(initsuperheight - height,
height - overlap_size_y) + 1
# Sanity # Sanity
assert emb_tiles_x > 1 or emb_tiles_y > 1, f'ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don\'t need to Embiggen! Check your arguments.' assert emb_tiles_x > 1 or emb_tiles_y > 1, f'ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don\'t need to Embiggen! Check your arguments.'
# Prep alpha layers -------------- # Prep alpha layers --------------
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil # https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
# agradientL is Left-side transparent # agradientL is Left-side transparent
agradientL = Image.linear_gradient('L').rotate(90).resize((overlap_size_x, height)) agradientL = Image.linear_gradient('L').rotate(
90).resize((overlap_size_x, height))
# agradientT is Top-side transparent # agradientT is Top-side transparent
agradientT = Image.linear_gradient('L').resize((width, overlap_size_y)) agradientT = Image.linear_gradient('L').resize((width, overlap_size_y))
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant # radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
@ -154,59 +160,79 @@ class Embiggen(Generator):
alphaLayerT.paste(agradientT, (0, 0)) alphaLayerT.paste(agradientT, (0, 0))
alphaLayerLTC.paste(agradientL, (0, 0)) alphaLayerLTC.paste(agradientL, (0, 0))
alphaLayerLTC.paste(agradientT, (0, 0)) alphaLayerLTC.paste(agradientT, (0, 0))
alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)) alphaLayerLTC.paste(agradientC.resize(
(overlap_size_x, overlap_size_y)), (0, 0))
if embiggen_tiles: if embiggen_tiles:
# Individual unconnected sides # Individual unconnected sides
alphaLayerR = Image.new("L", (width, height), 255) alphaLayerR = Image.new("L", (width, height), 255)
alphaLayerR.paste(agradientL.rotate(180), (width - overlap_size_x, 0)) alphaLayerR.paste(agradientL.rotate(
180), (width - overlap_size_x, 0))
alphaLayerB = Image.new("L", (width, height), 255) alphaLayerB = Image.new("L", (width, height), 255)
alphaLayerB.paste(agradientT.rotate(180), (0, height - overlap_size_y)) alphaLayerB.paste(agradientT.rotate(
180), (0, height - overlap_size_y))
alphaLayerTB = Image.new("L", (width, height), 255) alphaLayerTB = Image.new("L", (width, height), 255)
alphaLayerTB.paste(agradientT, (0, 0)) alphaLayerTB.paste(agradientT, (0, 0))
alphaLayerTB.paste(agradientT.rotate(180), (0, height - overlap_size_y)) alphaLayerTB.paste(agradientT.rotate(
180), (0, height - overlap_size_y))
alphaLayerLR = Image.new("L", (width, height), 255) alphaLayerLR = Image.new("L", (width, height), 255)
alphaLayerLR.paste(agradientL, (0, 0)) alphaLayerLR.paste(agradientL, (0, 0))
alphaLayerLR.paste(agradientL.rotate(180), (width - overlap_size_x, 0)) alphaLayerLR.paste(agradientL.rotate(
180), (width - overlap_size_x, 0))
# Sides and corner Layers # Sides and corner Layers
alphaLayerRBC = Image.new("L", (width, height), 255) alphaLayerRBC = Image.new("L", (width, height), 255)
alphaLayerRBC.paste(agradientL.rotate(180), (width - overlap_size_x, 0)) alphaLayerRBC.paste(agradientL.rotate(
alphaLayerRBC.paste(agradientT.rotate(180), (0, height - overlap_size_y)) 180), (width - overlap_size_x, 0))
alphaLayerRBC.paste(agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y)) alphaLayerRBC.paste(agradientT.rotate(
180), (0, height - overlap_size_y))
alphaLayerRBC.paste(agradientC.rotate(180).resize(
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
alphaLayerLBC = Image.new("L", (width, height), 255) alphaLayerLBC = Image.new("L", (width, height), 255)
alphaLayerLBC.paste(agradientL, (0, 0)) alphaLayerLBC.paste(agradientL, (0, 0))
alphaLayerLBC.paste(agradientT.rotate(180), (0, height - overlap_size_y)) alphaLayerLBC.paste(agradientT.rotate(
alphaLayerLBC.paste(agradientC.rotate(90).resize((overlap_size_x, overlap_size_y)), (0, height - overlap_size_y)) 180), (0, height - overlap_size_y))
alphaLayerLBC.paste(agradientC.rotate(90).resize(
(overlap_size_x, overlap_size_y)), (0, height - overlap_size_y))
alphaLayerRTC = Image.new("L", (width, height), 255) alphaLayerRTC = Image.new("L", (width, height), 255)
alphaLayerRTC.paste(agradientL.rotate(180), (width - overlap_size_x, 0)) alphaLayerRTC.paste(agradientL.rotate(
180), (width - overlap_size_x, 0))
alphaLayerRTC.paste(agradientT, (0, 0)) alphaLayerRTC.paste(agradientT, (0, 0))
alphaLayerRTC.paste(agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0)) alphaLayerRTC.paste(agradientC.rotate(270).resize(
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
# All but X layers # All but X layers
alphaLayerABT = Image.new("L", (width, height), 255) alphaLayerABT = Image.new("L", (width, height), 255)
alphaLayerABT.paste(alphaLayerLBC, (0, 0)) alphaLayerABT.paste(alphaLayerLBC, (0, 0))
alphaLayerABT.paste(agradientL.rotate(180), (width - overlap_size_x, 0)) alphaLayerABT.paste(agradientL.rotate(
alphaLayerABT.paste(agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y)) 180), (width - overlap_size_x, 0))
alphaLayerABT.paste(agradientC.rotate(180).resize(
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
alphaLayerABL = Image.new("L", (width, height), 255) alphaLayerABL = Image.new("L", (width, height), 255)
alphaLayerABL.paste(alphaLayerRTC, (0, 0)) alphaLayerABL.paste(alphaLayerRTC, (0, 0))
alphaLayerABL.paste(agradientT.rotate(180), (0, height - overlap_size_y)) alphaLayerABL.paste(agradientT.rotate(
alphaLayerABL.paste(agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y)) 180), (0, height - overlap_size_y))
alphaLayerABL.paste(agradientC.rotate(180).resize(
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
alphaLayerABR = Image.new("L", (width, height), 255) alphaLayerABR = Image.new("L", (width, height), 255)
alphaLayerABR.paste(alphaLayerLBC, (0, 0)) alphaLayerABR.paste(alphaLayerLBC, (0, 0))
alphaLayerABR.paste(agradientT, (0, 0)) alphaLayerABR.paste(agradientT, (0, 0))
alphaLayerABR.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)) alphaLayerABR.paste(agradientC.resize(
(overlap_size_x, overlap_size_y)), (0, 0))
alphaLayerABB = Image.new("L", (width, height), 255) alphaLayerABB = Image.new("L", (width, height), 255)
alphaLayerABB.paste(alphaLayerRTC, (0, 0)) alphaLayerABB.paste(alphaLayerRTC, (0, 0))
alphaLayerABB.paste(agradientL, (0, 0)) alphaLayerABB.paste(agradientL, (0, 0))
alphaLayerABB.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)) alphaLayerABB.paste(agradientC.resize(
(overlap_size_x, overlap_size_y)), (0, 0))
# All-around layer # All-around layer
alphaLayerAA = Image.new("L", (width, height), 255) alphaLayerAA = Image.new("L", (width, height), 255)
alphaLayerAA.paste(alphaLayerABT, (0, 0)) alphaLayerAA.paste(alphaLayerABT, (0, 0))
alphaLayerAA.paste(agradientT, (0, 0)) alphaLayerAA.paste(agradientT, (0, 0))
alphaLayerAA.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)) alphaLayerAA.paste(agradientC.resize(
alphaLayerAA.paste(agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0)) (overlap_size_x, overlap_size_y)), (0, 0))
alphaLayerAA.paste(agradientC.rotate(270).resize(
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
# Clean up temporary gradients # Clean up temporary gradients
del agradientL del agradientL
@ -218,7 +244,8 @@ class Embiggen(Generator):
if embiggen_tiles: if embiggen_tiles:
print(f'>> Making {len(embiggen_tiles)} Embiggen tiles...') print(f'>> Making {len(embiggen_tiles)} Embiggen tiles...')
else: else:
print(f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...') print(
f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...')
emb_tile_store = [] emb_tile_store = []
for tile in range(emb_tiles_x * emb_tiles_y): for tile in range(emb_tiles_x * emb_tiles_y):
@ -248,12 +275,15 @@ class Embiggen(Generator):
# newinitimage.save(newinitimagepath) # newinitimage.save(newinitimagepath)
if embiggen_tiles: if embiggen_tiles:
print(f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)') print(
f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)')
else: else:
print(f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles') print(
f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles')
# create a torch tensor from an Image # create a torch tensor from an Image
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0 newinitimage = np.array(
newinitimage).astype(np.float32) / 255.0
newinitimage = newinitimage[None].transpose(0, 3, 1, 2) newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
newinitimage = torch.from_numpy(newinitimage) newinitimage = torch.from_numpy(newinitimage)
newinitimage = 2.0 * newinitimage - 1.0 newinitimage = 2.0 * newinitimage - 1.0
@ -285,9 +315,11 @@ class Embiggen(Generator):
# Sanity check we have them all # Sanity check we have them all
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)): if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)):
outputsuperimage = Image.new("RGBA", (initsuperwidth, initsuperheight)) outputsuperimage = Image.new(
"RGBA", (initsuperwidth, initsuperheight))
if embiggen_tiles: if embiggen_tiles:
outputsuperimage.alpha_composite(initsuperimage.convert('RGBA'), (0, 0)) outputsuperimage.alpha_composite(
initsuperimage.convert('RGBA'), (0, 0))
for tile in range(emb_tiles_x * emb_tiles_y): for tile in range(emb_tiles_x * emb_tiles_y):
if embiggen_tiles: if embiggen_tiles:
if tile in embiggen_tiles: if tile in embiggen_tiles:
@ -308,7 +340,8 @@ class Embiggen(Generator):
if emb_column_i + 1 == emb_tiles_x: if emb_column_i + 1 == emb_tiles_x:
left = initsuperwidth - width left = initsuperwidth - width
else: else:
left = round(emb_column_i * (width - overlap_size_x)) left = round(emb_column_i *
(width - overlap_size_x))
if emb_row_i + 1 == emb_tiles_y: if emb_row_i + 1 == emb_tiles_y:
top = initsuperheight - height top = initsuperheight - height
else: else:

View File

@ -37,6 +37,8 @@ def build_opt(post_data, seed, gfpgan_model_exists):
setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed'])) setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed']))
setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0) setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0)
setattr(opt, 'with_variations', []) setattr(opt, 'with_variations', [])
setattr(opt, 'embiggen', None)
setattr(opt, 'embiggen_tiles', None)
broken = False broken = False
if int(post_data['seed']) != -1 and post_data['with_variations'] != '': if int(post_data['seed']) != -1 and post_data['with_variations'] != '':
@ -80,12 +82,11 @@ class DreamServer(BaseHTTPRequestHandler):
self.wfile.write(content.read()) self.wfile.write(content.read())
elif self.path == "/config.js": elif self.path == "/config.js":
# unfortunately this import can't be at the top level, since that would cause a circular import # unfortunately this import can't be at the top level, since that would cause a circular import
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
self.send_response(200) self.send_response(200)
self.send_header("Content-type", "application/javascript") self.send_header("Content-type", "application/javascript")
self.end_headers() self.end_headers()
config = { config = {
'gfpgan_model_exists': gfpgan_model_exists 'gfpgan_model_exists': self.gfpgan_model_exists
} }
self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8")) self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8"))
elif self.path == "/run_log.json": elif self.path == "/run_log.json":
@ -138,11 +139,10 @@ class DreamServer(BaseHTTPRequestHandler):
self.end_headers() self.end_headers()
# unfortunately this import can't be at the top level, since that would cause a circular import # unfortunately this import can't be at the top level, since that would cause a circular import
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
content_length = int(self.headers['Content-Length']) content_length = int(self.headers['Content-Length'])
post_data = json.loads(self.rfile.read(content_length)) post_data = json.loads(self.rfile.read(content_length))
opt = build_opt(post_data, self.model.seed, gfpgan_model_exists) opt = build_opt(post_data, self.model.seed, self.gfpgan_model_exists)
self.canceled.clear() self.canceled.clear()
# In order to handle upscaled images, the PngWriter needs to maintain state # In order to handle upscaled images, the PngWriter needs to maintain state

View File

@ -51,6 +51,24 @@ torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli) torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial) torch.multinomial = fix_func(torch.multinomial)
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw):
device = kw.get("device", "mps")
kw["device"]="cpu"
return orig(*args, **kw).to(device)
return new_func
return orig
torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn)
torch.randn_like = fix_func(torch.randn_like)
torch.randint = fix_func(torch.randint)
torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
"""Simplified text to image API for stable diffusion/latent diffusion """Simplified text to image API for stable diffusion/latent diffusion
Example Usage: Example Usage:
@ -135,6 +153,9 @@ class Generate:
# these are deprecated; if present they override values in the conf file # these are deprecated; if present they override values in the conf file
weights = None, weights = None,
config = None, config = None,
gfpgan=None,
codeformer=None,
esrgan=None
): ):
models = OmegaConf.load(conf) models = OmegaConf.load(conf)
mconfig = models[model] mconfig = models[model]
@ -158,6 +179,9 @@ class Generate:
self.generators = {} self.generators = {}
self.base_generator = None self.base_generator = None
self.seed = None self.seed = None
self.gfpgan = gfpgan
self.codeformer = codeformer
self.esrgan = esrgan
# Note that in previous versions, there was an option to pass the # Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so # device to Generate(). However the device was then ignored, so
@ -312,7 +336,8 @@ class Generate:
0.0 <= variation_amount <= 1.0 0.0 <= variation_amount <= 1.0
), '-v --variation_amount must be in [0.0, 1.0]' ), '-v --variation_amount must be in [0.0, 1.0]'
assert ( assert (
(embiggen == None and embiggen_tiles == None) or ((embiggen != None or embiggen_tiles != None) and init_img != None) (embiggen == None and embiggen_tiles == None) or (
(embiggen != None or embiggen_tiles != None) and init_img != None)
), 'Embiggen requires an init/input image to be specified' ), 'Embiggen requires an init/input image to be specified'
if len(with_variations) > 0 or variation_amount > 1.0: if len(with_variations) > 0 or variation_amount > 1.0:
@ -345,7 +370,8 @@ class Generate:
log_tokens =self.log_tokenization log_tokens =self.log_tokenization
) )
(init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit) (init_image, mask_image) = self._make_images(
init_img, init_mask, width, height, fit)
if (init_image is not None) and (mask_image is not None): if (init_image is not None) and (mask_image is not None):
generator = self._make_inpaint() generator = self._make_inpaint()
@ -356,7 +382,8 @@ class Generate:
else: else:
generator = self._make_txt2img() generator = self._make_txt2img()
generator.set_variation(self.seed, variation_amount, with_variations) generator.set_variation(
self.seed, variation_amount, with_variations)
results = generator.generate( results = generator.generate(
prompt, prompt,
iterations=iterations, iterations=iterations,
@ -404,7 +431,8 @@ class Generate:
toc = time.time() toc = time.time()
print('>> Usage stats:') print('>> Usage stats:')
print( print(
f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic) f'>> {len(results)} image(s) generated in', '%4.2fs' % (
toc - tic)
) )
if self._has_cuda(): if self._has_cuda():
print( print(
@ -520,21 +548,27 @@ class Generate:
if not img_path: if not img_path:
return None, None return None, None
image = self._load_img(img_path, width, height, fit=fit) # this returns an Image image = self._load_img(img_path, width, height,
init_image = self._create_init_image(image) # this returns a torch tensor fit=fit) # this returns an Image
# this returns a torch tensor
init_image = self._create_init_image(image)
if self._has_transparency(image) and not mask_path: # if image has a transparent area and no mask was provided, then try to generate mask # if image has a transparent area and no mask was provided, then try to generate mask
print('>> Initial image has transparent areas. Will inpaint in these regions.') if self._has_transparency(image) and not mask_path:
print(
'>> Initial image has transparent areas. Will inpaint in these regions.')
if self._check_for_erasure(image): if self._check_for_erasure(image):
print( print(
'>> WARNING: Colors underneath the transparent region seem to have been erased.\n', '>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
'>> Inpainting will be suboptimal. Please preserve the colors when making\n', '>> Inpainting will be suboptimal. Please preserve the colors when making\n',
'>> a transparency mask, or provide mask explicitly using --init_mask (-M).' '>> a transparency mask, or provide mask explicitly using --init_mask (-M).'
) )
init_mask = self._create_init_mask(image) # this returns a torch tensor # this returns a torch tensor
init_mask = self._create_init_mask(image)
if mask_path: if mask_path:
mask_image = self._load_img(mask_path, width, height, fit=fit) # this returns an Image mask_image = self._load_img(
mask_path, width, height, fit=fit) # this returns an Image
init_mask = self._create_init_mask(mask_image) init_mask = self._create_init_mask(mask_image)
return init_image, init_mask return init_image, init_mask
@ -619,38 +653,26 @@ class Generate:
codeformer_fidelity = 0.75, codeformer_fidelity = 0.75,
save_original = False, save_original = False,
image_callback = None): image_callback = None):
try:
if upscale is not None:
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
if strength > 0:
if facetool == 'codeformer':
from ldm.restoration.codeformer.codeformer import CodeFormerRestoration
else:
from ldm.gfpgan.gfpgan_tools import run_gfpgan
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
return
for r in image_list: for r in image_list:
image, seed = r image, seed = r
try: try:
if upscale is not None: if upscale is not None:
if self.esrgan is not None:
if len(upscale) < 2: if len(upscale) < 2:
upscale.append(0.75) upscale.append(0.75)
image = real_esrgan_upscale( image = self.esrgan.process(
image, image, upscale[1], seed, int(upscale[0]))
upscale[1],
int(upscale[0]),
seed,
)
if strength > 0:
if facetool == 'codeformer':
image = CodeFormerRestoration().process(image=image, strength=strength, device=self.device, seed=seed, fidelity=codeformer_fidelity)
else: else:
image = run_gfpgan( print(">> ESRGAN is disabled. Image not upscaled.")
image, strength, seed, 1 if strength > 0:
) if self.gfpgan is not None and self.codeformer is not None:
if facetool == 'codeformer':
image = self.codeformer.process(image=image, strength=strength, device=self.device, seed=seed, fidelity=codeformer_fidelity)
else:
image = self.gfpgan.process(image, strength, seed)
else:
print(">> Face Restoration is disabled.")
except Exception as e: except Exception as e:
print( print(
f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}' f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}'
@ -779,7 +801,8 @@ class Generate:
image = image.convert('RGB') image = image.convert('RGB')
# BUG: We need to use the model's downsample factor rather than hardcoding "8" # BUG: We need to use the model's downsample factor rather than hardcoding "8"
from ldm.dream.generator.base import downsampling from ldm.dream.generator.base import downsampling
image = image.resize((image.width//downsampling, image.height//downsampling), resample=Image.Resampling.LANCZOS) image = image.resize((image.width//downsampling, image.height //
downsampling), resample=Image.Resampling.LANCZOS)
# print( # print(
# f'>> DEBUG: writing the mask to mask.png' # f'>> DEBUG: writing the mask to mask.png'
# ) # )
@ -815,7 +838,6 @@ class Generate:
return True return True
return False return False
def _check_for_erasure(self, image): def _check_for_erasure(self, image):
width, height = image.size width, height = image.size
pixdata = image.load() pixdata = image.load()
@ -835,7 +857,6 @@ class Generate:
return InitImageResizer(image).resize(x, y) return InitImageResizer(image).resize(x, y)
return image return image
def _fit_image(self, image, max_dimensions): def _fit_image(self, image, max_dimensions):
w, h = max_dimensions w, h = max_dimensions
print( print(
@ -847,7 +868,8 @@ class Generate:
w = None # ditto for w w = None # ditto for w
else: else:
pass pass
image = InitImageResizer(image).resize(w,h) # note that InitImageResizer does the multiple of 64 truncation internally # note that InitImageResizer does the multiple of 64 truncation internally
image = InitImageResizer(image).resize(w, h)
print( print(
f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}' f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}'
) )

View File

@ -1,168 +0,0 @@
import torch
import warnings
import os
import sys
import numpy as np
from PIL import Image
#from scripts.dream import create_argv_parser
from ldm.dream.args import Args
opt = Args()
opt.parse_args()
model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path)
gfpgan_model_exists = os.path.isfile(model_path)
def run_gfpgan(image, strength, seed, upsampler_scale=4):
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
gfpgan = None
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
if not gfpgan_model_exists:
raise Exception('GFPGAN model not found at path ' + model_path)
sys.path.append(os.path.abspath(opt.gfpgan_dir))
from gfpgan import GFPGANer
bg_upsampler = _load_gfpgan_bg_upsampler(
opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile
)
gfpgan = GFPGANer(
model_path=model_path,
upscale=upsampler_scale,
arch='clean',
channel_multiplier=2,
bg_upsampler=bg_upsampler,
)
except Exception:
import traceback
print('>> Error loading GFPGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if gfpgan is None:
print(
f'>> WARNING: GFPGAN not initialized.'
)
print(
f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth to {model_path}, \nor change GFPGAN directory with --gfpgan_dir.'
)
return image
image = image.convert('RGB')
cropped_faces, restored_faces, restored_img = gfpgan.enhance(
np.array(image, dtype=np.uint8),
has_aligned=False,
only_center_face=False,
paste_back=True,
)
res = Image.fromarray(restored_img)
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if restored_img.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
gfpgan = None
return res
def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400):
if bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
model_path = {
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
}
if upsampler_scale not in model_path:
return None
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
if upsampler_scale == 4:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
if upsampler_scale == 2:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
bg_upsampler = RealESRGANer(
scale=upsampler_scale,
model_path=model_path[upsampler_scale],
model=model,
tile=bg_tile,
tile_pad=10,
pre_pad=0,
half=use_half_precision,
)
else:
bg_upsampler = None
return bg_upsampler
def real_esrgan_upscale(image, strength, upsampler_scale, seed):
print(
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
upsampler = _load_gfpgan_bg_upsampler(
opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile
)
except Exception:
import traceback
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
output, img_mode = upsampler.enhance(
np.array(image, dtype=np.uint8),
outscale=upsampler_scale,
alpha_upsampler=opt.gfpgan_bg_upsampler,
)
res = Image.fromarray(output)
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if output.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
upsampler = None
return res

View File

@ -2,12 +2,20 @@ import os
import torch import torch
import numpy as np import numpy as np
import warnings import warnings
import sys
pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
class CodeFormerRestoration(): class CodeFormerRestoration():
def __init__(self) -> None: def __init__(self,
pass codeformer_dir='ldm/restoration/codeformer',
codeformer_model_path='weights/codeformer.pth') -> None:
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
self.codeformer_model_exists = os.path.isfile(self.model_path)
if not self.codeformer_model_exists:
print('## NOT FOUND: CodeFormer model not found at ' + self.model_path)
sys.path.append(os.path.abspath(codeformer_dir))
def process(self, image, strength, device, seed=None, fidelity=0.75): def process(self, image, strength, device, seed=None, fidelity=0.75):
if seed is not None: if seed is not None:

View File

@ -0,0 +1,76 @@
import torch
import warnings
import os
import sys
import numpy as np
from PIL import Image
class GFPGAN():
def __init__(
self,
gfpgan_dir='src/gfpgan',
gfpgan_model_path='experiments/pretrained_models/GFPGANv1.3.pth') -> None:
self.model_path = os.path.join(gfpgan_dir, gfpgan_model_path)
self.gfpgan_model_exists = os.path.isfile(self.model_path)
if not self.gfpgan_model_exists:
raise Exception(
'GFPGAN model not found at path ' + self.model_path)
sys.path.append(os.path.abspath(gfpgan_dir))
def model_exists(self):
return os.path.isfile(self.model_path)
def process(self, image, strength: float, seed: str = None):
if seed is not None:
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
from gfpgan import GFPGANer
self.gfpgan = GFPGANer(
model_path=self.model_path,
upscale=1,
arch='clean',
channel_multiplier=2,
bg_upsampler=None,
)
except Exception:
import traceback
print('>> Error loading GFPGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if self.gfpgan is None:
print(
f'>> WARNING: GFPGAN not initialized.'
)
print(
f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth to {self.model_path}, \nor change GFPGAN directory with --gfpgan_dir.'
)
image = image.convert('RGB')
_, _, restored_img = self.gfpgan.enhance(
np.array(image, dtype=np.uint8),
has_aligned=False,
only_center_face=False,
paste_back=True,
)
res = Image.fromarray(restored_img)
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if restored_img.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.gfpgan = None
return res

View File

@ -0,0 +1,102 @@
import torch
import warnings
import numpy as np
from PIL import Image
class ESRGAN():
def __init__(self, bg_tile_size=400) -> None:
self.bg_tile_size = bg_tile_size
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
def load_esrgan_bg_upsampler(self, upsampler_scale):
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
model_path = {
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
}
if upsampler_scale not in model_path:
return None
else:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
if upsampler_scale == 4:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
if upsampler_scale == 2:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
bg_upsampler = RealESRGANer(
scale=upsampler_scale,
model_path=model_path[upsampler_scale],
model=model,
tile=self.bg_tile_size,
tile_pad=10,
pre_pad=0,
half=use_half_precision,
)
return bg_upsampler
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
if seed is not None:
print(
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
upsampler = self.load_esrgan_bg_upsampler(upsampler_scale)
except Exception:
import traceback
import sys
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
output, _ = upsampler.enhance(
np.array(image, dtype=np.uint8),
outscale=upsampler_scale,
alpha_upsampler='realesrgan',
)
res = Image.fromarray(output)
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if output.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
upsampler = None
return res

View File

@ -0,0 +1,34 @@
class Restoration():
def __init__(self, gfpgan_dir='./src/gfpgan', gfpgan_model_path='experiments/pretrained_models/GFPGANv1.3.pth', esrgan_bg_tile=400) -> None:
self.gfpgan_dir = gfpgan_dir
self.gfpgan_model_path = gfpgan_model_path
self.esrgan_bg_tile = esrgan_bg_tile
def load_face_restore_models(self):
# Load GFPGAN
gfpgan = self.load_gfpgan()
if gfpgan.gfpgan_model_exists:
print('>> GFPGAN Initialized')
# Load CodeFormer
codeformer = self.load_codeformer()
if codeformer.codeformer_model_exists:
print('>> CodeFormer Initialized')
return gfpgan, codeformer
# Face Restore Models
def load_gfpgan(self):
from ldm.restoration.gfpgan.gfpgan import GFPGAN
return GFPGAN(self.gfpgan_dir, self.gfpgan_model_path)
def load_codeformer(self):
from ldm.restoration.codeformer.codeformer import CodeFormerRestoration
return CodeFormerRestoration()
# Upscale Models
def load_ersgan(self):
from ldm.restoration.realesrgan.realesrgan import ESRGAN
esrgan = ESRGAN(self.esrgan_bg_tile)
print('>> ESRGAN Initialized')
return esrgan;

View File

@ -43,7 +43,25 @@ def main():
import transformers import transformers
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
# creating a simple Generate object with a handful of # Loading Face Restoration and ESRGAN Modules
try:
gfpgan, codeformer, esrgan = None, None, None
from ldm.restoration.restoration import Restoration
restoration = Restoration(opt.gfpgan_dir, opt.gfpgan_model_path, opt.esrgan_bg_tile)
if opt.restore:
gfpgan, codeformer = restoration.load_face_restore_models()
else:
print('>> Face Restoration Disabled')
if opt.esrgan:
esrgan = restoration.load_ersgan()
else:
print('>> ESRGAN Disabled')
except (ModuleNotFoundError, ImportError):
import traceback
print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
# creating a simple text2image object with a handful of
# defaults passed on the command line. # defaults passed on the command line.
# additional parameters will be added (or overriden) during # additional parameters will be added (or overriden) during
# the user input loop # the user input loop
@ -55,6 +73,9 @@ def main():
embedding_path = opt.embedding_path, embedding_path = opt.embedding_path,
full_precision = opt.full_precision, full_precision = opt.full_precision,
precision = opt.precision, precision = opt.precision,
gfpgan=gfpgan,
codeformer=codeformer,
esrgan=esrgan
) )
except (FileNotFoundError, IOError, KeyError) as e: except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.') print(f'{e}. Aborting.')
@ -91,7 +112,7 @@ def main():
# web server loops forever # web server loops forever
if opt.web: if opt.web:
dream_server_loop(gen, opt.host, opt.port, opt.outdir) dream_server_loop(gen, opt.host, opt.port, opt.outdir, gfpgan)
sys.exit(0) sys.exit(0)
main_loop(gen, opt, infile) main_loop(gen, opt, infile)
@ -347,7 +368,7 @@ def get_next_command(infile=None) -> str: # command string
print(f'#{command}') print(f'#{command}')
return command return command
def dream_server_loop(gen, host, port, outdir): def dream_server_loop(gen, host, port, outdir, gfpgan):
print('\n* --web was specified, starting web server...') print('\n* --web was specified, starting web server...')
# Change working directory to the stable-diffusion directory # Change working directory to the stable-diffusion directory
os.chdir( os.chdir(
@ -357,6 +378,10 @@ def dream_server_loop(gen, host, port, outdir):
# Start server # Start server
DreamServer.model = gen # misnomer in DreamServer - this is not the model you are looking for DreamServer.model = gen # misnomer in DreamServer - this is not the model you are looking for
DreamServer.outdir = outdir DreamServer.outdir = outdir
DreamServer.gfpgan_model_exists = False
if gfpgan is not None:
DreamServer.gfpgan_model_exists = gfpgan.gfpgan_model_exists
dream_server = ThreadingDreamServer((host, port)) dream_server = ThreadingDreamServer((host, port))
print(">> Started Stable Diffusion dream server!") print(">> Started Stable Diffusion dream server!")
if host == '0.0.0.0': if host == '0.0.0.0':
@ -374,5 +399,19 @@ def dream_server_loop(gen, host, port, outdir):
dream_server.server_close() dream_server.server_close()
<<<<<<< HEAD
=======
def write_log_message(results, log_path):
"""logs the name of the output image, prompt, and prompt args to the terminal and log file"""
global output_cntr
log_lines = [f'{path}: {prompt}\n' for path, prompt in results]
for l in log_lines:
output_cntr += 1
print(f'[{output_cntr}] {l}', end='')
with open(log_path, 'a', encoding='utf-8') as file:
file.writelines(log_lines)
>>>>>>> GFPGAN and Real ESRGAN Implementation Refactor
if __name__ == '__main__': if __name__ == '__main__':
main() main()