mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
GFPGAN and Real ESRGAN Implementation Refactor
This commit is contained in:
parent
e8bb39370c
commit
1b5013ab72
@ -4,10 +4,16 @@ title: Upscale
|
|||||||
|
|
||||||
# :material-image-size-select-large: Upscale
|
# :material-image-size-select-large: Upscale
|
||||||
|
|
||||||
|
## **Intro**
|
||||||
|
|
||||||
|
The script provides the ability to restore faces and upscale.
|
||||||
|
|
||||||
|
You can enable these features by passing `--restore` and `--esrgan` to your launch script to enable
|
||||||
|
face restoration modules and upscaling modules respectively.
|
||||||
|
|
||||||
## **GFPGAN and Real-ESRGAN Support**
|
## **GFPGAN and Real-ESRGAN Support**
|
||||||
|
|
||||||
The script also provides the ability to do face restoration and upscaling with the help of GFPGAN
|
The default face restoration module is GFPGAN and the default upscaling module is ESRGAN.
|
||||||
and Real-ESRGAN respectively.
|
|
||||||
|
|
||||||
As of version 1.14, environment.yaml will install the Real-ESRGAN package into the standard install
|
As of version 1.14, environment.yaml will install the Real-ESRGAN package into the standard install
|
||||||
location for python packages, and will put GFPGAN into a subdirectory of "src" in the
|
location for python packages, and will put GFPGAN into a subdirectory of "src" in the
|
||||||
|
@ -370,16 +370,19 @@ class Args(object):
|
|||||||
type=str,
|
type=str,
|
||||||
help='Path to a pre-trained embedding manager checkpoint - can only be set on command line',
|
help='Path to a pre-trained embedding manager checkpoint - can only be set on command line',
|
||||||
)
|
)
|
||||||
# GFPGAN related args
|
# Restoration related args
|
||||||
postprocessing_group.add_argument(
|
postprocessing_group.add_argument(
|
||||||
'--gfpgan_bg_upsampler',
|
'--restore',
|
||||||
type=str,
|
action='store_true',
|
||||||
default='realesrgan',
|
help='Enable Face Restoration',
|
||||||
help='Background upsampler. Default: realesrgan. Options: realesrgan, none.',
|
|
||||||
|
|
||||||
)
|
)
|
||||||
postprocessing_group.add_argument(
|
postprocessing_group.add_argument(
|
||||||
'--gfpgan_bg_tile',
|
'--esrgan',
|
||||||
|
action='store_true',
|
||||||
|
help='Enable Upscaling',
|
||||||
|
)
|
||||||
|
postprocessing_group.add_argument(
|
||||||
|
'--esrgan_bg_tile',
|
||||||
type=int,
|
type=int,
|
||||||
default=400,
|
default=400,
|
||||||
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
|
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
|
||||||
|
@ -4,11 +4,12 @@ and generates with ldm.dream.generator.img2img
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from ldm.dream.generator.base import Generator
|
from ldm.dream.generator.base import Generator
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.dream.generator.img2img import Img2Img
|
from ldm.dream.generator.img2img import Img2Img
|
||||||
|
|
||||||
|
|
||||||
class Embiggen(Generator):
|
class Embiggen(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
@ -38,19 +39,20 @@ class Embiggen(Generator):
|
|||||||
Return value depends on the seed at the time you call it
|
Return value depends on the seed at the time you call it
|
||||||
"""
|
"""
|
||||||
# Construct embiggen arg array, and sanity check arguments
|
# Construct embiggen arg array, and sanity check arguments
|
||||||
if embiggen == None: # embiggen can also be called with just embiggen_tiles
|
if embiggen == None: # embiggen can also be called with just embiggen_tiles
|
||||||
embiggen = [1.0] # If not specified, assume no scaling
|
embiggen = [1.0] # If not specified, assume no scaling
|
||||||
elif embiggen[0] < 0 :
|
elif embiggen[0] < 0:
|
||||||
embiggen[0] = 1.0
|
embiggen[0] = 1.0
|
||||||
print('>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !')
|
print(
|
||||||
|
'>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !')
|
||||||
if len(embiggen) < 2:
|
if len(embiggen) < 2:
|
||||||
embiggen.append(0.75)
|
embiggen.append(0.75)
|
||||||
elif embiggen[1] > 1.0 or embiggen[1] < 0 :
|
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
||||||
embiggen[1] = 0.75
|
embiggen[1] = 0.75
|
||||||
print('>> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !')
|
print('>> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !')
|
||||||
if len(embiggen) < 3:
|
if len(embiggen) < 3:
|
||||||
embiggen.append(0.25)
|
embiggen.append(0.25)
|
||||||
elif embiggen[2] < 0 :
|
elif embiggen[2] < 0:
|
||||||
embiggen[2] = 0.25
|
embiggen[2] = 0.25
|
||||||
print('>> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !')
|
print('>> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !')
|
||||||
|
|
||||||
@ -76,29 +78,30 @@ class Embiggen(Generator):
|
|||||||
if embiggen[0] != 1.0:
|
if embiggen[0] != 1.0:
|
||||||
initsuperwidth = round(initsuperwidth*embiggen[0])
|
initsuperwidth = round(initsuperwidth*embiggen[0])
|
||||||
initsuperheight = round(initsuperheight*embiggen[0])
|
initsuperheight = round(initsuperheight*embiggen[0])
|
||||||
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
|
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
|
||||||
from ldm.gfpgan.gfpgan_tools import (
|
from ldm.restoration.realesrgan import ESRGAN
|
||||||
real_esrgan_upscale,
|
esrgan = ESRGAN()
|
||||||
)
|
print(
|
||||||
print(f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}')
|
f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}')
|
||||||
if embiggen[0] > 2:
|
if embiggen[0] > 2:
|
||||||
initsuperimage = real_esrgan_upscale(
|
initsuperimage = esrgan.process(
|
||||||
initsuperimage,
|
initsuperimage,
|
||||||
embiggen[1], # upscale strength
|
embiggen[1], # upscale strength
|
||||||
4, # upscale scale
|
|
||||||
self.seed,
|
self.seed,
|
||||||
|
4, # upscale scale
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
initsuperimage = real_esrgan_upscale(
|
initsuperimage = esrgan.process(
|
||||||
initsuperimage,
|
initsuperimage,
|
||||||
embiggen[1], # upscale strength
|
embiggen[1], # upscale strength
|
||||||
2, # upscale scale
|
|
||||||
self.seed,
|
self.seed,
|
||||||
|
2, # upscale scale
|
||||||
)
|
)
|
||||||
# We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x
|
# We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x
|
||||||
# but from personal experiance it doesn't greatly improve anything after 4x
|
# but from personal experiance it doesn't greatly improve anything after 4x
|
||||||
# Resize to target scaling factor resolution
|
# Resize to target scaling factor resolution
|
||||||
initsuperimage = initsuperimage.resize((initsuperwidth, initsuperheight), Image.Resampling.LANCZOS)
|
initsuperimage = initsuperimage.resize(
|
||||||
|
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
# Use width and height as tile widths and height
|
# Use width and height as tile widths and height
|
||||||
# Determine buffer size in pixels
|
# Determine buffer size in pixels
|
||||||
@ -121,28 +124,31 @@ class Embiggen(Generator):
|
|||||||
emb_tiles_x = 1
|
emb_tiles_x = 1
|
||||||
emb_tiles_y = 1
|
emb_tiles_y = 1
|
||||||
if (initsuperwidth - width) > 0:
|
if (initsuperwidth - width) > 0:
|
||||||
emb_tiles_x = ceildiv(initsuperwidth - width, width - overlap_size_x) + 1
|
emb_tiles_x = ceildiv(initsuperwidth - width,
|
||||||
|
width - overlap_size_x) + 1
|
||||||
if (initsuperheight - height) > 0:
|
if (initsuperheight - height) > 0:
|
||||||
emb_tiles_y = ceildiv(initsuperheight - height, height - overlap_size_y) + 1
|
emb_tiles_y = ceildiv(initsuperheight - height,
|
||||||
|
height - overlap_size_y) + 1
|
||||||
# Sanity
|
# Sanity
|
||||||
assert emb_tiles_x > 1 or emb_tiles_y > 1, f'ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don\'t need to Embiggen! Check your arguments.'
|
assert emb_tiles_x > 1 or emb_tiles_y > 1, f'ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don\'t need to Embiggen! Check your arguments.'
|
||||||
|
|
||||||
# Prep alpha layers --------------
|
# Prep alpha layers --------------
|
||||||
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
|
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
|
||||||
# agradientL is Left-side transparent
|
# agradientL is Left-side transparent
|
||||||
agradientL = Image.linear_gradient('L').rotate(90).resize((overlap_size_x, height))
|
agradientL = Image.linear_gradient('L').rotate(
|
||||||
|
90).resize((overlap_size_x, height))
|
||||||
# agradientT is Top-side transparent
|
# agradientT is Top-side transparent
|
||||||
agradientT = Image.linear_gradient('L').resize((width, overlap_size_y))
|
agradientT = Image.linear_gradient('L').resize((width, overlap_size_y))
|
||||||
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
|
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
|
||||||
agradientC = Image.new('L', (256, 256))
|
agradientC = Image.new('L', (256, 256))
|
||||||
for y in range(256):
|
for y in range(256):
|
||||||
for x in range(256):
|
for x in range(256):
|
||||||
#Find distance to lower right corner (numpy takes arrays)
|
# Find distance to lower right corner (numpy takes arrays)
|
||||||
distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0]
|
distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0]
|
||||||
#Clamp values to max 255
|
# Clamp values to max 255
|
||||||
if distanceToLR > 255:
|
if distanceToLR > 255:
|
||||||
distanceToLR = 255
|
distanceToLR = 255
|
||||||
#Place the pixel as invert of distance
|
# Place the pixel as invert of distance
|
||||||
agradientC.putpixel((x, y), int(255 - distanceToLR))
|
agradientC.putpixel((x, y), int(255 - distanceToLR))
|
||||||
|
|
||||||
# Create alpha layers default fully white
|
# Create alpha layers default fully white
|
||||||
@ -154,59 +160,79 @@ class Embiggen(Generator):
|
|||||||
alphaLayerT.paste(agradientT, (0, 0))
|
alphaLayerT.paste(agradientT, (0, 0))
|
||||||
alphaLayerLTC.paste(agradientL, (0, 0))
|
alphaLayerLTC.paste(agradientL, (0, 0))
|
||||||
alphaLayerLTC.paste(agradientT, (0, 0))
|
alphaLayerLTC.paste(agradientT, (0, 0))
|
||||||
alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
|
alphaLayerLTC.paste(agradientC.resize(
|
||||||
|
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||||
|
|
||||||
if embiggen_tiles:
|
if embiggen_tiles:
|
||||||
# Individual unconnected sides
|
# Individual unconnected sides
|
||||||
alphaLayerR = Image.new("L", (width, height), 255)
|
alphaLayerR = Image.new("L", (width, height), 255)
|
||||||
alphaLayerR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
alphaLayerR.paste(agradientL.rotate(
|
||||||
|
180), (width - overlap_size_x, 0))
|
||||||
alphaLayerB = Image.new("L", (width, height), 255)
|
alphaLayerB = Image.new("L", (width, height), 255)
|
||||||
alphaLayerB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
alphaLayerB.paste(agradientT.rotate(
|
||||||
|
180), (0, height - overlap_size_y))
|
||||||
alphaLayerTB = Image.new("L", (width, height), 255)
|
alphaLayerTB = Image.new("L", (width, height), 255)
|
||||||
alphaLayerTB.paste(agradientT, (0, 0))
|
alphaLayerTB.paste(agradientT, (0, 0))
|
||||||
alphaLayerTB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
alphaLayerTB.paste(agradientT.rotate(
|
||||||
|
180), (0, height - overlap_size_y))
|
||||||
alphaLayerLR = Image.new("L", (width, height), 255)
|
alphaLayerLR = Image.new("L", (width, height), 255)
|
||||||
alphaLayerLR.paste(agradientL, (0, 0))
|
alphaLayerLR.paste(agradientL, (0, 0))
|
||||||
alphaLayerLR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
alphaLayerLR.paste(agradientL.rotate(
|
||||||
|
180), (width - overlap_size_x, 0))
|
||||||
|
|
||||||
# Sides and corner Layers
|
# Sides and corner Layers
|
||||||
alphaLayerRBC = Image.new("L", (width, height), 255)
|
alphaLayerRBC = Image.new("L", (width, height), 255)
|
||||||
alphaLayerRBC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
alphaLayerRBC.paste(agradientL.rotate(
|
||||||
alphaLayerRBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
180), (width - overlap_size_x, 0))
|
||||||
alphaLayerRBC.paste(agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
alphaLayerRBC.paste(agradientT.rotate(
|
||||||
|
180), (0, height - overlap_size_y))
|
||||||
|
alphaLayerRBC.paste(agradientC.rotate(180).resize(
|
||||||
|
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||||
alphaLayerLBC = Image.new("L", (width, height), 255)
|
alphaLayerLBC = Image.new("L", (width, height), 255)
|
||||||
alphaLayerLBC.paste(agradientL, (0, 0))
|
alphaLayerLBC.paste(agradientL, (0, 0))
|
||||||
alphaLayerLBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
alphaLayerLBC.paste(agradientT.rotate(
|
||||||
alphaLayerLBC.paste(agradientC.rotate(90).resize((overlap_size_x, overlap_size_y)), (0, height - overlap_size_y))
|
180), (0, height - overlap_size_y))
|
||||||
|
alphaLayerLBC.paste(agradientC.rotate(90).resize(
|
||||||
|
(overlap_size_x, overlap_size_y)), (0, height - overlap_size_y))
|
||||||
alphaLayerRTC = Image.new("L", (width, height), 255)
|
alphaLayerRTC = Image.new("L", (width, height), 255)
|
||||||
alphaLayerRTC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
alphaLayerRTC.paste(agradientL.rotate(
|
||||||
|
180), (width - overlap_size_x, 0))
|
||||||
alphaLayerRTC.paste(agradientT, (0, 0))
|
alphaLayerRTC.paste(agradientT, (0, 0))
|
||||||
alphaLayerRTC.paste(agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
alphaLayerRTC.paste(agradientC.rotate(270).resize(
|
||||||
|
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||||
|
|
||||||
# All but X layers
|
# All but X layers
|
||||||
alphaLayerABT = Image.new("L", (width, height), 255)
|
alphaLayerABT = Image.new("L", (width, height), 255)
|
||||||
alphaLayerABT.paste(alphaLayerLBC, (0, 0))
|
alphaLayerABT.paste(alphaLayerLBC, (0, 0))
|
||||||
alphaLayerABT.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
alphaLayerABT.paste(agradientL.rotate(
|
||||||
alphaLayerABT.paste(agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
180), (width - overlap_size_x, 0))
|
||||||
|
alphaLayerABT.paste(agradientC.rotate(180).resize(
|
||||||
|
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||||
alphaLayerABL = Image.new("L", (width, height), 255)
|
alphaLayerABL = Image.new("L", (width, height), 255)
|
||||||
alphaLayerABL.paste(alphaLayerRTC, (0, 0))
|
alphaLayerABL.paste(alphaLayerRTC, (0, 0))
|
||||||
alphaLayerABL.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
alphaLayerABL.paste(agradientT.rotate(
|
||||||
alphaLayerABL.paste(agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
180), (0, height - overlap_size_y))
|
||||||
|
alphaLayerABL.paste(agradientC.rotate(180).resize(
|
||||||
|
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||||
alphaLayerABR = Image.new("L", (width, height), 255)
|
alphaLayerABR = Image.new("L", (width, height), 255)
|
||||||
alphaLayerABR.paste(alphaLayerLBC, (0, 0))
|
alphaLayerABR.paste(alphaLayerLBC, (0, 0))
|
||||||
alphaLayerABR.paste(agradientT, (0, 0))
|
alphaLayerABR.paste(agradientT, (0, 0))
|
||||||
alphaLayerABR.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
|
alphaLayerABR.paste(agradientC.resize(
|
||||||
|
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||||
alphaLayerABB = Image.new("L", (width, height), 255)
|
alphaLayerABB = Image.new("L", (width, height), 255)
|
||||||
alphaLayerABB.paste(alphaLayerRTC, (0, 0))
|
alphaLayerABB.paste(alphaLayerRTC, (0, 0))
|
||||||
alphaLayerABB.paste(agradientL, (0, 0))
|
alphaLayerABB.paste(agradientL, (0, 0))
|
||||||
alphaLayerABB.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
|
alphaLayerABB.paste(agradientC.resize(
|
||||||
|
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||||
|
|
||||||
# All-around layer
|
# All-around layer
|
||||||
alphaLayerAA = Image.new("L", (width, height), 255)
|
alphaLayerAA = Image.new("L", (width, height), 255)
|
||||||
alphaLayerAA.paste(alphaLayerABT, (0, 0))
|
alphaLayerAA.paste(alphaLayerABT, (0, 0))
|
||||||
alphaLayerAA.paste(agradientT, (0, 0))
|
alphaLayerAA.paste(agradientT, (0, 0))
|
||||||
alphaLayerAA.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
|
alphaLayerAA.paste(agradientC.resize(
|
||||||
alphaLayerAA.paste(agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||||
|
alphaLayerAA.paste(agradientC.rotate(270).resize(
|
||||||
|
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||||
|
|
||||||
# Clean up temporary gradients
|
# Clean up temporary gradients
|
||||||
del agradientL
|
del agradientL
|
||||||
@ -218,7 +244,8 @@ class Embiggen(Generator):
|
|||||||
if embiggen_tiles:
|
if embiggen_tiles:
|
||||||
print(f'>> Making {len(embiggen_tiles)} Embiggen tiles...')
|
print(f'>> Making {len(embiggen_tiles)} Embiggen tiles...')
|
||||||
else:
|
else:
|
||||||
print(f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...')
|
print(
|
||||||
|
f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...')
|
||||||
|
|
||||||
emb_tile_store = []
|
emb_tile_store = []
|
||||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||||
@ -240,20 +267,23 @@ class Embiggen(Generator):
|
|||||||
top = round(emb_row_i * (height - overlap_size_y))
|
top = round(emb_row_i * (height - overlap_size_y))
|
||||||
right = left + width
|
right = left + width
|
||||||
bottom = top + height
|
bottom = top + height
|
||||||
|
|
||||||
# Cropped image of above dimension (does not modify the original)
|
# Cropped image of above dimension (does not modify the original)
|
||||||
newinitimage = initsuperimage.crop((left, top, right, bottom))
|
newinitimage = initsuperimage.crop((left, top, right, bottom))
|
||||||
# DEBUG:
|
# DEBUG:
|
||||||
# newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png'
|
# newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png'
|
||||||
# newinitimage.save(newinitimagepath)
|
# newinitimage.save(newinitimagepath)
|
||||||
|
|
||||||
if embiggen_tiles:
|
if embiggen_tiles:
|
||||||
print(f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)')
|
print(
|
||||||
|
f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)')
|
||||||
else:
|
else:
|
||||||
print(f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles')
|
print(
|
||||||
|
f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles')
|
||||||
|
|
||||||
# create a torch tensor from an Image
|
# create a torch tensor from an Image
|
||||||
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
newinitimage = np.array(
|
||||||
|
newinitimage).astype(np.float32) / 255.0
|
||||||
newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
|
newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
|
||||||
newinitimage = torch.from_numpy(newinitimage)
|
newinitimage = torch.from_numpy(newinitimage)
|
||||||
newinitimage = 2.0 * newinitimage - 1.0
|
newinitimage = 2.0 * newinitimage - 1.0
|
||||||
@ -261,33 +291,35 @@ class Embiggen(Generator):
|
|||||||
|
|
||||||
tile_results = gen_img2img.generate(
|
tile_results = gen_img2img.generate(
|
||||||
prompt,
|
prompt,
|
||||||
iterations = 1,
|
iterations=1,
|
||||||
seed = self.seed,
|
seed=self.seed,
|
||||||
sampler = sampler,
|
sampler=sampler,
|
||||||
steps = steps,
|
steps=steps,
|
||||||
cfg_scale = cfg_scale,
|
cfg_scale=cfg_scale,
|
||||||
conditioning = conditioning,
|
conditioning=conditioning,
|
||||||
ddim_eta = ddim_eta,
|
ddim_eta=ddim_eta,
|
||||||
image_callback = None, # called only after the final image is generated
|
image_callback=None, # called only after the final image is generated
|
||||||
step_callback = step_callback, # called after each intermediate image is generated
|
step_callback=step_callback, # called after each intermediate image is generated
|
||||||
width = width,
|
width=width,
|
||||||
height = height,
|
height=height,
|
||||||
init_img = init_img, # img2img doesn't need this, but it might in the future
|
init_img=init_img, # img2img doesn't need this, but it might in the future
|
||||||
init_image = newinitimage, # notice that init_image is different from init_img
|
init_image=newinitimage, # notice that init_image is different from init_img
|
||||||
mask_image = None,
|
mask_image=None,
|
||||||
strength = strength,
|
strength=strength,
|
||||||
)
|
)
|
||||||
|
|
||||||
emb_tile_store.append(tile_results[0][0])
|
emb_tile_store.append(tile_results[0][0])
|
||||||
# DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite
|
# DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite
|
||||||
# emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png')
|
# emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png')
|
||||||
del newinitimage
|
del newinitimage
|
||||||
|
|
||||||
# Sanity check we have them all
|
# Sanity check we have them all
|
||||||
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)):
|
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)):
|
||||||
outputsuperimage = Image.new("RGBA", (initsuperwidth, initsuperheight))
|
outputsuperimage = Image.new(
|
||||||
|
"RGBA", (initsuperwidth, initsuperheight))
|
||||||
if embiggen_tiles:
|
if embiggen_tiles:
|
||||||
outputsuperimage.alpha_composite(initsuperimage.convert('RGBA'), (0, 0))
|
outputsuperimage.alpha_composite(
|
||||||
|
initsuperimage.convert('RGBA'), (0, 0))
|
||||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||||
if embiggen_tiles:
|
if embiggen_tiles:
|
||||||
if tile in embiggen_tiles:
|
if tile in embiggen_tiles:
|
||||||
@ -308,7 +340,8 @@ class Embiggen(Generator):
|
|||||||
if emb_column_i + 1 == emb_tiles_x:
|
if emb_column_i + 1 == emb_tiles_x:
|
||||||
left = initsuperwidth - width
|
left = initsuperwidth - width
|
||||||
else:
|
else:
|
||||||
left = round(emb_column_i * (width - overlap_size_x))
|
left = round(emb_column_i *
|
||||||
|
(width - overlap_size_x))
|
||||||
if emb_row_i + 1 == emb_tiles_y:
|
if emb_row_i + 1 == emb_tiles_y:
|
||||||
top = initsuperheight - height
|
top = initsuperheight - height
|
||||||
else:
|
else:
|
||||||
@ -319,33 +352,33 @@ class Embiggen(Generator):
|
|||||||
# top of image
|
# top of image
|
||||||
if emb_row_i == 0:
|
if emb_row_i == 0:
|
||||||
if emb_column_i == 0:
|
if emb_column_i == 0:
|
||||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||||
if (tile+emb_tiles_x) not in embiggen_tiles: # Look-ahead down
|
if (tile+emb_tiles_x) not in embiggen_tiles: # Look-ahead down
|
||||||
intileimage.putalpha(alphaLayerB)
|
intileimage.putalpha(alphaLayerB)
|
||||||
# Otherwise do nothing on this tile
|
# Otherwise do nothing on this tile
|
||||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||||
intileimage.putalpha(alphaLayerR)
|
intileimage.putalpha(alphaLayerR)
|
||||||
else:
|
else:
|
||||||
intileimage.putalpha(alphaLayerRBC)
|
intileimage.putalpha(alphaLayerRBC)
|
||||||
elif emb_column_i == emb_tiles_x - 1:
|
elif emb_column_i == emb_tiles_x - 1:
|
||||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||||
intileimage.putalpha(alphaLayerL)
|
intileimage.putalpha(alphaLayerL)
|
||||||
else:
|
else:
|
||||||
intileimage.putalpha(alphaLayerLBC)
|
intileimage.putalpha(alphaLayerLBC)
|
||||||
else:
|
else:
|
||||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||||
intileimage.putalpha(alphaLayerL)
|
intileimage.putalpha(alphaLayerL)
|
||||||
else:
|
else:
|
||||||
intileimage.putalpha(alphaLayerLBC)
|
intileimage.putalpha(alphaLayerLBC)
|
||||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||||
intileimage.putalpha(alphaLayerLR)
|
intileimage.putalpha(alphaLayerLR)
|
||||||
else:
|
else:
|
||||||
intileimage.putalpha(alphaLayerABT)
|
intileimage.putalpha(alphaLayerABT)
|
||||||
# bottom of image
|
# bottom of image
|
||||||
elif emb_row_i == emb_tiles_y - 1:
|
elif emb_row_i == emb_tiles_y - 1:
|
||||||
if emb_column_i == 0:
|
if emb_column_i == 0:
|
||||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||||
intileimage.putalpha(alphaLayerT)
|
intileimage.putalpha(alphaLayerT)
|
||||||
else:
|
else:
|
||||||
intileimage.putalpha(alphaLayerRTC)
|
intileimage.putalpha(alphaLayerRTC)
|
||||||
@ -353,34 +386,34 @@ class Embiggen(Generator):
|
|||||||
# No tiles to look ahead to
|
# No tiles to look ahead to
|
||||||
intileimage.putalpha(alphaLayerLTC)
|
intileimage.putalpha(alphaLayerLTC)
|
||||||
else:
|
else:
|
||||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||||
intileimage.putalpha(alphaLayerLTC)
|
intileimage.putalpha(alphaLayerLTC)
|
||||||
else:
|
else:
|
||||||
intileimage.putalpha(alphaLayerABB)
|
intileimage.putalpha(alphaLayerABB)
|
||||||
# vertical middle of image
|
# vertical middle of image
|
||||||
else:
|
else:
|
||||||
if emb_column_i == 0:
|
if emb_column_i == 0:
|
||||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||||
intileimage.putalpha(alphaLayerT)
|
intileimage.putalpha(alphaLayerT)
|
||||||
else:
|
else:
|
||||||
intileimage.putalpha(alphaLayerTB)
|
intileimage.putalpha(alphaLayerTB)
|
||||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||||
intileimage.putalpha(alphaLayerRTC)
|
intileimage.putalpha(alphaLayerRTC)
|
||||||
else:
|
else:
|
||||||
intileimage.putalpha(alphaLayerABL)
|
intileimage.putalpha(alphaLayerABL)
|
||||||
elif emb_column_i == emb_tiles_x - 1:
|
elif emb_column_i == emb_tiles_x - 1:
|
||||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||||
intileimage.putalpha(alphaLayerLTC)
|
intileimage.putalpha(alphaLayerLTC)
|
||||||
else:
|
else:
|
||||||
intileimage.putalpha(alphaLayerABR)
|
intileimage.putalpha(alphaLayerABR)
|
||||||
else:
|
else:
|
||||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||||
intileimage.putalpha(alphaLayerLTC)
|
intileimage.putalpha(alphaLayerLTC)
|
||||||
else:
|
else:
|
||||||
intileimage.putalpha(alphaLayerABR)
|
intileimage.putalpha(alphaLayerABR)
|
||||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||||
intileimage.putalpha(alphaLayerABB)
|
intileimage.putalpha(alphaLayerABB)
|
||||||
else:
|
else:
|
||||||
intileimage.putalpha(alphaLayerAA)
|
intileimage.putalpha(alphaLayerAA)
|
||||||
@ -400,4 +433,4 @@ class Embiggen(Generator):
|
|||||||
# after internal loops and patching up return Embiggen image
|
# after internal loops and patching up return Embiggen image
|
||||||
return outputsuperimage
|
return outputsuperimage
|
||||||
# end of function declaration
|
# end of function declaration
|
||||||
return make_image
|
return make_image
|
||||||
|
@ -37,6 +37,8 @@ def build_opt(post_data, seed, gfpgan_model_exists):
|
|||||||
setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed']))
|
setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed']))
|
||||||
setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0)
|
setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0)
|
||||||
setattr(opt, 'with_variations', [])
|
setattr(opt, 'with_variations', [])
|
||||||
|
setattr(opt, 'embiggen', None)
|
||||||
|
setattr(opt, 'embiggen_tiles', None)
|
||||||
|
|
||||||
broken = False
|
broken = False
|
||||||
if int(post_data['seed']) != -1 and post_data['with_variations'] != '':
|
if int(post_data['seed']) != -1 and post_data['with_variations'] != '':
|
||||||
@ -80,12 +82,11 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
self.wfile.write(content.read())
|
self.wfile.write(content.read())
|
||||||
elif self.path == "/config.js":
|
elif self.path == "/config.js":
|
||||||
# unfortunately this import can't be at the top level, since that would cause a circular import
|
# unfortunately this import can't be at the top level, since that would cause a circular import
|
||||||
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
|
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header("Content-type", "application/javascript")
|
self.send_header("Content-type", "application/javascript")
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
config = {
|
config = {
|
||||||
'gfpgan_model_exists': gfpgan_model_exists
|
'gfpgan_model_exists': self.gfpgan_model_exists
|
||||||
}
|
}
|
||||||
self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8"))
|
self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8"))
|
||||||
elif self.path == "/run_log.json":
|
elif self.path == "/run_log.json":
|
||||||
@ -138,11 +139,10 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
# unfortunately this import can't be at the top level, since that would cause a circular import
|
# unfortunately this import can't be at the top level, since that would cause a circular import
|
||||||
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
|
|
||||||
|
|
||||||
content_length = int(self.headers['Content-Length'])
|
content_length = int(self.headers['Content-Length'])
|
||||||
post_data = json.loads(self.rfile.read(content_length))
|
post_data = json.loads(self.rfile.read(content_length))
|
||||||
opt = build_opt(post_data, self.model.seed, gfpgan_model_exists)
|
opt = build_opt(post_data, self.model.seed, self.gfpgan_model_exists)
|
||||||
|
|
||||||
self.canceled.clear()
|
self.canceled.clear()
|
||||||
# In order to handle upscaled images, the PngWriter needs to maintain state
|
# In order to handle upscaled images, the PngWriter needs to maintain state
|
||||||
|
224
ldm/generate.py
224
ldm/generate.py
@ -23,9 +23,9 @@ from PIL import Image, ImageOps
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from pytorch_lightning import seed_everything, logging
|
from pytorch_lightning import seed_everything, logging
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
from ldm.models.diffusion.ksampler import KSampler
|
from ldm.models.diffusion.ksampler import KSampler
|
||||||
from ldm.dream.pngwriter import PngWriter, retrieve_metadata
|
from ldm.dream.pngwriter import PngWriter, retrieve_metadata
|
||||||
from ldm.dream.args import metadata_loads
|
from ldm.dream.args import metadata_loads
|
||||||
@ -51,6 +51,24 @@ torch.randint_like = fix_func(torch.randint_like)
|
|||||||
torch.bernoulli = fix_func(torch.bernoulli)
|
torch.bernoulli = fix_func(torch.bernoulli)
|
||||||
torch.multinomial = fix_func(torch.multinomial)
|
torch.multinomial = fix_func(torch.multinomial)
|
||||||
|
|
||||||
|
def fix_func(orig):
|
||||||
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||||
|
def new_func(*args, **kw):
|
||||||
|
device = kw.get("device", "mps")
|
||||||
|
kw["device"]="cpu"
|
||||||
|
return orig(*args, **kw).to(device)
|
||||||
|
return new_func
|
||||||
|
return orig
|
||||||
|
|
||||||
|
torch.rand = fix_func(torch.rand)
|
||||||
|
torch.rand_like = fix_func(torch.rand_like)
|
||||||
|
torch.randn = fix_func(torch.randn)
|
||||||
|
torch.randn_like = fix_func(torch.randn_like)
|
||||||
|
torch.randint = fix_func(torch.randint)
|
||||||
|
torch.randint_like = fix_func(torch.randint_like)
|
||||||
|
torch.bernoulli = fix_func(torch.bernoulli)
|
||||||
|
torch.multinomial = fix_func(torch.multinomial)
|
||||||
|
|
||||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||||
|
|
||||||
Example Usage:
|
Example Usage:
|
||||||
@ -135,6 +153,9 @@ class Generate:
|
|||||||
# these are deprecated; if present they override values in the conf file
|
# these are deprecated; if present they override values in the conf file
|
||||||
weights = None,
|
weights = None,
|
||||||
config = None,
|
config = None,
|
||||||
|
gfpgan=None,
|
||||||
|
codeformer=None,
|
||||||
|
esrgan=None
|
||||||
):
|
):
|
||||||
models = OmegaConf.load(conf)
|
models = OmegaConf.load(conf)
|
||||||
mconfig = models[model]
|
mconfig = models[model]
|
||||||
@ -158,6 +179,9 @@ class Generate:
|
|||||||
self.generators = {}
|
self.generators = {}
|
||||||
self.base_generator = None
|
self.base_generator = None
|
||||||
self.seed = None
|
self.seed = None
|
||||||
|
self.gfpgan = gfpgan
|
||||||
|
self.codeformer = codeformer
|
||||||
|
self.esrgan = esrgan
|
||||||
|
|
||||||
# Note that in previous versions, there was an option to pass the
|
# Note that in previous versions, there was an option to pass the
|
||||||
# device to Generate(). However the device was then ignored, so
|
# device to Generate(). However the device was then ignored, so
|
||||||
@ -234,8 +258,8 @@ class Generate:
|
|||||||
strength = None,
|
strength = None,
|
||||||
init_color = None,
|
init_color = None,
|
||||||
# these are specific to embiggen (which also relies on img2img args)
|
# these are specific to embiggen (which also relies on img2img args)
|
||||||
embiggen = None,
|
embiggen=None,
|
||||||
embiggen_tiles = None,
|
embiggen_tiles=None,
|
||||||
# these are specific to GFPGAN/ESRGAN
|
# these are specific to GFPGAN/ESRGAN
|
||||||
facetool = None,
|
facetool = None,
|
||||||
gfpgan_strength = 0,
|
gfpgan_strength = 0,
|
||||||
@ -284,15 +308,15 @@ class Generate:
|
|||||||
write the prompt into the PNG metadata.
|
write the prompt into the PNG metadata.
|
||||||
"""
|
"""
|
||||||
# TODO: convert this into a getattr() loop
|
# TODO: convert this into a getattr() loop
|
||||||
steps = steps or self.steps
|
steps = steps or self.steps
|
||||||
width = width or self.width
|
width = width or self.width
|
||||||
height = height or self.height
|
height = height or self.height
|
||||||
seamless = seamless or self.seamless
|
seamless = seamless or self.seamless
|
||||||
cfg_scale = cfg_scale or self.cfg_scale
|
cfg_scale = cfg_scale or self.cfg_scale
|
||||||
ddim_eta = ddim_eta or self.ddim_eta
|
ddim_eta = ddim_eta or self.ddim_eta
|
||||||
iterations = iterations or self.iterations
|
iterations = iterations or self.iterations
|
||||||
strength = strength or self.strength
|
strength = strength or self.strength
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.log_tokenization = log_tokenization
|
self.log_tokenization = log_tokenization
|
||||||
self.step_callback = step_callback
|
self.step_callback = step_callback
|
||||||
with_variations = [] if with_variations is None else with_variations
|
with_variations = [] if with_variations is None else with_variations
|
||||||
@ -303,16 +327,17 @@ class Generate:
|
|||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||||
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
|
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
|
||||||
|
|
||||||
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
|
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
|
||||||
assert (
|
assert (
|
||||||
0.0 < strength < 1.0
|
0.0 < strength < 1.0
|
||||||
), 'img2img and inpaint strength can only work with 0.0 < strength < 1.0'
|
), 'img2img and inpaint strength can only work with 0.0 < strength < 1.0'
|
||||||
assert (
|
assert (
|
||||||
0.0 <= variation_amount <= 1.0
|
0.0 <= variation_amount <= 1.0
|
||||||
), '-v --variation_amount must be in [0.0, 1.0]'
|
), '-v --variation_amount must be in [0.0, 1.0]'
|
||||||
assert (
|
assert (
|
||||||
(embiggen == None and embiggen_tiles == None) or ((embiggen != None or embiggen_tiles != None) and init_img != None)
|
(embiggen == None and embiggen_tiles == None) or (
|
||||||
|
(embiggen != None or embiggen_tiles != None) and init_img != None)
|
||||||
), 'Embiggen requires an init/input image to be specified'
|
), 'Embiggen requires an init/input image to be specified'
|
||||||
|
|
||||||
if len(with_variations) > 0 or variation_amount > 1.0:
|
if len(with_variations) > 0 or variation_amount > 1.0:
|
||||||
@ -334,9 +359,9 @@ class Generate:
|
|||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
results = list()
|
results = list()
|
||||||
init_image = None
|
init_image = None
|
||||||
mask_image = None
|
mask_image = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
uc, c = get_uc_and_c(
|
uc, c = get_uc_and_c(
|
||||||
@ -345,8 +370,9 @@ class Generate:
|
|||||||
log_tokens =self.log_tokenization
|
log_tokens =self.log_tokenization
|
||||||
)
|
)
|
||||||
|
|
||||||
(init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit)
|
(init_image, mask_image) = self._make_images(
|
||||||
|
init_img, init_mask, width, height, fit)
|
||||||
|
|
||||||
if (init_image is not None) and (mask_image is not None):
|
if (init_image is not None) and (mask_image is not None):
|
||||||
generator = self._make_inpaint()
|
generator = self._make_inpaint()
|
||||||
elif (embiggen != None or embiggen_tiles != None):
|
elif (embiggen != None or embiggen_tiles != None):
|
||||||
@ -356,26 +382,27 @@ class Generate:
|
|||||||
else:
|
else:
|
||||||
generator = self._make_txt2img()
|
generator = self._make_txt2img()
|
||||||
|
|
||||||
generator.set_variation(self.seed, variation_amount, with_variations)
|
generator.set_variation(
|
||||||
|
self.seed, variation_amount, with_variations)
|
||||||
results = generator.generate(
|
results = generator.generate(
|
||||||
prompt,
|
prompt,
|
||||||
iterations = iterations,
|
iterations=iterations,
|
||||||
seed = self.seed,
|
seed=self.seed,
|
||||||
sampler = self.sampler,
|
sampler=self.sampler,
|
||||||
steps = steps,
|
steps=steps,
|
||||||
cfg_scale = cfg_scale,
|
cfg_scale=cfg_scale,
|
||||||
conditioning = (uc,c),
|
conditioning=(uc, c),
|
||||||
ddim_eta = ddim_eta,
|
ddim_eta=ddim_eta,
|
||||||
image_callback = image_callback, # called after the final image is generated
|
image_callback=image_callback, # called after the final image is generated
|
||||||
step_callback = step_callback, # called after each intermediate image is generated
|
step_callback=step_callback, # called after each intermediate image is generated
|
||||||
width = width,
|
width=width,
|
||||||
height = height,
|
height=height,
|
||||||
init_img = init_img, # embiggen needs to manipulate from the unmodified init_img
|
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
|
||||||
init_image = init_image, # notice that init_image is different from init_img
|
init_image=init_image, # notice that init_image is different from init_img
|
||||||
mask_image = mask_image,
|
mask_image=mask_image,
|
||||||
strength = strength,
|
strength=strength,
|
||||||
embiggen = embiggen,
|
embiggen=embiggen,
|
||||||
embiggen_tiles = embiggen_tiles,
|
embiggen_tiles=embiggen_tiles,
|
||||||
)
|
)
|
||||||
|
|
||||||
if init_color:
|
if init_color:
|
||||||
@ -404,7 +431,8 @@ class Generate:
|
|||||||
toc = time.time()
|
toc = time.time()
|
||||||
print('>> Usage stats:')
|
print('>> Usage stats:')
|
||||||
print(
|
print(
|
||||||
f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
|
f'>> {len(results)} image(s) generated in', '%4.2fs' % (
|
||||||
|
toc - tic)
|
||||||
)
|
)
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
print(
|
print(
|
||||||
@ -515,29 +543,35 @@ class Generate:
|
|||||||
|
|
||||||
|
|
||||||
def _make_images(self, img_path, mask_path, width, height, fit=False):
|
def _make_images(self, img_path, mask_path, width, height, fit=False):
|
||||||
init_image = None
|
init_image = None
|
||||||
init_mask = None
|
init_mask = None
|
||||||
if not img_path:
|
if not img_path:
|
||||||
return None,None
|
return None, None
|
||||||
|
|
||||||
image = self._load_img(img_path, width, height, fit=fit) # this returns an Image
|
image = self._load_img(img_path, width, height,
|
||||||
init_image = self._create_init_image(image) # this returns a torch tensor
|
fit=fit) # this returns an Image
|
||||||
|
# this returns a torch tensor
|
||||||
|
init_image = self._create_init_image(image)
|
||||||
|
|
||||||
if self._has_transparency(image) and not mask_path: # if image has a transparent area and no mask was provided, then try to generate mask
|
# if image has a transparent area and no mask was provided, then try to generate mask
|
||||||
print('>> Initial image has transparent areas. Will inpaint in these regions.')
|
if self._has_transparency(image) and not mask_path:
|
||||||
|
print(
|
||||||
|
'>> Initial image has transparent areas. Will inpaint in these regions.')
|
||||||
if self._check_for_erasure(image):
|
if self._check_for_erasure(image):
|
||||||
print(
|
print(
|
||||||
'>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
|
'>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
|
||||||
'>> Inpainting will be suboptimal. Please preserve the colors when making\n',
|
'>> Inpainting will be suboptimal. Please preserve the colors when making\n',
|
||||||
'>> a transparency mask, or provide mask explicitly using --init_mask (-M).'
|
'>> a transparency mask, or provide mask explicitly using --init_mask (-M).'
|
||||||
)
|
)
|
||||||
init_mask = self._create_init_mask(image) # this returns a torch tensor
|
# this returns a torch tensor
|
||||||
|
init_mask = self._create_init_mask(image)
|
||||||
|
|
||||||
if mask_path:
|
if mask_path:
|
||||||
mask_image = self._load_img(mask_path, width, height, fit=fit) # this returns an Image
|
mask_image = self._load_img(
|
||||||
init_mask = self._create_init_mask(mask_image)
|
mask_path, width, height, fit=fit) # this returns an Image
|
||||||
|
init_mask = self._create_init_mask(mask_image)
|
||||||
|
|
||||||
return init_image,init_mask
|
return init_image, init_mask
|
||||||
|
|
||||||
def _make_img2img(self):
|
def _make_img2img(self):
|
||||||
if not self.generators.get('img2img'):
|
if not self.generators.get('img2img'):
|
||||||
@ -619,38 +653,26 @@ class Generate:
|
|||||||
codeformer_fidelity = 0.75,
|
codeformer_fidelity = 0.75,
|
||||||
save_original = False,
|
save_original = False,
|
||||||
image_callback = None):
|
image_callback = None):
|
||||||
try:
|
|
||||||
if upscale is not None:
|
|
||||||
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
|
|
||||||
if strength > 0:
|
|
||||||
if facetool == 'codeformer':
|
|
||||||
from ldm.restoration.codeformer.codeformer import CodeFormerRestoration
|
|
||||||
else:
|
|
||||||
from ldm.gfpgan.gfpgan_tools import run_gfpgan
|
|
||||||
except (ModuleNotFoundError, ImportError):
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
|
||||||
return
|
|
||||||
|
|
||||||
for r in image_list:
|
for r in image_list:
|
||||||
image, seed = r
|
image, seed = r
|
||||||
try:
|
try:
|
||||||
if upscale is not None:
|
if upscale is not None:
|
||||||
if len(upscale) < 2:
|
if self.esrgan is not None:
|
||||||
upscale.append(0.75)
|
if len(upscale) < 2:
|
||||||
image = real_esrgan_upscale(
|
upscale.append(0.75)
|
||||||
image,
|
image = self.esrgan.process(
|
||||||
upscale[1],
|
image, upscale[1], seed, int(upscale[0]))
|
||||||
int(upscale[0]),
|
|
||||||
seed,
|
|
||||||
)
|
|
||||||
if strength > 0:
|
|
||||||
if facetool == 'codeformer':
|
|
||||||
image = CodeFormerRestoration().process(image=image, strength=strength, device=self.device, seed=seed, fidelity=codeformer_fidelity)
|
|
||||||
else:
|
else:
|
||||||
image = run_gfpgan(
|
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||||
image, strength, seed, 1
|
if strength > 0:
|
||||||
)
|
if self.gfpgan is not None and self.codeformer is not None:
|
||||||
|
if facetool == 'codeformer':
|
||||||
|
image = self.codeformer.process(image=image, strength=strength, device=self.device, seed=seed, fidelity=codeformer_fidelity)
|
||||||
|
else:
|
||||||
|
image = self.gfpgan.process(image, strength, seed)
|
||||||
|
else:
|
||||||
|
print(">> Face Restoration is disabled.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(
|
print(
|
||||||
f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}'
|
f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}'
|
||||||
@ -662,10 +684,10 @@ class Generate:
|
|||||||
r[0] = image
|
r[0] = image
|
||||||
|
|
||||||
# to help WebGUI - front end to generator util function
|
# to help WebGUI - front end to generator util function
|
||||||
def sample_to_image(self,samples):
|
def sample_to_image(self, samples):
|
||||||
return self._sample_to_image(samples)
|
return self._sample_to_image(samples)
|
||||||
|
|
||||||
def _sample_to_image(self,samples):
|
def _sample_to_image(self, samples):
|
||||||
if not self.base_generator:
|
if not self.base_generator:
|
||||||
from ldm.dream.generator import Generator
|
from ldm.dream.generator import Generator
|
||||||
self.base_generator = Generator(self.model)
|
self.base_generator = Generator(self.model)
|
||||||
@ -708,7 +730,7 @@ class Generate:
|
|||||||
# for usage statistics
|
# for usage statistics
|
||||||
device_type = choose_torch_device()
|
device_type = choose_torch_device()
|
||||||
if device_type == 'cuda':
|
if device_type == 'cuda':
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
|
||||||
# this does the work
|
# this does the work
|
||||||
@ -756,12 +778,12 @@ class Generate:
|
|||||||
f'>> loaded input image of size {image.width}x{image.height} from {path}'
|
f'>> loaded input image of size {image.width}x{image.height} from {path}'
|
||||||
)
|
)
|
||||||
if fit:
|
if fit:
|
||||||
image = self._fit_image(image,(width,height))
|
image = self._fit_image(image, (width, height))
|
||||||
else:
|
else:
|
||||||
image = self._squeeze_image(image)
|
image = self._squeeze_image(image)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def _create_init_image(self,image):
|
def _create_init_image(self, image):
|
||||||
image = image.convert('RGB')
|
image = image.convert('RGB')
|
||||||
# print(
|
# print(
|
||||||
# f'>> DEBUG: writing the image to img.png'
|
# f'>> DEBUG: writing the image to img.png'
|
||||||
@ -770,7 +792,7 @@ class Generate:
|
|||||||
image = np.array(image).astype(np.float32) / 255.0
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
image = torch.from_numpy(image)
|
image = torch.from_numpy(image)
|
||||||
image = 2.0 * image - 1.0
|
image = 2.0 * image - 1.0
|
||||||
return image.to(self.device)
|
return image.to(self.device)
|
||||||
|
|
||||||
def _create_init_mask(self, image):
|
def _create_init_mask(self, image):
|
||||||
@ -779,7 +801,8 @@ class Generate:
|
|||||||
image = image.convert('RGB')
|
image = image.convert('RGB')
|
||||||
# BUG: We need to use the model's downsample factor rather than hardcoding "8"
|
# BUG: We need to use the model's downsample factor rather than hardcoding "8"
|
||||||
from ldm.dream.generator.base import downsampling
|
from ldm.dream.generator.base import downsampling
|
||||||
image = image.resize((image.width//downsampling, image.height//downsampling), resample=Image.Resampling.LANCZOS)
|
image = image.resize((image.width//downsampling, image.height //
|
||||||
|
downsampling), resample=Image.Resampling.LANCZOS)
|
||||||
# print(
|
# print(
|
||||||
# f'>> DEBUG: writing the mask to mask.png'
|
# f'>> DEBUG: writing the mask to mask.png'
|
||||||
# )
|
# )
|
||||||
@ -801,7 +824,7 @@ class Generate:
|
|||||||
mask = ImageOps.invert(mask)
|
mask = ImageOps.invert(mask)
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
def _has_transparency(self,image):
|
def _has_transparency(self, image):
|
||||||
if image.info.get("transparency", None) is not None:
|
if image.info.get("transparency", None) is not None:
|
||||||
return True
|
return True
|
||||||
if image.mode == "P":
|
if image.mode == "P":
|
||||||
@ -815,11 +838,10 @@ class Generate:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _check_for_erasure(self, image):
|
||||||
def _check_for_erasure(self,image):
|
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
pixdata = image.load()
|
pixdata = image.load()
|
||||||
colored = 0
|
colored = 0
|
||||||
for y in range(height):
|
for y in range(height):
|
||||||
for x in range(width):
|
for x in range(width):
|
||||||
if pixdata[x, y][3] == 0:
|
if pixdata[x, y][3] == 0:
|
||||||
@ -829,28 +851,28 @@ class Generate:
|
|||||||
colored += 1
|
colored += 1
|
||||||
return colored == 0
|
return colored == 0
|
||||||
|
|
||||||
def _squeeze_image(self,image):
|
def _squeeze_image(self, image):
|
||||||
x,y,resize_needed = self._resolution_check(image.width,image.height)
|
x, y, resize_needed = self._resolution_check(image.width, image.height)
|
||||||
if resize_needed:
|
if resize_needed:
|
||||||
return InitImageResizer(image).resize(x,y)
|
return InitImageResizer(image).resize(x, y)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
def _fit_image(self, image, max_dimensions):
|
||||||
def _fit_image(self,image,max_dimensions):
|
w, h = max_dimensions
|
||||||
w,h = max_dimensions
|
|
||||||
print(
|
print(
|
||||||
f'>> image will be resized to fit inside a box {w}x{h} in size.'
|
f'>> image will be resized to fit inside a box {w}x{h} in size.'
|
||||||
)
|
)
|
||||||
if image.width > image.height:
|
if image.width > image.height:
|
||||||
h = None # by setting h to none, we tell InitImageResizer to fit into the width and calculate height
|
h = None # by setting h to none, we tell InitImageResizer to fit into the width and calculate height
|
||||||
elif image.height > image.width:
|
elif image.height > image.width:
|
||||||
w = None # ditto for w
|
w = None # ditto for w
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
image = InitImageResizer(image).resize(w,h) # note that InitImageResizer does the multiple of 64 truncation internally
|
# note that InitImageResizer does the multiple of 64 truncation internally
|
||||||
|
image = InitImageResizer(image).resize(w, h)
|
||||||
print(
|
print(
|
||||||
f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}'
|
f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}'
|
||||||
)
|
)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def _resolution_check(self, width, height, log=False):
|
def _resolution_check(self, width, height, log=False):
|
||||||
@ -864,7 +886,7 @@ class Generate:
|
|||||||
f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}'
|
f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}'
|
||||||
)
|
)
|
||||||
height = h
|
height = h
|
||||||
width = w
|
width = w
|
||||||
resize_needed = True
|
resize_needed = True
|
||||||
|
|
||||||
if (width * height) > (self.width * self.height):
|
if (width * height) > (self.width * self.height):
|
||||||
|
@ -1,168 +0,0 @@
|
|||||||
import torch
|
|
||||||
import warnings
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
#from scripts.dream import create_argv_parser
|
|
||||||
from ldm.dream.args import Args
|
|
||||||
|
|
||||||
opt = Args()
|
|
||||||
opt.parse_args()
|
|
||||||
model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path)
|
|
||||||
gfpgan_model_exists = os.path.isfile(model_path)
|
|
||||||
|
|
||||||
def run_gfpgan(image, strength, seed, upsampler_scale=4):
|
|
||||||
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
|
|
||||||
gfpgan = None
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if not gfpgan_model_exists:
|
|
||||||
raise Exception('GFPGAN model not found at path ' + model_path)
|
|
||||||
|
|
||||||
sys.path.append(os.path.abspath(opt.gfpgan_dir))
|
|
||||||
from gfpgan import GFPGANer
|
|
||||||
|
|
||||||
bg_upsampler = _load_gfpgan_bg_upsampler(
|
|
||||||
opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile
|
|
||||||
)
|
|
||||||
|
|
||||||
gfpgan = GFPGANer(
|
|
||||||
model_path=model_path,
|
|
||||||
upscale=upsampler_scale,
|
|
||||||
arch='clean',
|
|
||||||
channel_multiplier=2,
|
|
||||||
bg_upsampler=bg_upsampler,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
print('>> Error loading GFPGAN:', file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
if gfpgan is None:
|
|
||||||
print(
|
|
||||||
f'>> WARNING: GFPGAN not initialized.'
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth to {model_path}, \nor change GFPGAN directory with --gfpgan_dir.'
|
|
||||||
)
|
|
||||||
return image
|
|
||||||
|
|
||||||
image = image.convert('RGB')
|
|
||||||
|
|
||||||
cropped_faces, restored_faces, restored_img = gfpgan.enhance(
|
|
||||||
np.array(image, dtype=np.uint8),
|
|
||||||
has_aligned=False,
|
|
||||||
only_center_face=False,
|
|
||||||
paste_back=True,
|
|
||||||
)
|
|
||||||
res = Image.fromarray(restored_img)
|
|
||||||
|
|
||||||
if strength < 1.0:
|
|
||||||
# Resize the image to the new image if the sizes have changed
|
|
||||||
if restored_img.size != image.size:
|
|
||||||
image = image.resize(res.size)
|
|
||||||
res = Image.blend(image, res, strength)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gfpgan = None
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400):
|
|
||||||
if bg_upsampler == 'realesrgan':
|
|
||||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
|
||||||
use_half_precision = False
|
|
||||||
else:
|
|
||||||
use_half_precision = True
|
|
||||||
|
|
||||||
model_path = {
|
|
||||||
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
|
||||||
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
|
|
||||||
}
|
|
||||||
|
|
||||||
if upsampler_scale not in model_path:
|
|
||||||
return None
|
|
||||||
|
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
||||||
from realesrgan import RealESRGANer
|
|
||||||
|
|
||||||
if upsampler_scale == 4:
|
|
||||||
model = RRDBNet(
|
|
||||||
num_in_ch=3,
|
|
||||||
num_out_ch=3,
|
|
||||||
num_feat=64,
|
|
||||||
num_block=23,
|
|
||||||
num_grow_ch=32,
|
|
||||||
scale=4,
|
|
||||||
)
|
|
||||||
if upsampler_scale == 2:
|
|
||||||
model = RRDBNet(
|
|
||||||
num_in_ch=3,
|
|
||||||
num_out_ch=3,
|
|
||||||
num_feat=64,
|
|
||||||
num_block=23,
|
|
||||||
num_grow_ch=32,
|
|
||||||
scale=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
bg_upsampler = RealESRGANer(
|
|
||||||
scale=upsampler_scale,
|
|
||||||
model_path=model_path[upsampler_scale],
|
|
||||||
model=model,
|
|
||||||
tile=bg_tile,
|
|
||||||
tile_pad=10,
|
|
||||||
pre_pad=0,
|
|
||||||
half=use_half_precision,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
bg_upsampler = None
|
|
||||||
|
|
||||||
return bg_upsampler
|
|
||||||
|
|
||||||
|
|
||||||
def real_esrgan_upscale(image, strength, upsampler_scale, seed):
|
|
||||||
print(
|
|
||||||
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
|
|
||||||
)
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
|
||||||
|
|
||||||
try:
|
|
||||||
upsampler = _load_gfpgan_bg_upsampler(
|
|
||||||
opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
output, img_mode = upsampler.enhance(
|
|
||||||
np.array(image, dtype=np.uint8),
|
|
||||||
outscale=upsampler_scale,
|
|
||||||
alpha_upsampler=opt.gfpgan_bg_upsampler,
|
|
||||||
)
|
|
||||||
|
|
||||||
res = Image.fromarray(output)
|
|
||||||
|
|
||||||
if strength < 1.0:
|
|
||||||
# Resize the image to the new image if the sizes have changed
|
|
||||||
if output.size != image.size:
|
|
||||||
image = image.resize(res.size)
|
|
||||||
res = Image.blend(image, res, strength)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
upsampler = None
|
|
||||||
|
|
||||||
return res
|
|
@ -2,12 +2,20 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import warnings
|
import warnings
|
||||||
|
import sys
|
||||||
|
|
||||||
pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||||
|
|
||||||
class CodeFormerRestoration():
|
class CodeFormerRestoration():
|
||||||
def __init__(self) -> None:
|
def __init__(self,
|
||||||
pass
|
codeformer_dir='ldm/restoration/codeformer',
|
||||||
|
codeformer_model_path='weights/codeformer.pth') -> None:
|
||||||
|
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
|
||||||
|
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
||||||
|
|
||||||
|
if not self.codeformer_model_exists:
|
||||||
|
print('## NOT FOUND: CodeFormer model not found at ' + self.model_path)
|
||||||
|
sys.path.append(os.path.abspath(codeformer_dir))
|
||||||
|
|
||||||
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
|
76
ldm/restoration/gfpgan/gfpgan.py
Normal file
76
ldm/restoration/gfpgan/gfpgan.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import torch
|
||||||
|
import warnings
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class GFPGAN():
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
gfpgan_dir='src/gfpgan',
|
||||||
|
gfpgan_model_path='experiments/pretrained_models/GFPGANv1.3.pth') -> None:
|
||||||
|
|
||||||
|
self.model_path = os.path.join(gfpgan_dir, gfpgan_model_path)
|
||||||
|
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||||
|
|
||||||
|
if not self.gfpgan_model_exists:
|
||||||
|
raise Exception(
|
||||||
|
'GFPGAN model not found at path ' + self.model_path)
|
||||||
|
sys.path.append(os.path.abspath(gfpgan_dir))
|
||||||
|
|
||||||
|
def model_exists(self):
|
||||||
|
return os.path.isfile(self.model_path)
|
||||||
|
|
||||||
|
def process(self, image, strength: float, seed: str = None):
|
||||||
|
if seed is not None:
|
||||||
|
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
try:
|
||||||
|
from gfpgan import GFPGANer
|
||||||
|
self.gfpgan = GFPGANer(
|
||||||
|
model_path=self.model_path,
|
||||||
|
upscale=1,
|
||||||
|
arch='clean',
|
||||||
|
channel_multiplier=2,
|
||||||
|
bg_upsampler=None,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
import traceback
|
||||||
|
print('>> Error loading GFPGAN:', file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
if self.gfpgan is None:
|
||||||
|
print(
|
||||||
|
f'>> WARNING: GFPGAN not initialized.'
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth to {self.model_path}, \nor change GFPGAN directory with --gfpgan_dir.'
|
||||||
|
)
|
||||||
|
|
||||||
|
image = image.convert('RGB')
|
||||||
|
|
||||||
|
_, _, restored_img = self.gfpgan.enhance(
|
||||||
|
np.array(image, dtype=np.uint8),
|
||||||
|
has_aligned=False,
|
||||||
|
only_center_face=False,
|
||||||
|
paste_back=True,
|
||||||
|
)
|
||||||
|
res = Image.fromarray(restored_img)
|
||||||
|
|
||||||
|
if strength < 1.0:
|
||||||
|
# Resize the image to the new image if the sizes have changed
|
||||||
|
if restored_img.size != image.size:
|
||||||
|
image = image.resize(res.size)
|
||||||
|
res = Image.blend(image, res, strength)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
self.gfpgan = None
|
||||||
|
|
||||||
|
return res
|
102
ldm/restoration/realesrgan/realesrgan.py
Normal file
102
ldm/restoration/realesrgan/realesrgan.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
import torch
|
||||||
|
import warnings
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class ESRGAN():
|
||||||
|
def __init__(self, bg_tile_size=400) -> None:
|
||||||
|
self.bg_tile_size = bg_tile_size
|
||||||
|
|
||||||
|
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||||
|
use_half_precision = False
|
||||||
|
else:
|
||||||
|
use_half_precision = True
|
||||||
|
|
||||||
|
def load_esrgan_bg_upsampler(self, upsampler_scale):
|
||||||
|
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||||
|
use_half_precision = False
|
||||||
|
else:
|
||||||
|
use_half_precision = True
|
||||||
|
|
||||||
|
model_path = {
|
||||||
|
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
||||||
|
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
|
||||||
|
}
|
||||||
|
|
||||||
|
if upsampler_scale not in model_path:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
|
if upsampler_scale == 4:
|
||||||
|
model = RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=4,
|
||||||
|
)
|
||||||
|
if upsampler_scale == 2:
|
||||||
|
model = RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
bg_upsampler = RealESRGANer(
|
||||||
|
scale=upsampler_scale,
|
||||||
|
model_path=model_path[upsampler_scale],
|
||||||
|
model=model,
|
||||||
|
tile=self.bg_tile_size,
|
||||||
|
tile_pad=10,
|
||||||
|
pre_pad=0,
|
||||||
|
half=use_half_precision,
|
||||||
|
)
|
||||||
|
|
||||||
|
return bg_upsampler
|
||||||
|
|
||||||
|
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
|
||||||
|
if seed is not None:
|
||||||
|
print(
|
||||||
|
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
|
||||||
|
)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
|
||||||
|
try:
|
||||||
|
upsampler = self.load_esrgan_bg_upsampler(upsampler_scale)
|
||||||
|
except Exception:
|
||||||
|
import traceback
|
||||||
|
import sys
|
||||||
|
|
||||||
|
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
output, _ = upsampler.enhance(
|
||||||
|
np.array(image, dtype=np.uint8),
|
||||||
|
outscale=upsampler_scale,
|
||||||
|
alpha_upsampler='realesrgan',
|
||||||
|
)
|
||||||
|
|
||||||
|
res = Image.fromarray(output)
|
||||||
|
|
||||||
|
if strength < 1.0:
|
||||||
|
# Resize the image to the new image if the sizes have changed
|
||||||
|
if output.size != image.size:
|
||||||
|
image = image.resize(res.size)
|
||||||
|
res = Image.blend(image, res, strength)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
upsampler = None
|
||||||
|
|
||||||
|
return res
|
34
ldm/restoration/restoration.py
Normal file
34
ldm/restoration/restoration.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
class Restoration():
|
||||||
|
def __init__(self, gfpgan_dir='./src/gfpgan', gfpgan_model_path='experiments/pretrained_models/GFPGANv1.3.pth', esrgan_bg_tile=400) -> None:
|
||||||
|
self.gfpgan_dir = gfpgan_dir
|
||||||
|
self.gfpgan_model_path = gfpgan_model_path
|
||||||
|
self.esrgan_bg_tile = esrgan_bg_tile
|
||||||
|
|
||||||
|
def load_face_restore_models(self):
|
||||||
|
# Load GFPGAN
|
||||||
|
gfpgan = self.load_gfpgan()
|
||||||
|
if gfpgan.gfpgan_model_exists:
|
||||||
|
print('>> GFPGAN Initialized')
|
||||||
|
|
||||||
|
# Load CodeFormer
|
||||||
|
codeformer = self.load_codeformer()
|
||||||
|
if codeformer.codeformer_model_exists:
|
||||||
|
print('>> CodeFormer Initialized')
|
||||||
|
|
||||||
|
return gfpgan, codeformer
|
||||||
|
|
||||||
|
# Face Restore Models
|
||||||
|
def load_gfpgan(self):
|
||||||
|
from ldm.restoration.gfpgan.gfpgan import GFPGAN
|
||||||
|
return GFPGAN(self.gfpgan_dir, self.gfpgan_model_path)
|
||||||
|
|
||||||
|
def load_codeformer(self):
|
||||||
|
from ldm.restoration.codeformer.codeformer import CodeFormerRestoration
|
||||||
|
return CodeFormerRestoration()
|
||||||
|
|
||||||
|
# Upscale Models
|
||||||
|
def load_ersgan(self):
|
||||||
|
from ldm.restoration.realesrgan.realesrgan import ESRGAN
|
||||||
|
esrgan = ESRGAN(self.esrgan_bg_tile)
|
||||||
|
print('>> ESRGAN Initialized')
|
||||||
|
return esrgan;
|
@ -43,7 +43,25 @@ def main():
|
|||||||
import transformers
|
import transformers
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
# creating a simple Generate object with a handful of
|
# Loading Face Restoration and ESRGAN Modules
|
||||||
|
try:
|
||||||
|
gfpgan, codeformer, esrgan = None, None, None
|
||||||
|
from ldm.restoration.restoration import Restoration
|
||||||
|
restoration = Restoration(opt.gfpgan_dir, opt.gfpgan_model_path, opt.esrgan_bg_tile)
|
||||||
|
if opt.restore:
|
||||||
|
gfpgan, codeformer = restoration.load_face_restore_models()
|
||||||
|
else:
|
||||||
|
print('>> Face Restoration Disabled')
|
||||||
|
if opt.esrgan:
|
||||||
|
esrgan = restoration.load_ersgan()
|
||||||
|
else:
|
||||||
|
print('>> ESRGAN Disabled')
|
||||||
|
except (ModuleNotFoundError, ImportError):
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
||||||
|
|
||||||
|
# creating a simple text2image object with a handful of
|
||||||
# defaults passed on the command line.
|
# defaults passed on the command line.
|
||||||
# additional parameters will be added (or overriden) during
|
# additional parameters will be added (or overriden) during
|
||||||
# the user input loop
|
# the user input loop
|
||||||
@ -55,6 +73,9 @@ def main():
|
|||||||
embedding_path = opt.embedding_path,
|
embedding_path = opt.embedding_path,
|
||||||
full_precision = opt.full_precision,
|
full_precision = opt.full_precision,
|
||||||
precision = opt.precision,
|
precision = opt.precision,
|
||||||
|
gfpgan=gfpgan,
|
||||||
|
codeformer=codeformer,
|
||||||
|
esrgan=esrgan
|
||||||
)
|
)
|
||||||
except (FileNotFoundError, IOError, KeyError) as e:
|
except (FileNotFoundError, IOError, KeyError) as e:
|
||||||
print(f'{e}. Aborting.')
|
print(f'{e}. Aborting.')
|
||||||
@ -91,7 +112,7 @@ def main():
|
|||||||
|
|
||||||
# web server loops forever
|
# web server loops forever
|
||||||
if opt.web:
|
if opt.web:
|
||||||
dream_server_loop(gen, opt.host, opt.port, opt.outdir)
|
dream_server_loop(gen, opt.host, opt.port, opt.outdir, gfpgan)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
main_loop(gen, opt, infile)
|
main_loop(gen, opt, infile)
|
||||||
@ -347,7 +368,7 @@ def get_next_command(infile=None) -> str: # command string
|
|||||||
print(f'#{command}')
|
print(f'#{command}')
|
||||||
return command
|
return command
|
||||||
|
|
||||||
def dream_server_loop(gen, host, port, outdir):
|
def dream_server_loop(gen, host, port, outdir, gfpgan):
|
||||||
print('\n* --web was specified, starting web server...')
|
print('\n* --web was specified, starting web server...')
|
||||||
# Change working directory to the stable-diffusion directory
|
# Change working directory to the stable-diffusion directory
|
||||||
os.chdir(
|
os.chdir(
|
||||||
@ -357,6 +378,10 @@ def dream_server_loop(gen, host, port, outdir):
|
|||||||
# Start server
|
# Start server
|
||||||
DreamServer.model = gen # misnomer in DreamServer - this is not the model you are looking for
|
DreamServer.model = gen # misnomer in DreamServer - this is not the model you are looking for
|
||||||
DreamServer.outdir = outdir
|
DreamServer.outdir = outdir
|
||||||
|
DreamServer.gfpgan_model_exists = False
|
||||||
|
if gfpgan is not None:
|
||||||
|
DreamServer.gfpgan_model_exists = gfpgan.gfpgan_model_exists
|
||||||
|
|
||||||
dream_server = ThreadingDreamServer((host, port))
|
dream_server = ThreadingDreamServer((host, port))
|
||||||
print(">> Started Stable Diffusion dream server!")
|
print(">> Started Stable Diffusion dream server!")
|
||||||
if host == '0.0.0.0':
|
if host == '0.0.0.0':
|
||||||
@ -374,5 +399,19 @@ def dream_server_loop(gen, host, port, outdir):
|
|||||||
dream_server.server_close()
|
dream_server.server_close()
|
||||||
|
|
||||||
|
|
||||||
|
<<<<<<< HEAD
|
||||||
|
=======
|
||||||
|
def write_log_message(results, log_path):
|
||||||
|
"""logs the name of the output image, prompt, and prompt args to the terminal and log file"""
|
||||||
|
global output_cntr
|
||||||
|
log_lines = [f'{path}: {prompt}\n' for path, prompt in results]
|
||||||
|
for l in log_lines:
|
||||||
|
output_cntr += 1
|
||||||
|
print(f'[{output_cntr}] {l}', end='')
|
||||||
|
|
||||||
|
with open(log_path, 'a', encoding='utf-8') as file:
|
||||||
|
file.writelines(log_lines)
|
||||||
|
|
||||||
|
>>>>>>> GFPGAN and Real ESRGAN Implementation Refactor
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user