From 7b0cbb34d618098b4072f14870937ee9eb4369a1 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 14 Sep 2022 05:17:14 +1200 Subject: [PATCH] GFPGAN and Real ESRGAN Implementation Refactor --- docs/features/UPSCALE.md | 10 +- ldm/dream/args.py | 17 +- ldm/dream/generator/embiggen.py | 217 ++++++++++++--------- ldm/dream/server.py | 8 +- ldm/generate.py | 234 +++++++++++++---------- ldm/gfpgan/gfpgan_tools.py | 168 ---------------- ldm/restoration/codeformer/codeformer.py | 12 +- ldm/restoration/gfpgan/gfpgan.py | 76 ++++++++ ldm/restoration/realesrgan/realesrgan.py | 102 ++++++++++ ldm/restoration/restoration.py | 34 ++++ scripts/dream.py | 34 +++- 11 files changed, 526 insertions(+), 386 deletions(-) delete mode 100644 ldm/gfpgan/gfpgan_tools.py create mode 100644 ldm/restoration/gfpgan/gfpgan.py create mode 100644 ldm/restoration/realesrgan/realesrgan.py create mode 100644 ldm/restoration/restoration.py diff --git a/docs/features/UPSCALE.md b/docs/features/UPSCALE.md index 28d85c1d71..259b569e88 100644 --- a/docs/features/UPSCALE.md +++ b/docs/features/UPSCALE.md @@ -2,10 +2,16 @@ title: Upscale --- +## **Intro** + +The script provides the ability to restore faces and upscale. + +You can enable these features by passing `--restore` and `--esrgan` to your launch script to enable +face restoration modules and upscaling modules respectively. + ## **GFPGAN and Real-ESRGAN Support** -The script also provides the ability to do face restoration and upscaling with the help of GFPGAN -and Real-ESRGAN respectively. +The default face restoration module is GFPGAN and the default upscaling module is ESRGAN. As of version 1.14, environment.yaml will install the Real-ESRGAN package into the standard install location for python packages, and will put GFPGAN into a subdirectory of "src" in the diff --git a/ldm/dream/args.py b/ldm/dream/args.py index db6d963645..f0feacad73 100644 --- a/ldm/dream/args.py +++ b/ldm/dream/args.py @@ -348,16 +348,19 @@ class Args(object): type=str, help='Path to a pre-trained embedding manager checkpoint - can only be set on command line', ) - # GFPGAN related args + # Restoration related args postprocessing_group.add_argument( - '--gfpgan_bg_upsampler', - type=str, - default='realesrgan', - help='Background upsampler. Default: realesrgan. Options: realesrgan, none.', - + '--restore', + action='store_true', + help='Enable Face Restoration', ) postprocessing_group.add_argument( - '--gfpgan_bg_tile', + '--esrgan', + action='store_true', + help='Enable Upscaling', + ) + postprocessing_group.add_argument( + '--esrgan_bg_tile', type=int, default=400, help='Tile size for background sampler, 0 for no tile during testing. Default: 400.', diff --git a/ldm/dream/generator/embiggen.py b/ldm/dream/generator/embiggen.py index cb9c029a66..e196e3005f 100644 --- a/ldm/dream/generator/embiggen.py +++ b/ldm/dream/generator/embiggen.py @@ -4,16 +4,17 @@ and generates with ldm.dream.generator.img2img ''' import torch -import numpy as np +import numpy as np from PIL import Image -from ldm.dream.generator.base import Generator -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.dream.generator.img2img import Img2Img +from ldm.dream.generator.base import Generator +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.dream.generator.img2img import Img2Img + class Embiggen(Generator): - def __init__(self,model): + def __init__(self, model): super().__init__(model) - self.init_latent = None + self.init_latent = None @torch.no_grad() def get_make_image( @@ -38,19 +39,20 @@ class Embiggen(Generator): Return value depends on the seed at the time you call it """ # Construct embiggen arg array, and sanity check arguments - if embiggen == None: # embiggen can also be called with just embiggen_tiles - embiggen = [1.0] # If not specified, assume no scaling - elif embiggen[0] < 0 : + if embiggen == None: # embiggen can also be called with just embiggen_tiles + embiggen = [1.0] # If not specified, assume no scaling + elif embiggen[0] < 0: embiggen[0] = 1.0 - print('>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !') + print( + '>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !') if len(embiggen) < 2: embiggen.append(0.75) - elif embiggen[1] > 1.0 or embiggen[1] < 0 : + elif embiggen[1] > 1.0 or embiggen[1] < 0: embiggen[1] = 0.75 print('>> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !') if len(embiggen) < 3: embiggen.append(0.25) - elif embiggen[2] < 0 : + elif embiggen[2] < 0: embiggen[2] = 0.25 print('>> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !') @@ -76,29 +78,30 @@ class Embiggen(Generator): if embiggen[0] != 1.0: initsuperwidth = round(initsuperwidth*embiggen[0]) initsuperheight = round(initsuperheight*embiggen[0]) - if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero - from ldm.gfpgan.gfpgan_tools import ( - real_esrgan_upscale, - ) - print(f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}') + if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero + from ldm.restoration.realesrgan import ESRGAN + esrgan = ESRGAN() + print( + f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}') if embiggen[0] > 2: - initsuperimage = real_esrgan_upscale( + initsuperimage = esrgan.process( initsuperimage, - embiggen[1], # upscale strength - 4, # upscale scale + embiggen[1], # upscale strength self.seed, + 4, # upscale scale ) else: - initsuperimage = real_esrgan_upscale( + initsuperimage = esrgan.process( initsuperimage, - embiggen[1], # upscale strength - 2, # upscale scale + embiggen[1], # upscale strength self.seed, + 2, # upscale scale ) # We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x # but from personal experiance it doesn't greatly improve anything after 4x # Resize to target scaling factor resolution - initsuperimage = initsuperimage.resize((initsuperwidth, initsuperheight), Image.Resampling.LANCZOS) + initsuperimage = initsuperimage.resize( + (initsuperwidth, initsuperheight), Image.Resampling.LANCZOS) # Use width and height as tile widths and height # Determine buffer size in pixels @@ -121,28 +124,31 @@ class Embiggen(Generator): emb_tiles_x = 1 emb_tiles_y = 1 if (initsuperwidth - width) > 0: - emb_tiles_x = ceildiv(initsuperwidth - width, width - overlap_size_x) + 1 + emb_tiles_x = ceildiv(initsuperwidth - width, + width - overlap_size_x) + 1 if (initsuperheight - height) > 0: - emb_tiles_y = ceildiv(initsuperheight - height, height - overlap_size_y) + 1 + emb_tiles_y = ceildiv(initsuperheight - height, + height - overlap_size_y) + 1 # Sanity assert emb_tiles_x > 1 or emb_tiles_y > 1, f'ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don\'t need to Embiggen! Check your arguments.' # Prep alpha layers -------------- # https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil # agradientL is Left-side transparent - agradientL = Image.linear_gradient('L').rotate(90).resize((overlap_size_x, height)) + agradientL = Image.linear_gradient('L').rotate( + 90).resize((overlap_size_x, height)) # agradientT is Top-side transparent agradientT = Image.linear_gradient('L').resize((width, overlap_size_y)) # radial corner is the left-top corner, made full circle then cut to just the left-top quadrant agradientC = Image.new('L', (256, 256)) for y in range(256): for x in range(256): - #Find distance to lower right corner (numpy takes arrays) + # Find distance to lower right corner (numpy takes arrays) distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0] - #Clamp values to max 255 + # Clamp values to max 255 if distanceToLR > 255: distanceToLR = 255 - #Place the pixel as invert of distance + # Place the pixel as invert of distance agradientC.putpixel((x, y), int(255 - distanceToLR)) # Create alpha layers default fully white @@ -154,59 +160,79 @@ class Embiggen(Generator): alphaLayerT.paste(agradientT, (0, 0)) alphaLayerLTC.paste(agradientL, (0, 0)) alphaLayerLTC.paste(agradientT, (0, 0)) - alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)) + alphaLayerLTC.paste(agradientC.resize( + (overlap_size_x, overlap_size_y)), (0, 0)) if embiggen_tiles: # Individual unconnected sides alphaLayerR = Image.new("L", (width, height), 255) - alphaLayerR.paste(agradientL.rotate(180), (width - overlap_size_x, 0)) + alphaLayerR.paste(agradientL.rotate( + 180), (width - overlap_size_x, 0)) alphaLayerB = Image.new("L", (width, height), 255) - alphaLayerB.paste(agradientT.rotate(180), (0, height - overlap_size_y)) + alphaLayerB.paste(agradientT.rotate( + 180), (0, height - overlap_size_y)) alphaLayerTB = Image.new("L", (width, height), 255) alphaLayerTB.paste(agradientT, (0, 0)) - alphaLayerTB.paste(agradientT.rotate(180), (0, height - overlap_size_y)) + alphaLayerTB.paste(agradientT.rotate( + 180), (0, height - overlap_size_y)) alphaLayerLR = Image.new("L", (width, height), 255) alphaLayerLR.paste(agradientL, (0, 0)) - alphaLayerLR.paste(agradientL.rotate(180), (width - overlap_size_x, 0)) + alphaLayerLR.paste(agradientL.rotate( + 180), (width - overlap_size_x, 0)) # Sides and corner Layers alphaLayerRBC = Image.new("L", (width, height), 255) - alphaLayerRBC.paste(agradientL.rotate(180), (width - overlap_size_x, 0)) - alphaLayerRBC.paste(agradientT.rotate(180), (0, height - overlap_size_y)) - alphaLayerRBC.paste(agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y)) + alphaLayerRBC.paste(agradientL.rotate( + 180), (width - overlap_size_x, 0)) + alphaLayerRBC.paste(agradientT.rotate( + 180), (0, height - overlap_size_y)) + alphaLayerRBC.paste(agradientC.rotate(180).resize( + (overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y)) alphaLayerLBC = Image.new("L", (width, height), 255) alphaLayerLBC.paste(agradientL, (0, 0)) - alphaLayerLBC.paste(agradientT.rotate(180), (0, height - overlap_size_y)) - alphaLayerLBC.paste(agradientC.rotate(90).resize((overlap_size_x, overlap_size_y)), (0, height - overlap_size_y)) + alphaLayerLBC.paste(agradientT.rotate( + 180), (0, height - overlap_size_y)) + alphaLayerLBC.paste(agradientC.rotate(90).resize( + (overlap_size_x, overlap_size_y)), (0, height - overlap_size_y)) alphaLayerRTC = Image.new("L", (width, height), 255) - alphaLayerRTC.paste(agradientL.rotate(180), (width - overlap_size_x, 0)) + alphaLayerRTC.paste(agradientL.rotate( + 180), (width - overlap_size_x, 0)) alphaLayerRTC.paste(agradientT, (0, 0)) - alphaLayerRTC.paste(agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0)) + alphaLayerRTC.paste(agradientC.rotate(270).resize( + (overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0)) # All but X layers alphaLayerABT = Image.new("L", (width, height), 255) alphaLayerABT.paste(alphaLayerLBC, (0, 0)) - alphaLayerABT.paste(agradientL.rotate(180), (width - overlap_size_x, 0)) - alphaLayerABT.paste(agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y)) + alphaLayerABT.paste(agradientL.rotate( + 180), (width - overlap_size_x, 0)) + alphaLayerABT.paste(agradientC.rotate(180).resize( + (overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y)) alphaLayerABL = Image.new("L", (width, height), 255) alphaLayerABL.paste(alphaLayerRTC, (0, 0)) - alphaLayerABL.paste(agradientT.rotate(180), (0, height - overlap_size_y)) - alphaLayerABL.paste(agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y)) + alphaLayerABL.paste(agradientT.rotate( + 180), (0, height - overlap_size_y)) + alphaLayerABL.paste(agradientC.rotate(180).resize( + (overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y)) alphaLayerABR = Image.new("L", (width, height), 255) alphaLayerABR.paste(alphaLayerLBC, (0, 0)) alphaLayerABR.paste(agradientT, (0, 0)) - alphaLayerABR.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)) + alphaLayerABR.paste(agradientC.resize( + (overlap_size_x, overlap_size_y)), (0, 0)) alphaLayerABB = Image.new("L", (width, height), 255) alphaLayerABB.paste(alphaLayerRTC, (0, 0)) alphaLayerABB.paste(agradientL, (0, 0)) - alphaLayerABB.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)) + alphaLayerABB.paste(agradientC.resize( + (overlap_size_x, overlap_size_y)), (0, 0)) # All-around layer alphaLayerAA = Image.new("L", (width, height), 255) alphaLayerAA.paste(alphaLayerABT, (0, 0)) alphaLayerAA.paste(agradientT, (0, 0)) - alphaLayerAA.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)) - alphaLayerAA.paste(agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0)) + alphaLayerAA.paste(agradientC.resize( + (overlap_size_x, overlap_size_y)), (0, 0)) + alphaLayerAA.paste(agradientC.rotate(270).resize( + (overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0)) # Clean up temporary gradients del agradientL @@ -218,7 +244,8 @@ class Embiggen(Generator): if embiggen_tiles: print(f'>> Making {len(embiggen_tiles)} Embiggen tiles...') else: - print(f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...') + print( + f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...') emb_tile_store = [] for tile in range(emb_tiles_x * emb_tiles_y): @@ -240,20 +267,23 @@ class Embiggen(Generator): top = round(emb_row_i * (height - overlap_size_y)) right = left + width bottom = top + height - + # Cropped image of above dimension (does not modify the original) newinitimage = initsuperimage.crop((left, top, right, bottom)) # DEBUG: # newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png' # newinitimage.save(newinitimagepath) - + if embiggen_tiles: - print(f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)') + print( + f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)') else: - print(f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles') + print( + f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles') # create a torch tensor from an Image - newinitimage = np.array(newinitimage).astype(np.float32) / 255.0 + newinitimage = np.array( + newinitimage).astype(np.float32) / 255.0 newinitimage = newinitimage[None].transpose(0, 3, 1, 2) newinitimage = torch.from_numpy(newinitimage) newinitimage = 2.0 * newinitimage - 1.0 @@ -261,33 +291,35 @@ class Embiggen(Generator): tile_results = gen_img2img.generate( prompt, - iterations = 1, - seed = self.seed, - sampler = sampler, - steps = steps, - cfg_scale = cfg_scale, - conditioning = conditioning, - ddim_eta = ddim_eta, - image_callback = None, # called only after the final image is generated - step_callback = step_callback, # called after each intermediate image is generated - width = width, - height = height, - init_img = init_img, # img2img doesn't need this, but it might in the future - init_image = newinitimage, # notice that init_image is different from init_img - mask_image = None, - strength = strength, + iterations=1, + seed=self.seed, + sampler=sampler, + steps=steps, + cfg_scale=cfg_scale, + conditioning=conditioning, + ddim_eta=ddim_eta, + image_callback=None, # called only after the final image is generated + step_callback=step_callback, # called after each intermediate image is generated + width=width, + height=height, + init_img=init_img, # img2img doesn't need this, but it might in the future + init_image=newinitimage, # notice that init_image is different from init_img + mask_image=None, + strength=strength, ) emb_tile_store.append(tile_results[0][0]) # DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite # emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png') del newinitimage - + # Sanity check we have them all if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)): - outputsuperimage = Image.new("RGBA", (initsuperwidth, initsuperheight)) + outputsuperimage = Image.new( + "RGBA", (initsuperwidth, initsuperheight)) if embiggen_tiles: - outputsuperimage.alpha_composite(initsuperimage.convert('RGBA'), (0, 0)) + outputsuperimage.alpha_composite( + initsuperimage.convert('RGBA'), (0, 0)) for tile in range(emb_tiles_x * emb_tiles_y): if embiggen_tiles: if tile in embiggen_tiles: @@ -308,7 +340,8 @@ class Embiggen(Generator): if emb_column_i + 1 == emb_tiles_x: left = initsuperwidth - width else: - left = round(emb_column_i * (width - overlap_size_x)) + left = round(emb_column_i * + (width - overlap_size_x)) if emb_row_i + 1 == emb_tiles_y: top = initsuperheight - height else: @@ -319,33 +352,33 @@ class Embiggen(Generator): # top of image if emb_row_i == 0: if emb_column_i == 0: - if (tile+1) in embiggen_tiles: # Look-ahead right - if (tile+emb_tiles_x) not in embiggen_tiles: # Look-ahead down + if (tile+1) in embiggen_tiles: # Look-ahead right + if (tile+emb_tiles_x) not in embiggen_tiles: # Look-ahead down intileimage.putalpha(alphaLayerB) # Otherwise do nothing on this tile - elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only + elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only intileimage.putalpha(alphaLayerR) else: intileimage.putalpha(alphaLayerRBC) elif emb_column_i == emb_tiles_x - 1: - if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down + if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down intileimage.putalpha(alphaLayerL) else: intileimage.putalpha(alphaLayerLBC) else: - if (tile+1) in embiggen_tiles: # Look-ahead right - if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down + if (tile+1) in embiggen_tiles: # Look-ahead right + if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down intileimage.putalpha(alphaLayerL) else: intileimage.putalpha(alphaLayerLBC) - elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only + elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only intileimage.putalpha(alphaLayerLR) else: intileimage.putalpha(alphaLayerABT) # bottom of image elif emb_row_i == emb_tiles_y - 1: if emb_column_i == 0: - if (tile+1) in embiggen_tiles: # Look-ahead right + if (tile+1) in embiggen_tiles: # Look-ahead right intileimage.putalpha(alphaLayerT) else: intileimage.putalpha(alphaLayerRTC) @@ -353,34 +386,34 @@ class Embiggen(Generator): # No tiles to look ahead to intileimage.putalpha(alphaLayerLTC) else: - if (tile+1) in embiggen_tiles: # Look-ahead right + if (tile+1) in embiggen_tiles: # Look-ahead right intileimage.putalpha(alphaLayerLTC) else: intileimage.putalpha(alphaLayerABB) # vertical middle of image else: if emb_column_i == 0: - if (tile+1) in embiggen_tiles: # Look-ahead right - if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down + if (tile+1) in embiggen_tiles: # Look-ahead right + if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down intileimage.putalpha(alphaLayerT) else: intileimage.putalpha(alphaLayerTB) - elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only + elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only intileimage.putalpha(alphaLayerRTC) else: intileimage.putalpha(alphaLayerABL) elif emb_column_i == emb_tiles_x - 1: - if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down + if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down intileimage.putalpha(alphaLayerLTC) else: intileimage.putalpha(alphaLayerABR) else: - if (tile+1) in embiggen_tiles: # Look-ahead right - if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down + if (tile+1) in embiggen_tiles: # Look-ahead right + if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down intileimage.putalpha(alphaLayerLTC) else: intileimage.putalpha(alphaLayerABR) - elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only + elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only intileimage.putalpha(alphaLayerABB) else: intileimage.putalpha(alphaLayerAA) @@ -400,4 +433,4 @@ class Embiggen(Generator): # after internal loops and patching up return Embiggen image return outputsuperimage # end of function declaration - return make_image \ No newline at end of file + return make_image diff --git a/ldm/dream/server.py b/ldm/dream/server.py index 9e37c070d1..03114ac9d2 100644 --- a/ldm/dream/server.py +++ b/ldm/dream/server.py @@ -37,6 +37,8 @@ def build_opt(post_data, seed, gfpgan_model_exists): setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed'])) setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0) setattr(opt, 'with_variations', []) + setattr(opt, 'embiggen', None) + setattr(opt, 'embiggen_tiles', None) broken = False if int(post_data['seed']) != -1 and post_data['with_variations'] != '': @@ -80,12 +82,11 @@ class DreamServer(BaseHTTPRequestHandler): self.wfile.write(content.read()) elif self.path == "/config.js": # unfortunately this import can't be at the top level, since that would cause a circular import - from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists self.send_response(200) self.send_header("Content-type", "application/javascript") self.end_headers() config = { - 'gfpgan_model_exists': gfpgan_model_exists + 'gfpgan_model_exists': self.gfpgan_model_exists } self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8")) elif self.path == "/run_log.json": @@ -138,11 +139,10 @@ class DreamServer(BaseHTTPRequestHandler): self.end_headers() # unfortunately this import can't be at the top level, since that would cause a circular import - from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists content_length = int(self.headers['Content-Length']) post_data = json.loads(self.rfile.read(content_length)) - opt = build_opt(post_data, self.model.seed, gfpgan_model_exists) + opt = build_opt(post_data, self.model.seed, self.gfpgan_model_exists) self.canceled.clear() # In order to handle upscaled images, the PngWriter needs to maintain state diff --git a/ldm/generate.py b/ldm/generate.py index a470648cdc..2bd53ac57b 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -23,14 +23,32 @@ from PIL import Image, ImageOps from torch import nn from pytorch_lightning import seed_everything, logging -from ldm.util import instantiate_from_config -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.ksampler import KSampler -from ldm.dream.pngwriter import PngWriter -from ldm.dream.image_util import InitImageResizer -from ldm.dream.devices import choose_torch_device -from ldm.dream.conditioning import get_uc_and_c +from ldm.dream.pngwriter import PngWriter +from ldm.dream.image_util import InitImageResizer +from ldm.dream.devices import choose_torch_device +from ldm.dream.conditioning import get_uc_and_c + +def fix_func(orig): + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + def new_func(*args, **kw): + device = kw.get("device", "mps") + kw["device"]="cpu" + return orig(*args, **kw).to(device) + return new_func + return orig + +torch.rand = fix_func(torch.rand) +torch.rand_like = fix_func(torch.rand_like) +torch.randn = fix_func(torch.randn) +torch.randn_like = fix_func(torch.randn_like) +torch.randint = fix_func(torch.randint) +torch.randint_like = fix_func(torch.randint_like) +torch.bernoulli = fix_func(torch.bernoulli) +torch.multinomial = fix_func(torch.multinomial) def fix_func(orig): if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): @@ -133,6 +151,9 @@ class Generate: # these are deprecated; if present they override values in the conf file weights = None, config = None, + gfpgan=None, + codeformer=None, + esrgan=None ): models = OmegaConf.load(conf) mconfig = models[model] @@ -156,6 +177,9 @@ class Generate: self.generators = {} self.base_generator = None self.seed = None + self.gfpgan = gfpgan + self.codeformer = codeformer + self.esrgan = esrgan # Note that in previous versions, there was an option to pass the # device to Generate(). However the device was then ignored, so @@ -224,8 +248,8 @@ class Generate: strength = None, init_color = None, # these are specific to embiggen (which also relies on img2img args) - embiggen = None, - embiggen_tiles = None, + embiggen=None, + embiggen_tiles=None, # these are specific to GFPGAN/ESRGAN facetool = None, gfpgan_strength = 0, @@ -274,15 +298,15 @@ class Generate: write the prompt into the PNG metadata. """ # TODO: convert this into a getattr() loop - steps = steps or self.steps - width = width or self.width - height = height or self.height - seamless = seamless or self.seamless - cfg_scale = cfg_scale or self.cfg_scale - ddim_eta = ddim_eta or self.ddim_eta - iterations = iterations or self.iterations - strength = strength or self.strength - self.seed = seed + steps = steps or self.steps + width = width or self.width + height = height or self.height + seamless = seamless or self.seamless + cfg_scale = cfg_scale or self.cfg_scale + ddim_eta = ddim_eta or self.ddim_eta + iterations = iterations or self.iterations + strength = strength or self.strength + self.seed = seed self.log_tokenization = log_tokenization with_variations = [] if with_variations is None else with_variations @@ -292,16 +316,17 @@ class Generate: for m in model.modules(): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): m.padding_mode = 'circular' if seamless else m._orig_padding_mode - + assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0' assert ( 0.0 < strength < 1.0 ), 'img2img and inpaint strength can only work with 0.0 < strength < 1.0' assert ( - 0.0 <= variation_amount <= 1.0 + 0.0 <= variation_amount <= 1.0 ), '-v --variation_amount must be in [0.0, 1.0]' assert ( - (embiggen == None and embiggen_tiles == None) or ((embiggen != None or embiggen_tiles != None) and init_img != None) + (embiggen == None and embiggen_tiles == None) or ( + (embiggen != None or embiggen_tiles != None) and init_img != None) ), 'Embiggen requires an init/input image to be specified' if len(with_variations) > 0 or variation_amount > 1.0: @@ -323,9 +348,9 @@ class Generate: if self._has_cuda(): torch.cuda.reset_peak_memory_stats() - results = list() - init_image = None - mask_image = None + results = list() + init_image = None + mask_image = None try: uc, c = get_uc_and_c( @@ -334,8 +359,9 @@ class Generate: log_tokens =self.log_tokenization ) - (init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit) - + (init_image, mask_image) = self._make_images( + init_img, init_mask, width, height, fit) + if (init_image is not None) and (mask_image is not None): generator = self._make_inpaint() elif (embiggen != None or embiggen_tiles != None): @@ -345,26 +371,27 @@ class Generate: else: generator = self._make_txt2img() - generator.set_variation(self.seed, variation_amount, with_variations) + generator.set_variation( + self.seed, variation_amount, with_variations) results = generator.generate( prompt, - iterations = iterations, - seed = self.seed, - sampler = self.sampler, - steps = steps, - cfg_scale = cfg_scale, - conditioning = (uc,c), - ddim_eta = ddim_eta, - image_callback = image_callback, # called after the final image is generated - step_callback = step_callback, # called after each intermediate image is generated - width = width, - height = height, - init_img = init_img, # embiggen needs to manipulate from the unmodified init_img - init_image = init_image, # notice that init_image is different from init_img - mask_image = mask_image, - strength = strength, - embiggen = embiggen, - embiggen_tiles = embiggen_tiles, + iterations=iterations, + seed=self.seed, + sampler=self.sampler, + steps=steps, + cfg_scale=cfg_scale, + conditioning=(uc, c), + ddim_eta=ddim_eta, + image_callback=image_callback, # called after the final image is generated + step_callback=step_callback, # called after each intermediate image is generated + width=width, + height=height, + init_img=init_img, # embiggen needs to manipulate from the unmodified init_img + init_image=init_image, # notice that init_image is different from init_img + mask_image=mask_image, + strength=strength, + embiggen=embiggen, + embiggen_tiles=embiggen_tiles, ) if init_color: @@ -393,7 +420,8 @@ class Generate: toc = time.time() print('>> Usage stats:') print( - f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic) + f'>> {len(results)} image(s) generated in', '%4.2fs' % ( + toc - tic) ) if self._has_cuda(): print( @@ -413,36 +441,42 @@ class Generate: return results def _make_images(self, img_path, mask_path, width, height, fit=False): - init_image = None - init_mask = None + init_image = None + init_mask = None if not img_path: - return None,None + return None, None - image = self._load_img(img_path, width, height, fit=fit) # this returns an Image - init_image = self._create_init_image(image) # this returns a torch tensor + image = self._load_img(img_path, width, height, + fit=fit) # this returns an Image + # this returns a torch tensor + init_image = self._create_init_image(image) - if self._has_transparency(image) and not mask_path: # if image has a transparent area and no mask was provided, then try to generate mask - print('>> Initial image has transparent areas. Will inpaint in these regions.') + # if image has a transparent area and no mask was provided, then try to generate mask + if self._has_transparency(image) and not mask_path: + print( + '>> Initial image has transparent areas. Will inpaint in these regions.') if self._check_for_erasure(image): print( '>> WARNING: Colors underneath the transparent region seem to have been erased.\n', '>> Inpainting will be suboptimal. Please preserve the colors when making\n', '>> a transparency mask, or provide mask explicitly using --init_mask (-M).' ) - init_mask = self._create_init_mask(image) # this returns a torch tensor + # this returns a torch tensor + init_mask = self._create_init_mask(image) if mask_path: - mask_image = self._load_img(mask_path, width, height, fit=fit) # this returns an Image - init_mask = self._create_init_mask(mask_image) + mask_image = self._load_img( + mask_path, width, height, fit=fit) # this returns an Image + init_mask = self._create_init_mask(mask_image) - return init_image,init_mask + return init_image, init_mask def _make_img2img(self): if not self.generators.get('img2img'): from ldm.dream.generator.img2img import Img2Img self.generators['img2img'] = Img2Img(self.model) return self.generators['img2img'] - + def _make_embiggen(self): if not self.generators.get('embiggen'): from ldm.dream.generator.embiggen import Embiggen @@ -517,38 +551,26 @@ class Generate: codeformer_fidelity = 0.75, save_original = False, image_callback = None): - try: - if upscale is not None: - from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale - if strength > 0: - if facetool == 'codeformer': - from ldm.restoration.codeformer.codeformer import CodeFormerRestoration - else: - from ldm.gfpgan.gfpgan_tools import run_gfpgan - except (ModuleNotFoundError, ImportError): - print(traceback.format_exc(), file=sys.stderr) - print('>> You may need to install the ESRGAN and/or GFPGAN modules') - return for r in image_list: image, seed = r try: if upscale is not None: - if len(upscale) < 2: - upscale.append(0.75) - image = real_esrgan_upscale( - image, - upscale[1], - int(upscale[0]), - seed, - ) - if strength > 0: - if facetool == 'codeformer': - image = CodeFormerRestoration().process(image=image, strength=strength, device=self.device, seed=seed, fidelity=codeformer_fidelity) + if self.esrgan is not None: + if len(upscale) < 2: + upscale.append(0.75) + image = self.esrgan.process( + image, upscale[1], seed, int(upscale[0])) else: - image = run_gfpgan( - image, strength, seed, 1 - ) + print(">> ESRGAN is disabled. Image not upscaled.") + if strength > 0: + if self.gfpgan is not None and self.codeformer is not None: + if facetool == 'codeformer': + image = self.codeformer.process(image=image, strength=strength, device=self.device, seed=seed, fidelity=codeformer_fidelity) + else: + image = self.gfpgan.process(image, strength, seed) + else: + print(">> Face Restoration is disabled.") except Exception as e: print( f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}' @@ -560,10 +582,10 @@ class Generate: r[0] = image # to help WebGUI - front end to generator util function - def sample_to_image(self,samples): + def sample_to_image(self, samples): return self._sample_to_image(samples) - def _sample_to_image(self,samples): + def _sample_to_image(self, samples): if not self.base_generator: from ldm.dream.generator import Generator self.base_generator = Generator(self.model) @@ -606,7 +628,7 @@ class Generate: # for usage statistics device_type = choose_torch_device() if device_type == 'cuda': - torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_peak_memory_stats() tic = time.time() # this does the work @@ -657,12 +679,12 @@ class Generate: f'>> loaded input image of size {image.width}x{image.height} from {path}' ) if fit: - image = self._fit_image(image,(width,height)) + image = self._fit_image(image, (width, height)) else: image = self._squeeze_image(image) return image - def _create_init_image(self,image): + def _create_init_image(self, image): image = image.convert('RGB') # print( # f'>> DEBUG: writing the image to img.png' @@ -671,7 +693,7 @@ class Generate: image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) - image = 2.0 * image - 1.0 + image = 2.0 * image - 1.0 return image.to(self.device) def _create_init_mask(self, image): @@ -680,7 +702,8 @@ class Generate: image = image.convert('RGB') # BUG: We need to use the model's downsample factor rather than hardcoding "8" from ldm.dream.generator.base import downsampling - image = image.resize((image.width//downsampling, image.height//downsampling), resample=Image.Resampling.LANCZOS) + image = image.resize((image.width//downsampling, image.height // + downsampling), resample=Image.Resampling.LANCZOS) # print( # f'>> DEBUG: writing the mask to mask.png' # ) @@ -702,7 +725,7 @@ class Generate: mask = ImageOps.invert(mask) return mask - def _has_transparency(self,image): + def _has_transparency(self, image): if image.info.get("transparency", None) is not None: return True if image.mode == "P": @@ -716,11 +739,10 @@ class Generate: return True return False - - def _check_for_erasure(self,image): + def _check_for_erasure(self, image): width, height = image.size - pixdata = image.load() - colored = 0 + pixdata = image.load() + colored = 0 for y in range(height): for x in range(width): if pixdata[x, y][3] == 0: @@ -730,28 +752,28 @@ class Generate: colored += 1 return colored == 0 - def _squeeze_image(self,image): - x,y,resize_needed = self._resolution_check(image.width,image.height) + def _squeeze_image(self, image): + x, y, resize_needed = self._resolution_check(image.width, image.height) if resize_needed: - return InitImageResizer(image).resize(x,y) + return InitImageResizer(image).resize(x, y) return image - - def _fit_image(self,image,max_dimensions): - w,h = max_dimensions + def _fit_image(self, image, max_dimensions): + w, h = max_dimensions print( f'>> image will be resized to fit inside a box {w}x{h} in size.' ) if image.width > image.height: - h = None # by setting h to none, we tell InitImageResizer to fit into the width and calculate height + h = None # by setting h to none, we tell InitImageResizer to fit into the width and calculate height elif image.height > image.width: - w = None # ditto for w + w = None # ditto for w else: pass - image = InitImageResizer(image).resize(w,h) # note that InitImageResizer does the multiple of 64 truncation internally + # note that InitImageResizer does the multiple of 64 truncation internally + image = InitImageResizer(image).resize(w, h) print( f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}' - ) + ) return image def _resolution_check(self, width, height, log=False): @@ -765,7 +787,7 @@ class Generate: f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}' ) height = h - width = w + width = w resize_needed = True if (width * height) > (self.width * self.height): diff --git a/ldm/gfpgan/gfpgan_tools.py b/ldm/gfpgan/gfpgan_tools.py deleted file mode 100644 index 3adfc907a4..0000000000 --- a/ldm/gfpgan/gfpgan_tools.py +++ /dev/null @@ -1,168 +0,0 @@ -import torch -import warnings -import os -import sys -import numpy as np - -from PIL import Image -#from scripts.dream import create_argv_parser -from ldm.dream.args import Args - -opt = Args() -opt.parse_args() -model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path) -gfpgan_model_exists = os.path.isfile(model_path) - -def run_gfpgan(image, strength, seed, upsampler_scale=4): - print(f'>> GFPGAN - Restoring Faces for image seed:{seed}') - gfpgan = None - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=DeprecationWarning) - warnings.filterwarnings('ignore', category=UserWarning) - - try: - if not gfpgan_model_exists: - raise Exception('GFPGAN model not found at path ' + model_path) - - sys.path.append(os.path.abspath(opt.gfpgan_dir)) - from gfpgan import GFPGANer - - bg_upsampler = _load_gfpgan_bg_upsampler( - opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile - ) - - gfpgan = GFPGANer( - model_path=model_path, - upscale=upsampler_scale, - arch='clean', - channel_multiplier=2, - bg_upsampler=bg_upsampler, - ) - except Exception: - import traceback - - print('>> Error loading GFPGAN:', file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - - if gfpgan is None: - print( - f'>> WARNING: GFPGAN not initialized.' - ) - print( - f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth to {model_path}, \nor change GFPGAN directory with --gfpgan_dir.' - ) - return image - - image = image.convert('RGB') - - cropped_faces, restored_faces, restored_img = gfpgan.enhance( - np.array(image, dtype=np.uint8), - has_aligned=False, - only_center_face=False, - paste_back=True, - ) - res = Image.fromarray(restored_img) - - if strength < 1.0: - # Resize the image to the new image if the sizes have changed - if restored_img.size != image.size: - image = image.resize(res.size) - res = Image.blend(image, res, strength) - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gfpgan = None - - return res - - -def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400): - if bg_upsampler == 'realesrgan': - if not torch.cuda.is_available(): # CPU or MPS on M1 - use_half_precision = False - else: - use_half_precision = True - - model_path = { - 2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', - 4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', - } - - if upsampler_scale not in model_path: - return None - - from basicsr.archs.rrdbnet_arch import RRDBNet - from realesrgan import RealESRGANer - - if upsampler_scale == 4: - model = RRDBNet( - num_in_ch=3, - num_out_ch=3, - num_feat=64, - num_block=23, - num_grow_ch=32, - scale=4, - ) - if upsampler_scale == 2: - model = RRDBNet( - num_in_ch=3, - num_out_ch=3, - num_feat=64, - num_block=23, - num_grow_ch=32, - scale=2, - ) - - bg_upsampler = RealESRGANer( - scale=upsampler_scale, - model_path=model_path[upsampler_scale], - model=model, - tile=bg_tile, - tile_pad=10, - pre_pad=0, - half=use_half_precision, - ) - else: - bg_upsampler = None - - return bg_upsampler - - -def real_esrgan_upscale(image, strength, upsampler_scale, seed): - print( - f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x' - ) - - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=DeprecationWarning) - warnings.filterwarnings('ignore', category=UserWarning) - - try: - upsampler = _load_gfpgan_bg_upsampler( - opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile - ) - except Exception: - import traceback - - print('>> Error loading Real-ESRGAN:', file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - - output, img_mode = upsampler.enhance( - np.array(image, dtype=np.uint8), - outscale=upsampler_scale, - alpha_upsampler=opt.gfpgan_bg_upsampler, - ) - - res = Image.fromarray(output) - - if strength < 1.0: - # Resize the image to the new image if the sizes have changed - if output.size != image.size: - image = image.resize(res.size) - res = Image.blend(image, res, strength) - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - upsampler = None - - return res diff --git a/ldm/restoration/codeformer/codeformer.py b/ldm/restoration/codeformer/codeformer.py index ff81085793..f725ef9144 100644 --- a/ldm/restoration/codeformer/codeformer.py +++ b/ldm/restoration/codeformer/codeformer.py @@ -2,12 +2,20 @@ import os import torch import numpy as np import warnings +import sys pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' class CodeFormerRestoration(): - def __init__(self) -> None: - pass + def __init__(self, + codeformer_dir='ldm/restoration/codeformer', + codeformer_model_path='weights/codeformer.pth') -> None: + self.model_path = os.path.join(codeformer_dir, codeformer_model_path) + self.codeformer_model_exists = os.path.isfile(self.model_path) + + if not self.codeformer_model_exists: + print('## NOT FOUND: CodeFormer model not found at ' + self.model_path) + sys.path.append(os.path.abspath(codeformer_dir)) def process(self, image, strength, device, seed=None, fidelity=0.75): if seed is not None: diff --git a/ldm/restoration/gfpgan/gfpgan.py b/ldm/restoration/gfpgan/gfpgan.py new file mode 100644 index 0000000000..643d1e9559 --- /dev/null +++ b/ldm/restoration/gfpgan/gfpgan.py @@ -0,0 +1,76 @@ +import torch +import warnings +import os +import sys +import numpy as np + +from PIL import Image + + +class GFPGAN(): + def __init__( + self, + gfpgan_dir='src/gfpgan', + gfpgan_model_path='experiments/pretrained_models/GFPGANv1.3.pth') -> None: + + self.model_path = os.path.join(gfpgan_dir, gfpgan_model_path) + self.gfpgan_model_exists = os.path.isfile(self.model_path) + + if not self.gfpgan_model_exists: + raise Exception( + 'GFPGAN model not found at path ' + self.model_path) + sys.path.append(os.path.abspath(gfpgan_dir)) + + def model_exists(self): + return os.path.isfile(self.model_path) + + def process(self, image, strength: float, seed: str = None): + if seed is not None: + print(f'>> GFPGAN - Restoring Faces for image seed:{seed}') + + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + warnings.filterwarnings('ignore', category=UserWarning) + try: + from gfpgan import GFPGANer + self.gfpgan = GFPGANer( + model_path=self.model_path, + upscale=1, + arch='clean', + channel_multiplier=2, + bg_upsampler=None, + ) + except Exception: + import traceback + print('>> Error loading GFPGAN:', file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + if self.gfpgan is None: + print( + f'>> WARNING: GFPGAN not initialized.' + ) + print( + f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth to {self.model_path}, \nor change GFPGAN directory with --gfpgan_dir.' + ) + + image = image.convert('RGB') + + _, _, restored_img = self.gfpgan.enhance( + np.array(image, dtype=np.uint8), + has_aligned=False, + only_center_face=False, + paste_back=True, + ) + res = Image.fromarray(restored_img) + + if strength < 1.0: + # Resize the image to the new image if the sizes have changed + if restored_img.size != image.size: + image = image.resize(res.size) + res = Image.blend(image, res, strength) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + self.gfpgan = None + + return res diff --git a/ldm/restoration/realesrgan/realesrgan.py b/ldm/restoration/realesrgan/realesrgan.py new file mode 100644 index 0000000000..9823a2cbf4 --- /dev/null +++ b/ldm/restoration/realesrgan/realesrgan.py @@ -0,0 +1,102 @@ +import torch +import warnings +import numpy as np + +from PIL import Image + + +class ESRGAN(): + def __init__(self, bg_tile_size=400) -> None: + self.bg_tile_size = bg_tile_size + + if not torch.cuda.is_available(): # CPU or MPS on M1 + use_half_precision = False + else: + use_half_precision = True + + def load_esrgan_bg_upsampler(self, upsampler_scale): + if not torch.cuda.is_available(): # CPU or MPS on M1 + use_half_precision = False + else: + use_half_precision = True + + model_path = { + 2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', + 4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', + } + + if upsampler_scale not in model_path: + return None + else: + from basicsr.archs.rrdbnet_arch import RRDBNet + from realesrgan import RealESRGANer + + if upsampler_scale == 4: + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=4, + ) + if upsampler_scale == 2: + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=2, + ) + + bg_upsampler = RealESRGANer( + scale=upsampler_scale, + model_path=model_path[upsampler_scale], + model=model, + tile=self.bg_tile_size, + tile_pad=10, + pre_pad=0, + half=use_half_precision, + ) + + return bg_upsampler + + def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2): + if seed is not None: + print( + f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x' + ) + + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + warnings.filterwarnings('ignore', category=UserWarning) + + try: + upsampler = self.load_esrgan_bg_upsampler(upsampler_scale) + except Exception: + import traceback + import sys + + print('>> Error loading Real-ESRGAN:', file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + output, _ = upsampler.enhance( + np.array(image, dtype=np.uint8), + outscale=upsampler_scale, + alpha_upsampler='realesrgan', + ) + + res = Image.fromarray(output) + + if strength < 1.0: + # Resize the image to the new image if the sizes have changed + if output.size != image.size: + image = image.resize(res.size) + res = Image.blend(image, res, strength) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + upsampler = None + + return res diff --git a/ldm/restoration/restoration.py b/ldm/restoration/restoration.py new file mode 100644 index 0000000000..d9caebd4fa --- /dev/null +++ b/ldm/restoration/restoration.py @@ -0,0 +1,34 @@ +class Restoration(): + def __init__(self, gfpgan_dir='./src/gfpgan', gfpgan_model_path='experiments/pretrained_models/GFPGANv1.3.pth', esrgan_bg_tile=400) -> None: + self.gfpgan_dir = gfpgan_dir + self.gfpgan_model_path = gfpgan_model_path + self.esrgan_bg_tile = esrgan_bg_tile + + def load_face_restore_models(self): + # Load GFPGAN + gfpgan = self.load_gfpgan() + if gfpgan.gfpgan_model_exists: + print('>> GFPGAN Initialized') + + # Load CodeFormer + codeformer = self.load_codeformer() + if codeformer.codeformer_model_exists: + print('>> CodeFormer Initialized') + + return gfpgan, codeformer + + # Face Restore Models + def load_gfpgan(self): + from ldm.restoration.gfpgan.gfpgan import GFPGAN + return GFPGAN(self.gfpgan_dir, self.gfpgan_model_path) + + def load_codeformer(self): + from ldm.restoration.codeformer.codeformer import CodeFormerRestoration + return CodeFormerRestoration() + + # Upscale Models + def load_ersgan(self): + from ldm.restoration.realesrgan.realesrgan import ESRGAN + esrgan = ESRGAN(self.esrgan_bg_tile) + print('>> ESRGAN Initialized') + return esrgan; \ No newline at end of file diff --git a/scripts/dream.py b/scripts/dream.py index 857b5637aa..dcc54aa15f 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -42,7 +42,25 @@ def main(): import transformers transformers.logging.set_verbosity_error() - # creating a simple Generate object with a handful of + # Loading Face Restoration and ESRGAN Modules + try: + gfpgan, codeformer, esrgan = None, None, None + from ldm.restoration.restoration import Restoration + restoration = Restoration(opt.gfpgan_dir, opt.gfpgan_model_path, opt.esrgan_bg_tile) + if opt.restore: + gfpgan, codeformer = restoration.load_face_restore_models() + else: + print('>> Face Restoration Disabled') + if opt.esrgan: + esrgan = restoration.load_ersgan() + else: + print('>> ESRGAN Disabled') + except (ModuleNotFoundError, ImportError): + import traceback + print(traceback.format_exc(), file=sys.stderr) + print('>> You may need to install the ESRGAN and/or GFPGAN modules') + + # creating a simple text2image object with a handful of # defaults passed on the command line. # additional parameters will be added (or overriden) during # the user input loop @@ -53,6 +71,9 @@ def main(): sampler_name = opt.sampler_name, embedding_path = opt.embedding_path, full_precision = opt.full_precision, + gfpgan=gfpgan, + codeformer=codeformer, + esrgan=esrgan ) except (FileNotFoundError, IOError, KeyError) as e: print(f'{e}. Aborting.') @@ -89,7 +110,7 @@ def main(): # web server loops forever if opt.web: - dream_server_loop(gen, opt.host, opt.port, opt.outdir) + dream_server_loop(gen, opt.host, opt.port, opt.outdir, gfpgan) sys.exit(0) main_loop(gen, opt, infile) @@ -312,7 +333,7 @@ def get_next_command(infile=None) -> str: # command string print(f'#{command}') return command -def dream_server_loop(gen, host, port, outdir): +def dream_server_loop(gen, host, port, outdir, gfpgan): print('\n* --web was specified, starting web server...') # Change working directory to the stable-diffusion directory os.chdir( @@ -322,6 +343,10 @@ def dream_server_loop(gen, host, port, outdir): # Start server DreamServer.model = gen # misnomer in DreamServer - this is not the model you are looking for DreamServer.outdir = outdir + DreamServer.gfpgan_model_exists = False + if gfpgan is not None: + DreamServer.gfpgan_model_exists = gfpgan.gfpgan_model_exists + dream_server = ThreadingDreamServer((host, port)) print(">> Started Stable Diffusion dream server!") if host == '0.0.0.0': @@ -345,8 +370,7 @@ def write_log_message(results, log_path): log_lines = [f'{path}: {prompt}\n' for path, prompt in results] for l in log_lines: output_cntr += 1 - print(f'[{output_cntr}] {l}',end='') - + print(f'[{output_cntr}] {l}', end='') with open(log_path, 'a', encoding='utf-8') as file: file.writelines(log_lines)