mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
GFPGAN and Real ESRGAN Implementation Refactor
This commit is contained in:
parent
e8bb39370c
commit
1b5013ab72
@ -4,10 +4,16 @@ title: Upscale
|
||||
|
||||
# :material-image-size-select-large: Upscale
|
||||
|
||||
## **Intro**
|
||||
|
||||
The script provides the ability to restore faces and upscale.
|
||||
|
||||
You can enable these features by passing `--restore` and `--esrgan` to your launch script to enable
|
||||
face restoration modules and upscaling modules respectively.
|
||||
|
||||
## **GFPGAN and Real-ESRGAN Support**
|
||||
|
||||
The script also provides the ability to do face restoration and upscaling with the help of GFPGAN
|
||||
and Real-ESRGAN respectively.
|
||||
The default face restoration module is GFPGAN and the default upscaling module is ESRGAN.
|
||||
|
||||
As of version 1.14, environment.yaml will install the Real-ESRGAN package into the standard install
|
||||
location for python packages, and will put GFPGAN into a subdirectory of "src" in the
|
||||
|
@ -370,16 +370,19 @@ class Args(object):
|
||||
type=str,
|
||||
help='Path to a pre-trained embedding manager checkpoint - can only be set on command line',
|
||||
)
|
||||
# GFPGAN related args
|
||||
# Restoration related args
|
||||
postprocessing_group.add_argument(
|
||||
'--gfpgan_bg_upsampler',
|
||||
type=str,
|
||||
default='realesrgan',
|
||||
help='Background upsampler. Default: realesrgan. Options: realesrgan, none.',
|
||||
|
||||
'--restore',
|
||||
action='store_true',
|
||||
help='Enable Face Restoration',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--gfpgan_bg_tile',
|
||||
'--esrgan',
|
||||
action='store_true',
|
||||
help='Enable Upscaling',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--esrgan_bg_tile',
|
||||
type=int,
|
||||
default=400,
|
||||
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
|
||||
|
@ -4,11 +4,12 @@ and generates with ldm.dream.generator.img2img
|
||||
'''
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from ldm.dream.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.dream.generator.img2img import Img2Img
|
||||
from ldm.dream.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.dream.generator.img2img import Img2Img
|
||||
|
||||
|
||||
class Embiggen(Generator):
|
||||
def __init__(self, model, precision):
|
||||
@ -38,19 +39,20 @@ class Embiggen(Generator):
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
# Construct embiggen arg array, and sanity check arguments
|
||||
if embiggen == None: # embiggen can also be called with just embiggen_tiles
|
||||
embiggen = [1.0] # If not specified, assume no scaling
|
||||
elif embiggen[0] < 0 :
|
||||
if embiggen == None: # embiggen can also be called with just embiggen_tiles
|
||||
embiggen = [1.0] # If not specified, assume no scaling
|
||||
elif embiggen[0] < 0:
|
||||
embiggen[0] = 1.0
|
||||
print('>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !')
|
||||
print(
|
||||
'>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !')
|
||||
if len(embiggen) < 2:
|
||||
embiggen.append(0.75)
|
||||
elif embiggen[1] > 1.0 or embiggen[1] < 0 :
|
||||
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
||||
embiggen[1] = 0.75
|
||||
print('>> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !')
|
||||
if len(embiggen) < 3:
|
||||
embiggen.append(0.25)
|
||||
elif embiggen[2] < 0 :
|
||||
elif embiggen[2] < 0:
|
||||
embiggen[2] = 0.25
|
||||
print('>> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !')
|
||||
|
||||
@ -76,29 +78,30 @@ class Embiggen(Generator):
|
||||
if embiggen[0] != 1.0:
|
||||
initsuperwidth = round(initsuperwidth*embiggen[0])
|
||||
initsuperheight = round(initsuperheight*embiggen[0])
|
||||
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
|
||||
from ldm.gfpgan.gfpgan_tools import (
|
||||
real_esrgan_upscale,
|
||||
)
|
||||
print(f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}')
|
||||
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
|
||||
from ldm.restoration.realesrgan import ESRGAN
|
||||
esrgan = ESRGAN()
|
||||
print(
|
||||
f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}')
|
||||
if embiggen[0] > 2:
|
||||
initsuperimage = real_esrgan_upscale(
|
||||
initsuperimage = esrgan.process(
|
||||
initsuperimage,
|
||||
embiggen[1], # upscale strength
|
||||
4, # upscale scale
|
||||
embiggen[1], # upscale strength
|
||||
self.seed,
|
||||
4, # upscale scale
|
||||
)
|
||||
else:
|
||||
initsuperimage = real_esrgan_upscale(
|
||||
initsuperimage = esrgan.process(
|
||||
initsuperimage,
|
||||
embiggen[1], # upscale strength
|
||||
2, # upscale scale
|
||||
embiggen[1], # upscale strength
|
||||
self.seed,
|
||||
2, # upscale scale
|
||||
)
|
||||
# We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x
|
||||
# but from personal experiance it doesn't greatly improve anything after 4x
|
||||
# Resize to target scaling factor resolution
|
||||
initsuperimage = initsuperimage.resize((initsuperwidth, initsuperheight), Image.Resampling.LANCZOS)
|
||||
initsuperimage = initsuperimage.resize(
|
||||
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS)
|
||||
|
||||
# Use width and height as tile widths and height
|
||||
# Determine buffer size in pixels
|
||||
@ -121,28 +124,31 @@ class Embiggen(Generator):
|
||||
emb_tiles_x = 1
|
||||
emb_tiles_y = 1
|
||||
if (initsuperwidth - width) > 0:
|
||||
emb_tiles_x = ceildiv(initsuperwidth - width, width - overlap_size_x) + 1
|
||||
emb_tiles_x = ceildiv(initsuperwidth - width,
|
||||
width - overlap_size_x) + 1
|
||||
if (initsuperheight - height) > 0:
|
||||
emb_tiles_y = ceildiv(initsuperheight - height, height - overlap_size_y) + 1
|
||||
emb_tiles_y = ceildiv(initsuperheight - height,
|
||||
height - overlap_size_y) + 1
|
||||
# Sanity
|
||||
assert emb_tiles_x > 1 or emb_tiles_y > 1, f'ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don\'t need to Embiggen! Check your arguments.'
|
||||
|
||||
# Prep alpha layers --------------
|
||||
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
|
||||
# agradientL is Left-side transparent
|
||||
agradientL = Image.linear_gradient('L').rotate(90).resize((overlap_size_x, height))
|
||||
agradientL = Image.linear_gradient('L').rotate(
|
||||
90).resize((overlap_size_x, height))
|
||||
# agradientT is Top-side transparent
|
||||
agradientT = Image.linear_gradient('L').resize((width, overlap_size_y))
|
||||
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
|
||||
agradientC = Image.new('L', (256, 256))
|
||||
for y in range(256):
|
||||
for x in range(256):
|
||||
#Find distance to lower right corner (numpy takes arrays)
|
||||
# Find distance to lower right corner (numpy takes arrays)
|
||||
distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0]
|
||||
#Clamp values to max 255
|
||||
# Clamp values to max 255
|
||||
if distanceToLR > 255:
|
||||
distanceToLR = 255
|
||||
#Place the pixel as invert of distance
|
||||
# Place the pixel as invert of distance
|
||||
agradientC.putpixel((x, y), int(255 - distanceToLR))
|
||||
|
||||
# Create alpha layers default fully white
|
||||
@ -154,59 +160,79 @@ class Embiggen(Generator):
|
||||
alphaLayerT.paste(agradientT, (0, 0))
|
||||
alphaLayerLTC.paste(agradientL, (0, 0))
|
||||
alphaLayerLTC.paste(agradientT, (0, 0))
|
||||
alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
|
||||
alphaLayerLTC.paste(agradientC.resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||
|
||||
if embiggen_tiles:
|
||||
# Individual unconnected sides
|
||||
alphaLayerR = Image.new("L", (width, height), 255)
|
||||
alphaLayerR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerR.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerB = Image.new("L", (width, height), 255)
|
||||
alphaLayerB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerB.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerTB = Image.new("L", (width, height), 255)
|
||||
alphaLayerTB.paste(agradientT, (0, 0))
|
||||
alphaLayerTB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerTB.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerLR = Image.new("L", (width, height), 255)
|
||||
alphaLayerLR.paste(agradientL, (0, 0))
|
||||
alphaLayerLR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerLR.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
|
||||
# Sides and corner Layers
|
||||
alphaLayerRBC = Image.new("L", (width, height), 255)
|
||||
alphaLayerRBC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerRBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerRBC.paste(agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||
alphaLayerRBC.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerRBC.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerRBC.paste(agradientC.rotate(180).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||
alphaLayerLBC = Image.new("L", (width, height), 255)
|
||||
alphaLayerLBC.paste(agradientL, (0, 0))
|
||||
alphaLayerLBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerLBC.paste(agradientC.rotate(90).resize((overlap_size_x, overlap_size_y)), (0, height - overlap_size_y))
|
||||
alphaLayerLBC.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerLBC.paste(agradientC.rotate(90).resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, height - overlap_size_y))
|
||||
alphaLayerRTC = Image.new("L", (width, height), 255)
|
||||
alphaLayerRTC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerRTC.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerRTC.paste(agradientT, (0, 0))
|
||||
alphaLayerRTC.paste(agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
alphaLayerRTC.paste(agradientC.rotate(270).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
|
||||
# All but X layers
|
||||
alphaLayerABT = Image.new("L", (width, height), 255)
|
||||
alphaLayerABT.paste(alphaLayerLBC, (0, 0))
|
||||
alphaLayerABT.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerABT.paste(agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||
alphaLayerABT.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerABT.paste(agradientC.rotate(180).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||
alphaLayerABL = Image.new("L", (width, height), 255)
|
||||
alphaLayerABL.paste(alphaLayerRTC, (0, 0))
|
||||
alphaLayerABL.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerABL.paste(agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||
alphaLayerABL.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerABL.paste(agradientC.rotate(180).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||
alphaLayerABR = Image.new("L", (width, height), 255)
|
||||
alphaLayerABR.paste(alphaLayerLBC, (0, 0))
|
||||
alphaLayerABR.paste(agradientT, (0, 0))
|
||||
alphaLayerABR.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
|
||||
alphaLayerABR.paste(agradientC.resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||
alphaLayerABB = Image.new("L", (width, height), 255)
|
||||
alphaLayerABB.paste(alphaLayerRTC, (0, 0))
|
||||
alphaLayerABB.paste(agradientL, (0, 0))
|
||||
alphaLayerABB.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
|
||||
alphaLayerABB.paste(agradientC.resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||
|
||||
# All-around layer
|
||||
alphaLayerAA = Image.new("L", (width, height), 255)
|
||||
alphaLayerAA.paste(alphaLayerABT, (0, 0))
|
||||
alphaLayerAA.paste(agradientT, (0, 0))
|
||||
alphaLayerAA.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
|
||||
alphaLayerAA.paste(agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
alphaLayerAA.paste(agradientC.resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||
alphaLayerAA.paste(agradientC.rotate(270).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
|
||||
# Clean up temporary gradients
|
||||
del agradientL
|
||||
@ -218,7 +244,8 @@ class Embiggen(Generator):
|
||||
if embiggen_tiles:
|
||||
print(f'>> Making {len(embiggen_tiles)} Embiggen tiles...')
|
||||
else:
|
||||
print(f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...')
|
||||
print(
|
||||
f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...')
|
||||
|
||||
emb_tile_store = []
|
||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||
@ -240,20 +267,23 @@ class Embiggen(Generator):
|
||||
top = round(emb_row_i * (height - overlap_size_y))
|
||||
right = left + width
|
||||
bottom = top + height
|
||||
|
||||
|
||||
# Cropped image of above dimension (does not modify the original)
|
||||
newinitimage = initsuperimage.crop((left, top, right, bottom))
|
||||
# DEBUG:
|
||||
# newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png'
|
||||
# newinitimage.save(newinitimagepath)
|
||||
|
||||
|
||||
if embiggen_tiles:
|
||||
print(f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)')
|
||||
print(
|
||||
f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)')
|
||||
else:
|
||||
print(f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles')
|
||||
print(
|
||||
f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles')
|
||||
|
||||
# create a torch tensor from an Image
|
||||
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
||||
newinitimage = np.array(
|
||||
newinitimage).astype(np.float32) / 255.0
|
||||
newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
|
||||
newinitimage = torch.from_numpy(newinitimage)
|
||||
newinitimage = 2.0 * newinitimage - 1.0
|
||||
@ -261,33 +291,35 @@ class Embiggen(Generator):
|
||||
|
||||
tile_results = gen_img2img.generate(
|
||||
prompt,
|
||||
iterations = 1,
|
||||
seed = self.seed,
|
||||
sampler = sampler,
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
conditioning = conditioning,
|
||||
ddim_eta = ddim_eta,
|
||||
image_callback = None, # called only after the final image is generated
|
||||
step_callback = step_callback, # called after each intermediate image is generated
|
||||
width = width,
|
||||
height = height,
|
||||
init_img = init_img, # img2img doesn't need this, but it might in the future
|
||||
init_image = newinitimage, # notice that init_image is different from init_img
|
||||
mask_image = None,
|
||||
strength = strength,
|
||||
iterations=1,
|
||||
seed=self.seed,
|
||||
sampler=sampler,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
conditioning=conditioning,
|
||||
ddim_eta=ddim_eta,
|
||||
image_callback=None, # called only after the final image is generated
|
||||
step_callback=step_callback, # called after each intermediate image is generated
|
||||
width=width,
|
||||
height=height,
|
||||
init_img=init_img, # img2img doesn't need this, but it might in the future
|
||||
init_image=newinitimage, # notice that init_image is different from init_img
|
||||
mask_image=None,
|
||||
strength=strength,
|
||||
)
|
||||
|
||||
emb_tile_store.append(tile_results[0][0])
|
||||
# DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite
|
||||
# emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png')
|
||||
del newinitimage
|
||||
|
||||
|
||||
# Sanity check we have them all
|
||||
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)):
|
||||
outputsuperimage = Image.new("RGBA", (initsuperwidth, initsuperheight))
|
||||
outputsuperimage = Image.new(
|
||||
"RGBA", (initsuperwidth, initsuperheight))
|
||||
if embiggen_tiles:
|
||||
outputsuperimage.alpha_composite(initsuperimage.convert('RGBA'), (0, 0))
|
||||
outputsuperimage.alpha_composite(
|
||||
initsuperimage.convert('RGBA'), (0, 0))
|
||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||
if embiggen_tiles:
|
||||
if tile in embiggen_tiles:
|
||||
@ -308,7 +340,8 @@ class Embiggen(Generator):
|
||||
if emb_column_i + 1 == emb_tiles_x:
|
||||
left = initsuperwidth - width
|
||||
else:
|
||||
left = round(emb_column_i * (width - overlap_size_x))
|
||||
left = round(emb_column_i *
|
||||
(width - overlap_size_x))
|
||||
if emb_row_i + 1 == emb_tiles_y:
|
||||
top = initsuperheight - height
|
||||
else:
|
||||
@ -319,33 +352,33 @@ class Embiggen(Generator):
|
||||
# top of image
|
||||
if emb_row_i == 0:
|
||||
if emb_column_i == 0:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) not in embiggen_tiles: # Look-ahead down
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) not in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerB)
|
||||
# Otherwise do nothing on this tile
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerR)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerRBC)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLBC)
|
||||
else:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLBC)
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerLR)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABT)
|
||||
# bottom of image
|
||||
elif emb_row_i == emb_tiles_y - 1:
|
||||
if emb_column_i == 0:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
intileimage.putalpha(alphaLayerT)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerRTC)
|
||||
@ -353,34 +386,34 @@ class Embiggen(Generator):
|
||||
# No tiles to look ahead to
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABB)
|
||||
# vertical middle of image
|
||||
else:
|
||||
if emb_column_i == 0:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerT)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerTB)
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerRTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABL)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABR)
|
||||
else:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABR)
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerABB)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerAA)
|
||||
@ -400,4 +433,4 @@ class Embiggen(Generator):
|
||||
# after internal loops and patching up return Embiggen image
|
||||
return outputsuperimage
|
||||
# end of function declaration
|
||||
return make_image
|
||||
return make_image
|
||||
|
@ -37,6 +37,8 @@ def build_opt(post_data, seed, gfpgan_model_exists):
|
||||
setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed']))
|
||||
setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0)
|
||||
setattr(opt, 'with_variations', [])
|
||||
setattr(opt, 'embiggen', None)
|
||||
setattr(opt, 'embiggen_tiles', None)
|
||||
|
||||
broken = False
|
||||
if int(post_data['seed']) != -1 and post_data['with_variations'] != '':
|
||||
@ -80,12 +82,11 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
self.wfile.write(content.read())
|
||||
elif self.path == "/config.js":
|
||||
# unfortunately this import can't be at the top level, since that would cause a circular import
|
||||
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/javascript")
|
||||
self.end_headers()
|
||||
config = {
|
||||
'gfpgan_model_exists': gfpgan_model_exists
|
||||
'gfpgan_model_exists': self.gfpgan_model_exists
|
||||
}
|
||||
self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8"))
|
||||
elif self.path == "/run_log.json":
|
||||
@ -138,11 +139,10 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
self.end_headers()
|
||||
|
||||
# unfortunately this import can't be at the top level, since that would cause a circular import
|
||||
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
|
||||
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
post_data = json.loads(self.rfile.read(content_length))
|
||||
opt = build_opt(post_data, self.model.seed, gfpgan_model_exists)
|
||||
opt = build_opt(post_data, self.model.seed, self.gfpgan_model_exists)
|
||||
|
||||
self.canceled.clear()
|
||||
# In order to handle upscaled images, the PngWriter needs to maintain state
|
||||
|
224
ldm/generate.py
224
ldm/generate.py
@ -23,9 +23,9 @@ from PIL import Image, ImageOps
|
||||
from torch import nn
|
||||
from pytorch_lightning import seed_everything, logging
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ksampler import KSampler
|
||||
from ldm.dream.pngwriter import PngWriter, retrieve_metadata
|
||||
from ldm.dream.args import metadata_loads
|
||||
@ -51,6 +51,24 @@ torch.randint_like = fix_func(torch.randint_like)
|
||||
torch.bernoulli = fix_func(torch.bernoulli)
|
||||
torch.multinomial = fix_func(torch.multinomial)
|
||||
|
||||
def fix_func(orig):
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
def new_func(*args, **kw):
|
||||
device = kw.get("device", "mps")
|
||||
kw["device"]="cpu"
|
||||
return orig(*args, **kw).to(device)
|
||||
return new_func
|
||||
return orig
|
||||
|
||||
torch.rand = fix_func(torch.rand)
|
||||
torch.rand_like = fix_func(torch.rand_like)
|
||||
torch.randn = fix_func(torch.randn)
|
||||
torch.randn_like = fix_func(torch.randn_like)
|
||||
torch.randint = fix_func(torch.randint)
|
||||
torch.randint_like = fix_func(torch.randint_like)
|
||||
torch.bernoulli = fix_func(torch.bernoulli)
|
||||
torch.multinomial = fix_func(torch.multinomial)
|
||||
|
||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||
|
||||
Example Usage:
|
||||
@ -135,6 +153,9 @@ class Generate:
|
||||
# these are deprecated; if present they override values in the conf file
|
||||
weights = None,
|
||||
config = None,
|
||||
gfpgan=None,
|
||||
codeformer=None,
|
||||
esrgan=None
|
||||
):
|
||||
models = OmegaConf.load(conf)
|
||||
mconfig = models[model]
|
||||
@ -158,6 +179,9 @@ class Generate:
|
||||
self.generators = {}
|
||||
self.base_generator = None
|
||||
self.seed = None
|
||||
self.gfpgan = gfpgan
|
||||
self.codeformer = codeformer
|
||||
self.esrgan = esrgan
|
||||
|
||||
# Note that in previous versions, there was an option to pass the
|
||||
# device to Generate(). However the device was then ignored, so
|
||||
@ -234,8 +258,8 @@ class Generate:
|
||||
strength = None,
|
||||
init_color = None,
|
||||
# these are specific to embiggen (which also relies on img2img args)
|
||||
embiggen = None,
|
||||
embiggen_tiles = None,
|
||||
embiggen=None,
|
||||
embiggen_tiles=None,
|
||||
# these are specific to GFPGAN/ESRGAN
|
||||
facetool = None,
|
||||
gfpgan_strength = 0,
|
||||
@ -284,15 +308,15 @@ class Generate:
|
||||
write the prompt into the PNG metadata.
|
||||
"""
|
||||
# TODO: convert this into a getattr() loop
|
||||
steps = steps or self.steps
|
||||
width = width or self.width
|
||||
height = height or self.height
|
||||
seamless = seamless or self.seamless
|
||||
cfg_scale = cfg_scale or self.cfg_scale
|
||||
ddim_eta = ddim_eta or self.ddim_eta
|
||||
iterations = iterations or self.iterations
|
||||
strength = strength or self.strength
|
||||
self.seed = seed
|
||||
steps = steps or self.steps
|
||||
width = width or self.width
|
||||
height = height or self.height
|
||||
seamless = seamless or self.seamless
|
||||
cfg_scale = cfg_scale or self.cfg_scale
|
||||
ddim_eta = ddim_eta or self.ddim_eta
|
||||
iterations = iterations or self.iterations
|
||||
strength = strength or self.strength
|
||||
self.seed = seed
|
||||
self.log_tokenization = log_tokenization
|
||||
self.step_callback = step_callback
|
||||
with_variations = [] if with_variations is None else with_variations
|
||||
@ -303,16 +327,17 @@ class Generate:
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
|
||||
|
||||
|
||||
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
|
||||
assert (
|
||||
0.0 < strength < 1.0
|
||||
), 'img2img and inpaint strength can only work with 0.0 < strength < 1.0'
|
||||
assert (
|
||||
0.0 <= variation_amount <= 1.0
|
||||
0.0 <= variation_amount <= 1.0
|
||||
), '-v --variation_amount must be in [0.0, 1.0]'
|
||||
assert (
|
||||
(embiggen == None and embiggen_tiles == None) or ((embiggen != None or embiggen_tiles != None) and init_img != None)
|
||||
(embiggen == None and embiggen_tiles == None) or (
|
||||
(embiggen != None or embiggen_tiles != None) and init_img != None)
|
||||
), 'Embiggen requires an init/input image to be specified'
|
||||
|
||||
if len(with_variations) > 0 or variation_amount > 1.0:
|
||||
@ -334,9 +359,9 @@ class Generate:
|
||||
if self._has_cuda():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
results = list()
|
||||
init_image = None
|
||||
mask_image = None
|
||||
results = list()
|
||||
init_image = None
|
||||
mask_image = None
|
||||
|
||||
try:
|
||||
uc, c = get_uc_and_c(
|
||||
@ -345,8 +370,9 @@ class Generate:
|
||||
log_tokens =self.log_tokenization
|
||||
)
|
||||
|
||||
(init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit)
|
||||
|
||||
(init_image, mask_image) = self._make_images(
|
||||
init_img, init_mask, width, height, fit)
|
||||
|
||||
if (init_image is not None) and (mask_image is not None):
|
||||
generator = self._make_inpaint()
|
||||
elif (embiggen != None or embiggen_tiles != None):
|
||||
@ -356,26 +382,27 @@ class Generate:
|
||||
else:
|
||||
generator = self._make_txt2img()
|
||||
|
||||
generator.set_variation(self.seed, variation_amount, with_variations)
|
||||
generator.set_variation(
|
||||
self.seed, variation_amount, with_variations)
|
||||
results = generator.generate(
|
||||
prompt,
|
||||
iterations = iterations,
|
||||
seed = self.seed,
|
||||
sampler = self.sampler,
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
conditioning = (uc,c),
|
||||
ddim_eta = ddim_eta,
|
||||
image_callback = image_callback, # called after the final image is generated
|
||||
step_callback = step_callback, # called after each intermediate image is generated
|
||||
width = width,
|
||||
height = height,
|
||||
init_img = init_img, # embiggen needs to manipulate from the unmodified init_img
|
||||
init_image = init_image, # notice that init_image is different from init_img
|
||||
mask_image = mask_image,
|
||||
strength = strength,
|
||||
embiggen = embiggen,
|
||||
embiggen_tiles = embiggen_tiles,
|
||||
iterations=iterations,
|
||||
seed=self.seed,
|
||||
sampler=self.sampler,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
conditioning=(uc, c),
|
||||
ddim_eta=ddim_eta,
|
||||
image_callback=image_callback, # called after the final image is generated
|
||||
step_callback=step_callback, # called after each intermediate image is generated
|
||||
width=width,
|
||||
height=height,
|
||||
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
|
||||
init_image=init_image, # notice that init_image is different from init_img
|
||||
mask_image=mask_image,
|
||||
strength=strength,
|
||||
embiggen=embiggen,
|
||||
embiggen_tiles=embiggen_tiles,
|
||||
)
|
||||
|
||||
if init_color:
|
||||
@ -404,7 +431,8 @@ class Generate:
|
||||
toc = time.time()
|
||||
print('>> Usage stats:')
|
||||
print(
|
||||
f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
|
||||
f'>> {len(results)} image(s) generated in', '%4.2fs' % (
|
||||
toc - tic)
|
||||
)
|
||||
if self._has_cuda():
|
||||
print(
|
||||
@ -515,29 +543,35 @@ class Generate:
|
||||
|
||||
|
||||
def _make_images(self, img_path, mask_path, width, height, fit=False):
|
||||
init_image = None
|
||||
init_mask = None
|
||||
init_image = None
|
||||
init_mask = None
|
||||
if not img_path:
|
||||
return None,None
|
||||
return None, None
|
||||
|
||||
image = self._load_img(img_path, width, height, fit=fit) # this returns an Image
|
||||
init_image = self._create_init_image(image) # this returns a torch tensor
|
||||
image = self._load_img(img_path, width, height,
|
||||
fit=fit) # this returns an Image
|
||||
# this returns a torch tensor
|
||||
init_image = self._create_init_image(image)
|
||||
|
||||
if self._has_transparency(image) and not mask_path: # if image has a transparent area and no mask was provided, then try to generate mask
|
||||
print('>> Initial image has transparent areas. Will inpaint in these regions.')
|
||||
# if image has a transparent area and no mask was provided, then try to generate mask
|
||||
if self._has_transparency(image) and not mask_path:
|
||||
print(
|
||||
'>> Initial image has transparent areas. Will inpaint in these regions.')
|
||||
if self._check_for_erasure(image):
|
||||
print(
|
||||
'>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
|
||||
'>> Inpainting will be suboptimal. Please preserve the colors when making\n',
|
||||
'>> a transparency mask, or provide mask explicitly using --init_mask (-M).'
|
||||
)
|
||||
init_mask = self._create_init_mask(image) # this returns a torch tensor
|
||||
# this returns a torch tensor
|
||||
init_mask = self._create_init_mask(image)
|
||||
|
||||
if mask_path:
|
||||
mask_image = self._load_img(mask_path, width, height, fit=fit) # this returns an Image
|
||||
init_mask = self._create_init_mask(mask_image)
|
||||
mask_image = self._load_img(
|
||||
mask_path, width, height, fit=fit) # this returns an Image
|
||||
init_mask = self._create_init_mask(mask_image)
|
||||
|
||||
return init_image,init_mask
|
||||
return init_image, init_mask
|
||||
|
||||
def _make_img2img(self):
|
||||
if not self.generators.get('img2img'):
|
||||
@ -619,38 +653,26 @@ class Generate:
|
||||
codeformer_fidelity = 0.75,
|
||||
save_original = False,
|
||||
image_callback = None):
|
||||
try:
|
||||
if upscale is not None:
|
||||
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
|
||||
if strength > 0:
|
||||
if facetool == 'codeformer':
|
||||
from ldm.restoration.codeformer.codeformer import CodeFormerRestoration
|
||||
else:
|
||||
from ldm.gfpgan.gfpgan_tools import run_gfpgan
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
||||
return
|
||||
|
||||
for r in image_list:
|
||||
image, seed = r
|
||||
try:
|
||||
if upscale is not None:
|
||||
if len(upscale) < 2:
|
||||
upscale.append(0.75)
|
||||
image = real_esrgan_upscale(
|
||||
image,
|
||||
upscale[1],
|
||||
int(upscale[0]),
|
||||
seed,
|
||||
)
|
||||
if strength > 0:
|
||||
if facetool == 'codeformer':
|
||||
image = CodeFormerRestoration().process(image=image, strength=strength, device=self.device, seed=seed, fidelity=codeformer_fidelity)
|
||||
if self.esrgan is not None:
|
||||
if len(upscale) < 2:
|
||||
upscale.append(0.75)
|
||||
image = self.esrgan.process(
|
||||
image, upscale[1], seed, int(upscale[0]))
|
||||
else:
|
||||
image = run_gfpgan(
|
||||
image, strength, seed, 1
|
||||
)
|
||||
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||
if strength > 0:
|
||||
if self.gfpgan is not None and self.codeformer is not None:
|
||||
if facetool == 'codeformer':
|
||||
image = self.codeformer.process(image=image, strength=strength, device=self.device, seed=seed, fidelity=codeformer_fidelity)
|
||||
else:
|
||||
image = self.gfpgan.process(image, strength, seed)
|
||||
else:
|
||||
print(">> Face Restoration is disabled.")
|
||||
except Exception as e:
|
||||
print(
|
||||
f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}'
|
||||
@ -662,10 +684,10 @@ class Generate:
|
||||
r[0] = image
|
||||
|
||||
# to help WebGUI - front end to generator util function
|
||||
def sample_to_image(self,samples):
|
||||
def sample_to_image(self, samples):
|
||||
return self._sample_to_image(samples)
|
||||
|
||||
def _sample_to_image(self,samples):
|
||||
def _sample_to_image(self, samples):
|
||||
if not self.base_generator:
|
||||
from ldm.dream.generator import Generator
|
||||
self.base_generator = Generator(self.model)
|
||||
@ -708,7 +730,7 @@ class Generate:
|
||||
# for usage statistics
|
||||
device_type = choose_torch_device()
|
||||
if device_type == 'cuda':
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
tic = time.time()
|
||||
|
||||
# this does the work
|
||||
@ -756,12 +778,12 @@ class Generate:
|
||||
f'>> loaded input image of size {image.width}x{image.height} from {path}'
|
||||
)
|
||||
if fit:
|
||||
image = self._fit_image(image,(width,height))
|
||||
image = self._fit_image(image, (width, height))
|
||||
else:
|
||||
image = self._squeeze_image(image)
|
||||
return image
|
||||
|
||||
def _create_init_image(self,image):
|
||||
def _create_init_image(self, image):
|
||||
image = image.convert('RGB')
|
||||
# print(
|
||||
# f'>> DEBUG: writing the image to img.png'
|
||||
@ -770,7 +792,7 @@ class Generate:
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
image = 2.0 * image - 1.0
|
||||
image = 2.0 * image - 1.0
|
||||
return image.to(self.device)
|
||||
|
||||
def _create_init_mask(self, image):
|
||||
@ -779,7 +801,8 @@ class Generate:
|
||||
image = image.convert('RGB')
|
||||
# BUG: We need to use the model's downsample factor rather than hardcoding "8"
|
||||
from ldm.dream.generator.base import downsampling
|
||||
image = image.resize((image.width//downsampling, image.height//downsampling), resample=Image.Resampling.LANCZOS)
|
||||
image = image.resize((image.width//downsampling, image.height //
|
||||
downsampling), resample=Image.Resampling.LANCZOS)
|
||||
# print(
|
||||
# f'>> DEBUG: writing the mask to mask.png'
|
||||
# )
|
||||
@ -801,7 +824,7 @@ class Generate:
|
||||
mask = ImageOps.invert(mask)
|
||||
return mask
|
||||
|
||||
def _has_transparency(self,image):
|
||||
def _has_transparency(self, image):
|
||||
if image.info.get("transparency", None) is not None:
|
||||
return True
|
||||
if image.mode == "P":
|
||||
@ -815,11 +838,10 @@ class Generate:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _check_for_erasure(self,image):
|
||||
def _check_for_erasure(self, image):
|
||||
width, height = image.size
|
||||
pixdata = image.load()
|
||||
colored = 0
|
||||
pixdata = image.load()
|
||||
colored = 0
|
||||
for y in range(height):
|
||||
for x in range(width):
|
||||
if pixdata[x, y][3] == 0:
|
||||
@ -829,28 +851,28 @@ class Generate:
|
||||
colored += 1
|
||||
return colored == 0
|
||||
|
||||
def _squeeze_image(self,image):
|
||||
x,y,resize_needed = self._resolution_check(image.width,image.height)
|
||||
def _squeeze_image(self, image):
|
||||
x, y, resize_needed = self._resolution_check(image.width, image.height)
|
||||
if resize_needed:
|
||||
return InitImageResizer(image).resize(x,y)
|
||||
return InitImageResizer(image).resize(x, y)
|
||||
return image
|
||||
|
||||
|
||||
def _fit_image(self,image,max_dimensions):
|
||||
w,h = max_dimensions
|
||||
def _fit_image(self, image, max_dimensions):
|
||||
w, h = max_dimensions
|
||||
print(
|
||||
f'>> image will be resized to fit inside a box {w}x{h} in size.'
|
||||
)
|
||||
if image.width > image.height:
|
||||
h = None # by setting h to none, we tell InitImageResizer to fit into the width and calculate height
|
||||
h = None # by setting h to none, we tell InitImageResizer to fit into the width and calculate height
|
||||
elif image.height > image.width:
|
||||
w = None # ditto for w
|
||||
w = None # ditto for w
|
||||
else:
|
||||
pass
|
||||
image = InitImageResizer(image).resize(w,h) # note that InitImageResizer does the multiple of 64 truncation internally
|
||||
# note that InitImageResizer does the multiple of 64 truncation internally
|
||||
image = InitImageResizer(image).resize(w, h)
|
||||
print(
|
||||
f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}'
|
||||
)
|
||||
)
|
||||
return image
|
||||
|
||||
def _resolution_check(self, width, height, log=False):
|
||||
@ -864,7 +886,7 @@ class Generate:
|
||||
f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}'
|
||||
)
|
||||
height = h
|
||||
width = w
|
||||
width = w
|
||||
resize_needed = True
|
||||
|
||||
if (width * height) > (self.width * self.height):
|
||||
|
@ -1,168 +0,0 @@
|
||||
import torch
|
||||
import warnings
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
#from scripts.dream import create_argv_parser
|
||||
from ldm.dream.args import Args
|
||||
|
||||
opt = Args()
|
||||
opt.parse_args()
|
||||
model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path)
|
||||
gfpgan_model_exists = os.path.isfile(model_path)
|
||||
|
||||
def run_gfpgan(image, strength, seed, upsampler_scale=4):
|
||||
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
|
||||
gfpgan = None
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
try:
|
||||
if not gfpgan_model_exists:
|
||||
raise Exception('GFPGAN model not found at path ' + model_path)
|
||||
|
||||
sys.path.append(os.path.abspath(opt.gfpgan_dir))
|
||||
from gfpgan import GFPGANer
|
||||
|
||||
bg_upsampler = _load_gfpgan_bg_upsampler(
|
||||
opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile
|
||||
)
|
||||
|
||||
gfpgan = GFPGANer(
|
||||
model_path=model_path,
|
||||
upscale=upsampler_scale,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=bg_upsampler,
|
||||
)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
print('>> Error loading GFPGAN:', file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if gfpgan is None:
|
||||
print(
|
||||
f'>> WARNING: GFPGAN not initialized.'
|
||||
)
|
||||
print(
|
||||
f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth to {model_path}, \nor change GFPGAN directory with --gfpgan_dir.'
|
||||
)
|
||||
return image
|
||||
|
||||
image = image.convert('RGB')
|
||||
|
||||
cropped_faces, restored_faces, restored_img = gfpgan.enhance(
|
||||
np.array(image, dtype=np.uint8),
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
if restored_img.size != image.size:
|
||||
image = image.resize(res.size)
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gfpgan = None
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400):
|
||||
if bg_upsampler == 'realesrgan':
|
||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||
use_half_precision = False
|
||||
else:
|
||||
use_half_precision = True
|
||||
|
||||
model_path = {
|
||||
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
||||
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
|
||||
}
|
||||
|
||||
if upsampler_scale not in model_path:
|
||||
return None
|
||||
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
if upsampler_scale == 4:
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=4,
|
||||
)
|
||||
if upsampler_scale == 2:
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=2,
|
||||
)
|
||||
|
||||
bg_upsampler = RealESRGANer(
|
||||
scale=upsampler_scale,
|
||||
model_path=model_path[upsampler_scale],
|
||||
model=model,
|
||||
tile=bg_tile,
|
||||
tile_pad=10,
|
||||
pre_pad=0,
|
||||
half=use_half_precision,
|
||||
)
|
||||
else:
|
||||
bg_upsampler = None
|
||||
|
||||
return bg_upsampler
|
||||
|
||||
|
||||
def real_esrgan_upscale(image, strength, upsampler_scale, seed):
|
||||
print(
|
||||
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
try:
|
||||
upsampler = _load_gfpgan_bg_upsampler(
|
||||
opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile
|
||||
)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
output, img_mode = upsampler.enhance(
|
||||
np.array(image, dtype=np.uint8),
|
||||
outscale=upsampler_scale,
|
||||
alpha_upsampler=opt.gfpgan_bg_upsampler,
|
||||
)
|
||||
|
||||
res = Image.fromarray(output)
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
if output.size != image.size:
|
||||
image = image.resize(res.size)
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
upsampler = None
|
||||
|
||||
return res
|
@ -2,12 +2,20 @@ import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import warnings
|
||||
import sys
|
||||
|
||||
pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
|
||||
class CodeFormerRestoration():
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
def __init__(self,
|
||||
codeformer_dir='ldm/restoration/codeformer',
|
||||
codeformer_model_path='weights/codeformer.pth') -> None:
|
||||
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
|
||||
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
||||
|
||||
if not self.codeformer_model_exists:
|
||||
print('## NOT FOUND: CodeFormer model not found at ' + self.model_path)
|
||||
sys.path.append(os.path.abspath(codeformer_dir))
|
||||
|
||||
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
||||
if seed is not None:
|
||||
|
76
ldm/restoration/gfpgan/gfpgan.py
Normal file
76
ldm/restoration/gfpgan/gfpgan.py
Normal file
@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import warnings
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class GFPGAN():
|
||||
def __init__(
|
||||
self,
|
||||
gfpgan_dir='src/gfpgan',
|
||||
gfpgan_model_path='experiments/pretrained_models/GFPGANv1.3.pth') -> None:
|
||||
|
||||
self.model_path = os.path.join(gfpgan_dir, gfpgan_model_path)
|
||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||
|
||||
if not self.gfpgan_model_exists:
|
||||
raise Exception(
|
||||
'GFPGAN model not found at path ' + self.model_path)
|
||||
sys.path.append(os.path.abspath(gfpgan_dir))
|
||||
|
||||
def model_exists(self):
|
||||
return os.path.isfile(self.model_path)
|
||||
|
||||
def process(self, image, strength: float, seed: str = None):
|
||||
if seed is not None:
|
||||
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
try:
|
||||
from gfpgan import GFPGANer
|
||||
self.gfpgan = GFPGANer(
|
||||
model_path=self.model_path,
|
||||
upscale=1,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=None,
|
||||
)
|
||||
except Exception:
|
||||
import traceback
|
||||
print('>> Error loading GFPGAN:', file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if self.gfpgan is None:
|
||||
print(
|
||||
f'>> WARNING: GFPGAN not initialized.'
|
||||
)
|
||||
print(
|
||||
f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth to {self.model_path}, \nor change GFPGAN directory with --gfpgan_dir.'
|
||||
)
|
||||
|
||||
image = image.convert('RGB')
|
||||
|
||||
_, _, restored_img = self.gfpgan.enhance(
|
||||
np.array(image, dtype=np.uint8),
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
if restored_img.size != image.size:
|
||||
image = image.resize(res.size)
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
self.gfpgan = None
|
||||
|
||||
return res
|
102
ldm/restoration/realesrgan/realesrgan.py
Normal file
102
ldm/restoration/realesrgan/realesrgan.py
Normal file
@ -0,0 +1,102 @@
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ESRGAN():
|
||||
def __init__(self, bg_tile_size=400) -> None:
|
||||
self.bg_tile_size = bg_tile_size
|
||||
|
||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||
use_half_precision = False
|
||||
else:
|
||||
use_half_precision = True
|
||||
|
||||
def load_esrgan_bg_upsampler(self, upsampler_scale):
|
||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||
use_half_precision = False
|
||||
else:
|
||||
use_half_precision = True
|
||||
|
||||
model_path = {
|
||||
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
||||
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
|
||||
}
|
||||
|
||||
if upsampler_scale not in model_path:
|
||||
return None
|
||||
else:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
if upsampler_scale == 4:
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=4,
|
||||
)
|
||||
if upsampler_scale == 2:
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=2,
|
||||
)
|
||||
|
||||
bg_upsampler = RealESRGANer(
|
||||
scale=upsampler_scale,
|
||||
model_path=model_path[upsampler_scale],
|
||||
model=model,
|
||||
tile=self.bg_tile_size,
|
||||
tile_pad=10,
|
||||
pre_pad=0,
|
||||
half=use_half_precision,
|
||||
)
|
||||
|
||||
return bg_upsampler
|
||||
|
||||
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
|
||||
if seed is not None:
|
||||
print(
|
||||
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
try:
|
||||
upsampler = self.load_esrgan_bg_upsampler(upsampler_scale)
|
||||
except Exception:
|
||||
import traceback
|
||||
import sys
|
||||
|
||||
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
output, _ = upsampler.enhance(
|
||||
np.array(image, dtype=np.uint8),
|
||||
outscale=upsampler_scale,
|
||||
alpha_upsampler='realesrgan',
|
||||
)
|
||||
|
||||
res = Image.fromarray(output)
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
if output.size != image.size:
|
||||
image = image.resize(res.size)
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
upsampler = None
|
||||
|
||||
return res
|
34
ldm/restoration/restoration.py
Normal file
34
ldm/restoration/restoration.py
Normal file
@ -0,0 +1,34 @@
|
||||
class Restoration():
|
||||
def __init__(self, gfpgan_dir='./src/gfpgan', gfpgan_model_path='experiments/pretrained_models/GFPGANv1.3.pth', esrgan_bg_tile=400) -> None:
|
||||
self.gfpgan_dir = gfpgan_dir
|
||||
self.gfpgan_model_path = gfpgan_model_path
|
||||
self.esrgan_bg_tile = esrgan_bg_tile
|
||||
|
||||
def load_face_restore_models(self):
|
||||
# Load GFPGAN
|
||||
gfpgan = self.load_gfpgan()
|
||||
if gfpgan.gfpgan_model_exists:
|
||||
print('>> GFPGAN Initialized')
|
||||
|
||||
# Load CodeFormer
|
||||
codeformer = self.load_codeformer()
|
||||
if codeformer.codeformer_model_exists:
|
||||
print('>> CodeFormer Initialized')
|
||||
|
||||
return gfpgan, codeformer
|
||||
|
||||
# Face Restore Models
|
||||
def load_gfpgan(self):
|
||||
from ldm.restoration.gfpgan.gfpgan import GFPGAN
|
||||
return GFPGAN(self.gfpgan_dir, self.gfpgan_model_path)
|
||||
|
||||
def load_codeformer(self):
|
||||
from ldm.restoration.codeformer.codeformer import CodeFormerRestoration
|
||||
return CodeFormerRestoration()
|
||||
|
||||
# Upscale Models
|
||||
def load_ersgan(self):
|
||||
from ldm.restoration.realesrgan.realesrgan import ESRGAN
|
||||
esrgan = ESRGAN(self.esrgan_bg_tile)
|
||||
print('>> ESRGAN Initialized')
|
||||
return esrgan;
|
@ -43,7 +43,25 @@ def main():
|
||||
import transformers
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
# creating a simple Generate object with a handful of
|
||||
# Loading Face Restoration and ESRGAN Modules
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
from ldm.restoration.restoration import Restoration
|
||||
restoration = Restoration(opt.gfpgan_dir, opt.gfpgan_model_path, opt.esrgan_bg_tile)
|
||||
if opt.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models()
|
||||
else:
|
||||
print('>> Face Restoration Disabled')
|
||||
if opt.esrgan:
|
||||
esrgan = restoration.load_ersgan()
|
||||
else:
|
||||
print('>> ESRGAN Disabled')
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
import traceback
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
||||
|
||||
# creating a simple text2image object with a handful of
|
||||
# defaults passed on the command line.
|
||||
# additional parameters will be added (or overriden) during
|
||||
# the user input loop
|
||||
@ -55,6 +73,9 @@ def main():
|
||||
embedding_path = opt.embedding_path,
|
||||
full_precision = opt.full_precision,
|
||||
precision = opt.precision,
|
||||
gfpgan=gfpgan,
|
||||
codeformer=codeformer,
|
||||
esrgan=esrgan
|
||||
)
|
||||
except (FileNotFoundError, IOError, KeyError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
@ -91,7 +112,7 @@ def main():
|
||||
|
||||
# web server loops forever
|
||||
if opt.web:
|
||||
dream_server_loop(gen, opt.host, opt.port, opt.outdir)
|
||||
dream_server_loop(gen, opt.host, opt.port, opt.outdir, gfpgan)
|
||||
sys.exit(0)
|
||||
|
||||
main_loop(gen, opt, infile)
|
||||
@ -347,7 +368,7 @@ def get_next_command(infile=None) -> str: # command string
|
||||
print(f'#{command}')
|
||||
return command
|
||||
|
||||
def dream_server_loop(gen, host, port, outdir):
|
||||
def dream_server_loop(gen, host, port, outdir, gfpgan):
|
||||
print('\n* --web was specified, starting web server...')
|
||||
# Change working directory to the stable-diffusion directory
|
||||
os.chdir(
|
||||
@ -357,6 +378,10 @@ def dream_server_loop(gen, host, port, outdir):
|
||||
# Start server
|
||||
DreamServer.model = gen # misnomer in DreamServer - this is not the model you are looking for
|
||||
DreamServer.outdir = outdir
|
||||
DreamServer.gfpgan_model_exists = False
|
||||
if gfpgan is not None:
|
||||
DreamServer.gfpgan_model_exists = gfpgan.gfpgan_model_exists
|
||||
|
||||
dream_server = ThreadingDreamServer((host, port))
|
||||
print(">> Started Stable Diffusion dream server!")
|
||||
if host == '0.0.0.0':
|
||||
@ -374,5 +399,19 @@ def dream_server_loop(gen, host, port, outdir):
|
||||
dream_server.server_close()
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
def write_log_message(results, log_path):
|
||||
"""logs the name of the output image, prompt, and prompt args to the terminal and log file"""
|
||||
global output_cntr
|
||||
log_lines = [f'{path}: {prompt}\n' for path, prompt in results]
|
||||
for l in log_lines:
|
||||
output_cntr += 1
|
||||
print(f'[{output_cntr}] {l}', end='')
|
||||
|
||||
with open(log_path, 'a', encoding='utf-8') as file:
|
||||
file.writelines(log_lines)
|
||||
|
||||
>>>>>>> GFPGAN and Real ESRGAN Implementation Refactor
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user