GFPGAN and Real ESRGAN Implementation Refactor

This commit is contained in:
blessedcoolant 2022-09-14 05:17:14 +12:00 committed by Lincoln Stein
parent e8bb39370c
commit 1b5013ab72
11 changed files with 532 additions and 377 deletions

View File

@ -4,10 +4,16 @@ title: Upscale
# :material-image-size-select-large: Upscale
## **Intro**
The script provides the ability to restore faces and upscale.
You can enable these features by passing `--restore` and `--esrgan` to your launch script to enable
face restoration modules and upscaling modules respectively.
## **GFPGAN and Real-ESRGAN Support**
The script also provides the ability to do face restoration and upscaling with the help of GFPGAN
and Real-ESRGAN respectively.
The default face restoration module is GFPGAN and the default upscaling module is ESRGAN.
As of version 1.14, environment.yaml will install the Real-ESRGAN package into the standard install
location for python packages, and will put GFPGAN into a subdirectory of "src" in the

View File

@ -370,16 +370,19 @@ class Args(object):
type=str,
help='Path to a pre-trained embedding manager checkpoint - can only be set on command line',
)
# GFPGAN related args
# Restoration related args
postprocessing_group.add_argument(
'--gfpgan_bg_upsampler',
type=str,
default='realesrgan',
help='Background upsampler. Default: realesrgan. Options: realesrgan, none.',
'--restore',
action='store_true',
help='Enable Face Restoration',
)
postprocessing_group.add_argument(
'--gfpgan_bg_tile',
'--esrgan',
action='store_true',
help='Enable Upscaling',
)
postprocessing_group.add_argument(
'--esrgan_bg_tile',
type=int,
default=400,
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',

View File

@ -4,11 +4,12 @@ and generates with ldm.dream.generator.img2img
'''
import torch
import numpy as np
import numpy as np
from PIL import Image
from ldm.dream.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.dream.generator.img2img import Img2Img
from ldm.dream.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.dream.generator.img2img import Img2Img
class Embiggen(Generator):
def __init__(self, model, precision):
@ -38,19 +39,20 @@ class Embiggen(Generator):
Return value depends on the seed at the time you call it
"""
# Construct embiggen arg array, and sanity check arguments
if embiggen == None: # embiggen can also be called with just embiggen_tiles
embiggen = [1.0] # If not specified, assume no scaling
elif embiggen[0] < 0 :
if embiggen == None: # embiggen can also be called with just embiggen_tiles
embiggen = [1.0] # If not specified, assume no scaling
elif embiggen[0] < 0:
embiggen[0] = 1.0
print('>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !')
print(
'>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !')
if len(embiggen) < 2:
embiggen.append(0.75)
elif embiggen[1] > 1.0 or embiggen[1] < 0 :
elif embiggen[1] > 1.0 or embiggen[1] < 0:
embiggen[1] = 0.75
print('>> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !')
if len(embiggen) < 3:
embiggen.append(0.25)
elif embiggen[2] < 0 :
elif embiggen[2] < 0:
embiggen[2] = 0.25
print('>> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !')
@ -76,29 +78,30 @@ class Embiggen(Generator):
if embiggen[0] != 1.0:
initsuperwidth = round(initsuperwidth*embiggen[0])
initsuperheight = round(initsuperheight*embiggen[0])
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
from ldm.gfpgan.gfpgan_tools import (
real_esrgan_upscale,
)
print(f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}')
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
from ldm.restoration.realesrgan import ESRGAN
esrgan = ESRGAN()
print(
f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}')
if embiggen[0] > 2:
initsuperimage = real_esrgan_upscale(
initsuperimage = esrgan.process(
initsuperimage,
embiggen[1], # upscale strength
4, # upscale scale
embiggen[1], # upscale strength
self.seed,
4, # upscale scale
)
else:
initsuperimage = real_esrgan_upscale(
initsuperimage = esrgan.process(
initsuperimage,
embiggen[1], # upscale strength
2, # upscale scale
embiggen[1], # upscale strength
self.seed,
2, # upscale scale
)
# We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x
# but from personal experiance it doesn't greatly improve anything after 4x
# Resize to target scaling factor resolution
initsuperimage = initsuperimage.resize((initsuperwidth, initsuperheight), Image.Resampling.LANCZOS)
initsuperimage = initsuperimage.resize(
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS)
# Use width and height as tile widths and height
# Determine buffer size in pixels
@ -121,28 +124,31 @@ class Embiggen(Generator):
emb_tiles_x = 1
emb_tiles_y = 1
if (initsuperwidth - width) > 0:
emb_tiles_x = ceildiv(initsuperwidth - width, width - overlap_size_x) + 1
emb_tiles_x = ceildiv(initsuperwidth - width,
width - overlap_size_x) + 1
if (initsuperheight - height) > 0:
emb_tiles_y = ceildiv(initsuperheight - height, height - overlap_size_y) + 1
emb_tiles_y = ceildiv(initsuperheight - height,
height - overlap_size_y) + 1
# Sanity
assert emb_tiles_x > 1 or emb_tiles_y > 1, f'ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don\'t need to Embiggen! Check your arguments.'
# Prep alpha layers --------------
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
# agradientL is Left-side transparent
agradientL = Image.linear_gradient('L').rotate(90).resize((overlap_size_x, height))
agradientL = Image.linear_gradient('L').rotate(
90).resize((overlap_size_x, height))
# agradientT is Top-side transparent
agradientT = Image.linear_gradient('L').resize((width, overlap_size_y))
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
agradientC = Image.new('L', (256, 256))
for y in range(256):
for x in range(256):
#Find distance to lower right corner (numpy takes arrays)
# Find distance to lower right corner (numpy takes arrays)
distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0]
#Clamp values to max 255
# Clamp values to max 255
if distanceToLR > 255:
distanceToLR = 255
#Place the pixel as invert of distance
# Place the pixel as invert of distance
agradientC.putpixel((x, y), int(255 - distanceToLR))
# Create alpha layers default fully white
@ -154,59 +160,79 @@ class Embiggen(Generator):
alphaLayerT.paste(agradientT, (0, 0))
alphaLayerLTC.paste(agradientL, (0, 0))
alphaLayerLTC.paste(agradientT, (0, 0))
alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
alphaLayerLTC.paste(agradientC.resize(
(overlap_size_x, overlap_size_y)), (0, 0))
if embiggen_tiles:
# Individual unconnected sides
alphaLayerR = Image.new("L", (width, height), 255)
alphaLayerR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
alphaLayerR.paste(agradientL.rotate(
180), (width - overlap_size_x, 0))
alphaLayerB = Image.new("L", (width, height), 255)
alphaLayerB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
alphaLayerB.paste(agradientT.rotate(
180), (0, height - overlap_size_y))
alphaLayerTB = Image.new("L", (width, height), 255)
alphaLayerTB.paste(agradientT, (0, 0))
alphaLayerTB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
alphaLayerTB.paste(agradientT.rotate(
180), (0, height - overlap_size_y))
alphaLayerLR = Image.new("L", (width, height), 255)
alphaLayerLR.paste(agradientL, (0, 0))
alphaLayerLR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
alphaLayerLR.paste(agradientL.rotate(
180), (width - overlap_size_x, 0))
# Sides and corner Layers
alphaLayerRBC = Image.new("L", (width, height), 255)
alphaLayerRBC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
alphaLayerRBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
alphaLayerRBC.paste(agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
alphaLayerRBC.paste(agradientL.rotate(
180), (width - overlap_size_x, 0))
alphaLayerRBC.paste(agradientT.rotate(
180), (0, height - overlap_size_y))
alphaLayerRBC.paste(agradientC.rotate(180).resize(
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
alphaLayerLBC = Image.new("L", (width, height), 255)
alphaLayerLBC.paste(agradientL, (0, 0))
alphaLayerLBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
alphaLayerLBC.paste(agradientC.rotate(90).resize((overlap_size_x, overlap_size_y)), (0, height - overlap_size_y))
alphaLayerLBC.paste(agradientT.rotate(
180), (0, height - overlap_size_y))
alphaLayerLBC.paste(agradientC.rotate(90).resize(
(overlap_size_x, overlap_size_y)), (0, height - overlap_size_y))
alphaLayerRTC = Image.new("L", (width, height), 255)
alphaLayerRTC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
alphaLayerRTC.paste(agradientL.rotate(
180), (width - overlap_size_x, 0))
alphaLayerRTC.paste(agradientT, (0, 0))
alphaLayerRTC.paste(agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
alphaLayerRTC.paste(agradientC.rotate(270).resize(
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
# All but X layers
alphaLayerABT = Image.new("L", (width, height), 255)
alphaLayerABT.paste(alphaLayerLBC, (0, 0))
alphaLayerABT.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
alphaLayerABT.paste(agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
alphaLayerABT.paste(agradientL.rotate(
180), (width - overlap_size_x, 0))
alphaLayerABT.paste(agradientC.rotate(180).resize(
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
alphaLayerABL = Image.new("L", (width, height), 255)
alphaLayerABL.paste(alphaLayerRTC, (0, 0))
alphaLayerABL.paste(agradientT.rotate(180), (0, height - overlap_size_y))
alphaLayerABL.paste(agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
alphaLayerABL.paste(agradientT.rotate(
180), (0, height - overlap_size_y))
alphaLayerABL.paste(agradientC.rotate(180).resize(
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
alphaLayerABR = Image.new("L", (width, height), 255)
alphaLayerABR.paste(alphaLayerLBC, (0, 0))
alphaLayerABR.paste(agradientT, (0, 0))
alphaLayerABR.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
alphaLayerABR.paste(agradientC.resize(
(overlap_size_x, overlap_size_y)), (0, 0))
alphaLayerABB = Image.new("L", (width, height), 255)
alphaLayerABB.paste(alphaLayerRTC, (0, 0))
alphaLayerABB.paste(agradientL, (0, 0))
alphaLayerABB.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
alphaLayerABB.paste(agradientC.resize(
(overlap_size_x, overlap_size_y)), (0, 0))
# All-around layer
alphaLayerAA = Image.new("L", (width, height), 255)
alphaLayerAA.paste(alphaLayerABT, (0, 0))
alphaLayerAA.paste(agradientT, (0, 0))
alphaLayerAA.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
alphaLayerAA.paste(agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
alphaLayerAA.paste(agradientC.resize(
(overlap_size_x, overlap_size_y)), (0, 0))
alphaLayerAA.paste(agradientC.rotate(270).resize(
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
# Clean up temporary gradients
del agradientL
@ -218,7 +244,8 @@ class Embiggen(Generator):
if embiggen_tiles:
print(f'>> Making {len(embiggen_tiles)} Embiggen tiles...')
else:
print(f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...')
print(
f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...')
emb_tile_store = []
for tile in range(emb_tiles_x * emb_tiles_y):
@ -240,20 +267,23 @@ class Embiggen(Generator):
top = round(emb_row_i * (height - overlap_size_y))
right = left + width
bottom = top + height
# Cropped image of above dimension (does not modify the original)
newinitimage = initsuperimage.crop((left, top, right, bottom))
# DEBUG:
# newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png'
# newinitimage.save(newinitimagepath)
if embiggen_tiles:
print(f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)')
print(
f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)')
else:
print(f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles')
print(
f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles')
# create a torch tensor from an Image
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
newinitimage = np.array(
newinitimage).astype(np.float32) / 255.0
newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
newinitimage = torch.from_numpy(newinitimage)
newinitimage = 2.0 * newinitimage - 1.0
@ -261,33 +291,35 @@ class Embiggen(Generator):
tile_results = gen_img2img.generate(
prompt,
iterations = 1,
seed = self.seed,
sampler = sampler,
steps = steps,
cfg_scale = cfg_scale,
conditioning = conditioning,
ddim_eta = ddim_eta,
image_callback = None, # called only after the final image is generated
step_callback = step_callback, # called after each intermediate image is generated
width = width,
height = height,
init_img = init_img, # img2img doesn't need this, but it might in the future
init_image = newinitimage, # notice that init_image is different from init_img
mask_image = None,
strength = strength,
iterations=1,
seed=self.seed,
sampler=sampler,
steps=steps,
cfg_scale=cfg_scale,
conditioning=conditioning,
ddim_eta=ddim_eta,
image_callback=None, # called only after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated
width=width,
height=height,
init_img=init_img, # img2img doesn't need this, but it might in the future
init_image=newinitimage, # notice that init_image is different from init_img
mask_image=None,
strength=strength,
)
emb_tile_store.append(tile_results[0][0])
# DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite
# emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png')
del newinitimage
# Sanity check we have them all
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)):
outputsuperimage = Image.new("RGBA", (initsuperwidth, initsuperheight))
outputsuperimage = Image.new(
"RGBA", (initsuperwidth, initsuperheight))
if embiggen_tiles:
outputsuperimage.alpha_composite(initsuperimage.convert('RGBA'), (0, 0))
outputsuperimage.alpha_composite(
initsuperimage.convert('RGBA'), (0, 0))
for tile in range(emb_tiles_x * emb_tiles_y):
if embiggen_tiles:
if tile in embiggen_tiles:
@ -308,7 +340,8 @@ class Embiggen(Generator):
if emb_column_i + 1 == emb_tiles_x:
left = initsuperwidth - width
else:
left = round(emb_column_i * (width - overlap_size_x))
left = round(emb_column_i *
(width - overlap_size_x))
if emb_row_i + 1 == emb_tiles_y:
top = initsuperheight - height
else:
@ -319,33 +352,33 @@ class Embiggen(Generator):
# top of image
if emb_row_i == 0:
if emb_column_i == 0:
if (tile+1) in embiggen_tiles: # Look-ahead right
if (tile+emb_tiles_x) not in embiggen_tiles: # Look-ahead down
if (tile+1) in embiggen_tiles: # Look-ahead right
if (tile+emb_tiles_x) not in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerB)
# Otherwise do nothing on this tile
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerR)
else:
intileimage.putalpha(alphaLayerRBC)
elif emb_column_i == emb_tiles_x - 1:
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerL)
else:
intileimage.putalpha(alphaLayerLBC)
else:
if (tile+1) in embiggen_tiles: # Look-ahead right
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
if (tile+1) in embiggen_tiles: # Look-ahead right
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerL)
else:
intileimage.putalpha(alphaLayerLBC)
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerLR)
else:
intileimage.putalpha(alphaLayerABT)
# bottom of image
elif emb_row_i == emb_tiles_y - 1:
if emb_column_i == 0:
if (tile+1) in embiggen_tiles: # Look-ahead right
if (tile+1) in embiggen_tiles: # Look-ahead right
intileimage.putalpha(alphaLayerT)
else:
intileimage.putalpha(alphaLayerRTC)
@ -353,34 +386,34 @@ class Embiggen(Generator):
# No tiles to look ahead to
intileimage.putalpha(alphaLayerLTC)
else:
if (tile+1) in embiggen_tiles: # Look-ahead right
if (tile+1) in embiggen_tiles: # Look-ahead right
intileimage.putalpha(alphaLayerLTC)
else:
intileimage.putalpha(alphaLayerABB)
# vertical middle of image
else:
if emb_column_i == 0:
if (tile+1) in embiggen_tiles: # Look-ahead right
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
if (tile+1) in embiggen_tiles: # Look-ahead right
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerT)
else:
intileimage.putalpha(alphaLayerTB)
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerRTC)
else:
intileimage.putalpha(alphaLayerABL)
elif emb_column_i == emb_tiles_x - 1:
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerLTC)
else:
intileimage.putalpha(alphaLayerABR)
else:
if (tile+1) in embiggen_tiles: # Look-ahead right
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
if (tile+1) in embiggen_tiles: # Look-ahead right
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerLTC)
else:
intileimage.putalpha(alphaLayerABR)
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerABB)
else:
intileimage.putalpha(alphaLayerAA)
@ -400,4 +433,4 @@ class Embiggen(Generator):
# after internal loops and patching up return Embiggen image
return outputsuperimage
# end of function declaration
return make_image
return make_image

View File

@ -37,6 +37,8 @@ def build_opt(post_data, seed, gfpgan_model_exists):
setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed']))
setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0)
setattr(opt, 'with_variations', [])
setattr(opt, 'embiggen', None)
setattr(opt, 'embiggen_tiles', None)
broken = False
if int(post_data['seed']) != -1 and post_data['with_variations'] != '':
@ -80,12 +82,11 @@ class DreamServer(BaseHTTPRequestHandler):
self.wfile.write(content.read())
elif self.path == "/config.js":
# unfortunately this import can't be at the top level, since that would cause a circular import
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
self.send_response(200)
self.send_header("Content-type", "application/javascript")
self.end_headers()
config = {
'gfpgan_model_exists': gfpgan_model_exists
'gfpgan_model_exists': self.gfpgan_model_exists
}
self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8"))
elif self.path == "/run_log.json":
@ -138,11 +139,10 @@ class DreamServer(BaseHTTPRequestHandler):
self.end_headers()
# unfortunately this import can't be at the top level, since that would cause a circular import
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
content_length = int(self.headers['Content-Length'])
post_data = json.loads(self.rfile.read(content_length))
opt = build_opt(post_data, self.model.seed, gfpgan_model_exists)
opt = build_opt(post_data, self.model.seed, self.gfpgan_model_exists)
self.canceled.clear()
# In order to handle upscaled images, the PngWriter needs to maintain state

View File

@ -23,9 +23,9 @@ from PIL import Image, ImageOps
from torch import nn
from pytorch_lightning import seed_everything, logging
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ksampler import KSampler
from ldm.dream.pngwriter import PngWriter, retrieve_metadata
from ldm.dream.args import metadata_loads
@ -51,6 +51,24 @@ torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw):
device = kw.get("device", "mps")
kw["device"]="cpu"
return orig(*args, **kw).to(device)
return new_func
return orig
torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn)
torch.randn_like = fix_func(torch.randn_like)
torch.randint = fix_func(torch.randint)
torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
"""Simplified text to image API for stable diffusion/latent diffusion
Example Usage:
@ -135,6 +153,9 @@ class Generate:
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
gfpgan=None,
codeformer=None,
esrgan=None
):
models = OmegaConf.load(conf)
mconfig = models[model]
@ -158,6 +179,9 @@ class Generate:
self.generators = {}
self.base_generator = None
self.seed = None
self.gfpgan = gfpgan
self.codeformer = codeformer
self.esrgan = esrgan
# Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so
@ -234,8 +258,8 @@ class Generate:
strength = None,
init_color = None,
# these are specific to embiggen (which also relies on img2img args)
embiggen = None,
embiggen_tiles = None,
embiggen=None,
embiggen_tiles=None,
# these are specific to GFPGAN/ESRGAN
facetool = None,
gfpgan_strength = 0,
@ -284,15 +308,15 @@ class Generate:
write the prompt into the PNG metadata.
"""
# TODO: convert this into a getattr() loop
steps = steps or self.steps
width = width or self.width
height = height or self.height
seamless = seamless or self.seamless
cfg_scale = cfg_scale or self.cfg_scale
ddim_eta = ddim_eta or self.ddim_eta
iterations = iterations or self.iterations
strength = strength or self.strength
self.seed = seed
steps = steps or self.steps
width = width or self.width
height = height or self.height
seamless = seamless or self.seamless
cfg_scale = cfg_scale or self.cfg_scale
ddim_eta = ddim_eta or self.ddim_eta
iterations = iterations or self.iterations
strength = strength or self.strength
self.seed = seed
self.log_tokenization = log_tokenization
self.step_callback = step_callback
with_variations = [] if with_variations is None else with_variations
@ -303,16 +327,17 @@ class Generate:
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
assert (
0.0 < strength < 1.0
), 'img2img and inpaint strength can only work with 0.0 < strength < 1.0'
assert (
0.0 <= variation_amount <= 1.0
0.0 <= variation_amount <= 1.0
), '-v --variation_amount must be in [0.0, 1.0]'
assert (
(embiggen == None and embiggen_tiles == None) or ((embiggen != None or embiggen_tiles != None) and init_img != None)
(embiggen == None and embiggen_tiles == None) or (
(embiggen != None or embiggen_tiles != None) and init_img != None)
), 'Embiggen requires an init/input image to be specified'
if len(with_variations) > 0 or variation_amount > 1.0:
@ -334,9 +359,9 @@ class Generate:
if self._has_cuda():
torch.cuda.reset_peak_memory_stats()
results = list()
init_image = None
mask_image = None
results = list()
init_image = None
mask_image = None
try:
uc, c = get_uc_and_c(
@ -345,8 +370,9 @@ class Generate:
log_tokens =self.log_tokenization
)
(init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit)
(init_image, mask_image) = self._make_images(
init_img, init_mask, width, height, fit)
if (init_image is not None) and (mask_image is not None):
generator = self._make_inpaint()
elif (embiggen != None or embiggen_tiles != None):
@ -356,26 +382,27 @@ class Generate:
else:
generator = self._make_txt2img()
generator.set_variation(self.seed, variation_amount, with_variations)
generator.set_variation(
self.seed, variation_amount, with_variations)
results = generator.generate(
prompt,
iterations = iterations,
seed = self.seed,
sampler = self.sampler,
steps = steps,
cfg_scale = cfg_scale,
conditioning = (uc,c),
ddim_eta = ddim_eta,
image_callback = image_callback, # called after the final image is generated
step_callback = step_callback, # called after each intermediate image is generated
width = width,
height = height,
init_img = init_img, # embiggen needs to manipulate from the unmodified init_img
init_image = init_image, # notice that init_image is different from init_img
mask_image = mask_image,
strength = strength,
embiggen = embiggen,
embiggen_tiles = embiggen_tiles,
iterations=iterations,
seed=self.seed,
sampler=self.sampler,
steps=steps,
cfg_scale=cfg_scale,
conditioning=(uc, c),
ddim_eta=ddim_eta,
image_callback=image_callback, # called after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated
width=width,
height=height,
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
init_image=init_image, # notice that init_image is different from init_img
mask_image=mask_image,
strength=strength,
embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
)
if init_color:
@ -404,7 +431,8 @@ class Generate:
toc = time.time()
print('>> Usage stats:')
print(
f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
f'>> {len(results)} image(s) generated in', '%4.2fs' % (
toc - tic)
)
if self._has_cuda():
print(
@ -515,29 +543,35 @@ class Generate:
def _make_images(self, img_path, mask_path, width, height, fit=False):
init_image = None
init_mask = None
init_image = None
init_mask = None
if not img_path:
return None,None
return None, None
image = self._load_img(img_path, width, height, fit=fit) # this returns an Image
init_image = self._create_init_image(image) # this returns a torch tensor
image = self._load_img(img_path, width, height,
fit=fit) # this returns an Image
# this returns a torch tensor
init_image = self._create_init_image(image)
if self._has_transparency(image) and not mask_path: # if image has a transparent area and no mask was provided, then try to generate mask
print('>> Initial image has transparent areas. Will inpaint in these regions.')
# if image has a transparent area and no mask was provided, then try to generate mask
if self._has_transparency(image) and not mask_path:
print(
'>> Initial image has transparent areas. Will inpaint in these regions.')
if self._check_for_erasure(image):
print(
'>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
'>> Inpainting will be suboptimal. Please preserve the colors when making\n',
'>> a transparency mask, or provide mask explicitly using --init_mask (-M).'
)
init_mask = self._create_init_mask(image) # this returns a torch tensor
# this returns a torch tensor
init_mask = self._create_init_mask(image)
if mask_path:
mask_image = self._load_img(mask_path, width, height, fit=fit) # this returns an Image
init_mask = self._create_init_mask(mask_image)
mask_image = self._load_img(
mask_path, width, height, fit=fit) # this returns an Image
init_mask = self._create_init_mask(mask_image)
return init_image,init_mask
return init_image, init_mask
def _make_img2img(self):
if not self.generators.get('img2img'):
@ -619,38 +653,26 @@ class Generate:
codeformer_fidelity = 0.75,
save_original = False,
image_callback = None):
try:
if upscale is not None:
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
if strength > 0:
if facetool == 'codeformer':
from ldm.restoration.codeformer.codeformer import CodeFormerRestoration
else:
from ldm.gfpgan.gfpgan_tools import run_gfpgan
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
return
for r in image_list:
image, seed = r
try:
if upscale is not None:
if len(upscale) < 2:
upscale.append(0.75)
image = real_esrgan_upscale(
image,
upscale[1],
int(upscale[0]),
seed,
)
if strength > 0:
if facetool == 'codeformer':
image = CodeFormerRestoration().process(image=image, strength=strength, device=self.device, seed=seed, fidelity=codeformer_fidelity)
if self.esrgan is not None:
if len(upscale) < 2:
upscale.append(0.75)
image = self.esrgan.process(
image, upscale[1], seed, int(upscale[0]))
else:
image = run_gfpgan(
image, strength, seed, 1
)
print(">> ESRGAN is disabled. Image not upscaled.")
if strength > 0:
if self.gfpgan is not None and self.codeformer is not None:
if facetool == 'codeformer':
image = self.codeformer.process(image=image, strength=strength, device=self.device, seed=seed, fidelity=codeformer_fidelity)
else:
image = self.gfpgan.process(image, strength, seed)
else:
print(">> Face Restoration is disabled.")
except Exception as e:
print(
f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}'
@ -662,10 +684,10 @@ class Generate:
r[0] = image
# to help WebGUI - front end to generator util function
def sample_to_image(self,samples):
def sample_to_image(self, samples):
return self._sample_to_image(samples)
def _sample_to_image(self,samples):
def _sample_to_image(self, samples):
if not self.base_generator:
from ldm.dream.generator import Generator
self.base_generator = Generator(self.model)
@ -708,7 +730,7 @@ class Generate:
# for usage statistics
device_type = choose_torch_device()
if device_type == 'cuda':
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_peak_memory_stats()
tic = time.time()
# this does the work
@ -756,12 +778,12 @@ class Generate:
f'>> loaded input image of size {image.width}x{image.height} from {path}'
)
if fit:
image = self._fit_image(image,(width,height))
image = self._fit_image(image, (width, height))
else:
image = self._squeeze_image(image)
return image
def _create_init_image(self,image):
def _create_init_image(self, image):
image = image.convert('RGB')
# print(
# f'>> DEBUG: writing the image to img.png'
@ -770,7 +792,7 @@ class Generate:
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image = 2.0 * image - 1.0
image = 2.0 * image - 1.0
return image.to(self.device)
def _create_init_mask(self, image):
@ -779,7 +801,8 @@ class Generate:
image = image.convert('RGB')
# BUG: We need to use the model's downsample factor rather than hardcoding "8"
from ldm.dream.generator.base import downsampling
image = image.resize((image.width//downsampling, image.height//downsampling), resample=Image.Resampling.LANCZOS)
image = image.resize((image.width//downsampling, image.height //
downsampling), resample=Image.Resampling.LANCZOS)
# print(
# f'>> DEBUG: writing the mask to mask.png'
# )
@ -801,7 +824,7 @@ class Generate:
mask = ImageOps.invert(mask)
return mask
def _has_transparency(self,image):
def _has_transparency(self, image):
if image.info.get("transparency", None) is not None:
return True
if image.mode == "P":
@ -815,11 +838,10 @@ class Generate:
return True
return False
def _check_for_erasure(self,image):
def _check_for_erasure(self, image):
width, height = image.size
pixdata = image.load()
colored = 0
pixdata = image.load()
colored = 0
for y in range(height):
for x in range(width):
if pixdata[x, y][3] == 0:
@ -829,28 +851,28 @@ class Generate:
colored += 1
return colored == 0
def _squeeze_image(self,image):
x,y,resize_needed = self._resolution_check(image.width,image.height)
def _squeeze_image(self, image):
x, y, resize_needed = self._resolution_check(image.width, image.height)
if resize_needed:
return InitImageResizer(image).resize(x,y)
return InitImageResizer(image).resize(x, y)
return image
def _fit_image(self,image,max_dimensions):
w,h = max_dimensions
def _fit_image(self, image, max_dimensions):
w, h = max_dimensions
print(
f'>> image will be resized to fit inside a box {w}x{h} in size.'
)
if image.width > image.height:
h = None # by setting h to none, we tell InitImageResizer to fit into the width and calculate height
h = None # by setting h to none, we tell InitImageResizer to fit into the width and calculate height
elif image.height > image.width:
w = None # ditto for w
w = None # ditto for w
else:
pass
image = InitImageResizer(image).resize(w,h) # note that InitImageResizer does the multiple of 64 truncation internally
# note that InitImageResizer does the multiple of 64 truncation internally
image = InitImageResizer(image).resize(w, h)
print(
f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}'
)
)
return image
def _resolution_check(self, width, height, log=False):
@ -864,7 +886,7 @@ class Generate:
f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}'
)
height = h
width = w
width = w
resize_needed = True
if (width * height) > (self.width * self.height):

View File

@ -1,168 +0,0 @@
import torch
import warnings
import os
import sys
import numpy as np
from PIL import Image
#from scripts.dream import create_argv_parser
from ldm.dream.args import Args
opt = Args()
opt.parse_args()
model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path)
gfpgan_model_exists = os.path.isfile(model_path)
def run_gfpgan(image, strength, seed, upsampler_scale=4):
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
gfpgan = None
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
if not gfpgan_model_exists:
raise Exception('GFPGAN model not found at path ' + model_path)
sys.path.append(os.path.abspath(opt.gfpgan_dir))
from gfpgan import GFPGANer
bg_upsampler = _load_gfpgan_bg_upsampler(
opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile
)
gfpgan = GFPGANer(
model_path=model_path,
upscale=upsampler_scale,
arch='clean',
channel_multiplier=2,
bg_upsampler=bg_upsampler,
)
except Exception:
import traceback
print('>> Error loading GFPGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if gfpgan is None:
print(
f'>> WARNING: GFPGAN not initialized.'
)
print(
f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth to {model_path}, \nor change GFPGAN directory with --gfpgan_dir.'
)
return image
image = image.convert('RGB')
cropped_faces, restored_faces, restored_img = gfpgan.enhance(
np.array(image, dtype=np.uint8),
has_aligned=False,
only_center_face=False,
paste_back=True,
)
res = Image.fromarray(restored_img)
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if restored_img.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
gfpgan = None
return res
def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400):
if bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
model_path = {
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
}
if upsampler_scale not in model_path:
return None
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
if upsampler_scale == 4:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
if upsampler_scale == 2:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
bg_upsampler = RealESRGANer(
scale=upsampler_scale,
model_path=model_path[upsampler_scale],
model=model,
tile=bg_tile,
tile_pad=10,
pre_pad=0,
half=use_half_precision,
)
else:
bg_upsampler = None
return bg_upsampler
def real_esrgan_upscale(image, strength, upsampler_scale, seed):
print(
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
upsampler = _load_gfpgan_bg_upsampler(
opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile
)
except Exception:
import traceback
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
output, img_mode = upsampler.enhance(
np.array(image, dtype=np.uint8),
outscale=upsampler_scale,
alpha_upsampler=opt.gfpgan_bg_upsampler,
)
res = Image.fromarray(output)
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if output.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
upsampler = None
return res

View File

@ -2,12 +2,20 @@ import os
import torch
import numpy as np
import warnings
import sys
pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
class CodeFormerRestoration():
def __init__(self) -> None:
pass
def __init__(self,
codeformer_dir='ldm/restoration/codeformer',
codeformer_model_path='weights/codeformer.pth') -> None:
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
self.codeformer_model_exists = os.path.isfile(self.model_path)
if not self.codeformer_model_exists:
print('## NOT FOUND: CodeFormer model not found at ' + self.model_path)
sys.path.append(os.path.abspath(codeformer_dir))
def process(self, image, strength, device, seed=None, fidelity=0.75):
if seed is not None:

View File

@ -0,0 +1,76 @@
import torch
import warnings
import os
import sys
import numpy as np
from PIL import Image
class GFPGAN():
def __init__(
self,
gfpgan_dir='src/gfpgan',
gfpgan_model_path='experiments/pretrained_models/GFPGANv1.3.pth') -> None:
self.model_path = os.path.join(gfpgan_dir, gfpgan_model_path)
self.gfpgan_model_exists = os.path.isfile(self.model_path)
if not self.gfpgan_model_exists:
raise Exception(
'GFPGAN model not found at path ' + self.model_path)
sys.path.append(os.path.abspath(gfpgan_dir))
def model_exists(self):
return os.path.isfile(self.model_path)
def process(self, image, strength: float, seed: str = None):
if seed is not None:
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
from gfpgan import GFPGANer
self.gfpgan = GFPGANer(
model_path=self.model_path,
upscale=1,
arch='clean',
channel_multiplier=2,
bg_upsampler=None,
)
except Exception:
import traceback
print('>> Error loading GFPGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if self.gfpgan is None:
print(
f'>> WARNING: GFPGAN not initialized.'
)
print(
f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth to {self.model_path}, \nor change GFPGAN directory with --gfpgan_dir.'
)
image = image.convert('RGB')
_, _, restored_img = self.gfpgan.enhance(
np.array(image, dtype=np.uint8),
has_aligned=False,
only_center_face=False,
paste_back=True,
)
res = Image.fromarray(restored_img)
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if restored_img.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.gfpgan = None
return res

View File

@ -0,0 +1,102 @@
import torch
import warnings
import numpy as np
from PIL import Image
class ESRGAN():
def __init__(self, bg_tile_size=400) -> None:
self.bg_tile_size = bg_tile_size
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
def load_esrgan_bg_upsampler(self, upsampler_scale):
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
model_path = {
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
}
if upsampler_scale not in model_path:
return None
else:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
if upsampler_scale == 4:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
if upsampler_scale == 2:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
bg_upsampler = RealESRGANer(
scale=upsampler_scale,
model_path=model_path[upsampler_scale],
model=model,
tile=self.bg_tile_size,
tile_pad=10,
pre_pad=0,
half=use_half_precision,
)
return bg_upsampler
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
if seed is not None:
print(
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
upsampler = self.load_esrgan_bg_upsampler(upsampler_scale)
except Exception:
import traceback
import sys
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
output, _ = upsampler.enhance(
np.array(image, dtype=np.uint8),
outscale=upsampler_scale,
alpha_upsampler='realesrgan',
)
res = Image.fromarray(output)
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if output.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
upsampler = None
return res

View File

@ -0,0 +1,34 @@
class Restoration():
def __init__(self, gfpgan_dir='./src/gfpgan', gfpgan_model_path='experiments/pretrained_models/GFPGANv1.3.pth', esrgan_bg_tile=400) -> None:
self.gfpgan_dir = gfpgan_dir
self.gfpgan_model_path = gfpgan_model_path
self.esrgan_bg_tile = esrgan_bg_tile
def load_face_restore_models(self):
# Load GFPGAN
gfpgan = self.load_gfpgan()
if gfpgan.gfpgan_model_exists:
print('>> GFPGAN Initialized')
# Load CodeFormer
codeformer = self.load_codeformer()
if codeformer.codeformer_model_exists:
print('>> CodeFormer Initialized')
return gfpgan, codeformer
# Face Restore Models
def load_gfpgan(self):
from ldm.restoration.gfpgan.gfpgan import GFPGAN
return GFPGAN(self.gfpgan_dir, self.gfpgan_model_path)
def load_codeformer(self):
from ldm.restoration.codeformer.codeformer import CodeFormerRestoration
return CodeFormerRestoration()
# Upscale Models
def load_ersgan(self):
from ldm.restoration.realesrgan.realesrgan import ESRGAN
esrgan = ESRGAN(self.esrgan_bg_tile)
print('>> ESRGAN Initialized')
return esrgan;

View File

@ -43,7 +43,25 @@ def main():
import transformers
transformers.logging.set_verbosity_error()
# creating a simple Generate object with a handful of
# Loading Face Restoration and ESRGAN Modules
try:
gfpgan, codeformer, esrgan = None, None, None
from ldm.restoration.restoration import Restoration
restoration = Restoration(opt.gfpgan_dir, opt.gfpgan_model_path, opt.esrgan_bg_tile)
if opt.restore:
gfpgan, codeformer = restoration.load_face_restore_models()
else:
print('>> Face Restoration Disabled')
if opt.esrgan:
esrgan = restoration.load_ersgan()
else:
print('>> ESRGAN Disabled')
except (ModuleNotFoundError, ImportError):
import traceback
print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
# creating a simple text2image object with a handful of
# defaults passed on the command line.
# additional parameters will be added (or overriden) during
# the user input loop
@ -55,6 +73,9 @@ def main():
embedding_path = opt.embedding_path,
full_precision = opt.full_precision,
precision = opt.precision,
gfpgan=gfpgan,
codeformer=codeformer,
esrgan=esrgan
)
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')
@ -91,7 +112,7 @@ def main():
# web server loops forever
if opt.web:
dream_server_loop(gen, opt.host, opt.port, opt.outdir)
dream_server_loop(gen, opt.host, opt.port, opt.outdir, gfpgan)
sys.exit(0)
main_loop(gen, opt, infile)
@ -347,7 +368,7 @@ def get_next_command(infile=None) -> str: # command string
print(f'#{command}')
return command
def dream_server_loop(gen, host, port, outdir):
def dream_server_loop(gen, host, port, outdir, gfpgan):
print('\n* --web was specified, starting web server...')
# Change working directory to the stable-diffusion directory
os.chdir(
@ -357,6 +378,10 @@ def dream_server_loop(gen, host, port, outdir):
# Start server
DreamServer.model = gen # misnomer in DreamServer - this is not the model you are looking for
DreamServer.outdir = outdir
DreamServer.gfpgan_model_exists = False
if gfpgan is not None:
DreamServer.gfpgan_model_exists = gfpgan.gfpgan_model_exists
dream_server = ThreadingDreamServer((host, port))
print(">> Started Stable Diffusion dream server!")
if host == '0.0.0.0':
@ -374,5 +399,19 @@ def dream_server_loop(gen, host, port, outdir):
dream_server.server_close()
<<<<<<< HEAD
=======
def write_log_message(results, log_path):
"""logs the name of the output image, prompt, and prompt args to the terminal and log file"""
global output_cntr
log_lines = [f'{path}: {prompt}\n' for path, prompt in results]
for l in log_lines:
output_cntr += 1
print(f'[{output_cntr}] {l}', end='')
with open(log_path, 'a', encoding='utf-8') as file:
file.writelines(log_lines)
>>>>>>> GFPGAN and Real ESRGAN Implementation Refactor
if __name__ == '__main__':
main()