optionally scale initial image to fit box defined by width x height

* This functionality is triggered by the --fit option in the CLI (default
false), and by the "fit" checkbox in the WebGUI (default True)

* In addition, this commit contains a number of whitespace changes to
make the code more readable, as well as an attempt to unify the visual
appearance of info and warning messages.
This commit is contained in:
Lincoln Stein 2022-09-01 00:50:28 -04:00
parent 4b560b50c2
commit 28fe84177e
8 changed files with 232 additions and 182 deletions

View File

@ -8,11 +8,10 @@ class InitImageResizer():
def resize(self,width=None,height=None) -> Image:
"""
Return a copy of the image resized to width x height.
The aspect ratio is maintained, with any excess space
filled using black borders (i.e. letterboxed). If
neither width nor height are provided, then returns
a copy of the original image. If one or the other is
Return a copy of the image resized to fit within
a box width x height. The aspect ratio is
maintained. If neither width nor height are provided,
then returns a copy of the original image. If one or the other is
provided, then the other will be calculated from the
aspect ratio.
@ -21,38 +20,34 @@ class InitImageResizer():
"""
im = self.image
if not(width or height):
return im.copy()
ar = im.width/im.height
ar = im.width/float(im.height)
# Infer missing values from aspect ratio
if not height: # height missing
if not(width or height): # both missing
width = im.width
height = im.height
elif not height: # height missing
height = int(width/ar)
if not width: # width missing
elif not width: # width missing
width = int(height*ar)
# rw and rh are the resizing width and height for the image
# they maintain the aspect ratio, but may not completelyl fill up
# the requested destination size
(rw,rh) = (width,int(width/ar)) if im.width>=im.height else (int(height*ar),width)
(rw,rh) = (width,int(width/ar)) if im.width>=im.height else (int(height*ar),height)
#round everything to multiples of 64
width,height,rw,rh = map(
lambda x: x-x%64, (width,height,rw,rh)
)
# resize the original image so that it fits inside the dest
# no resize necessary, but return a copy
if im.width == width and im.height == height:
return im.copy()
# otherwise resize the original image so that it fits inside the bounding box
resized_image = self.image.resize((rw,rh),resample=Image.Resampling.LANCZOS)
# create new destination image of specified dimensions
# and paste the resized image into it centered appropriately
new_image = Image.new('RGB',(width,height))
new_image.paste(resized_image,((width-rw)//2,(height-rh)//2))
print(f'>> Resized image size to {width}x{height}')
return new_image
return resized_image
def make_grid(image_list, rows=None, cols=None):
image_cnt = len(image_list)

View File

@ -61,6 +61,8 @@ class PromptFormatter:
switches.append(f'-A{opt.sampler_name or t2i.sampler_name}')
if opt.init_img:
switches.append(f'-I{opt.init_img}')
if opt.fit:
switches.append(f'--fit')
if opt.strength and opt.init_img is not None:
switches.append(f'-f{opt.strength or t2i.strength}')
if opt.gfpgan_strength:

View File

@ -70,6 +70,7 @@ class DreamServer(BaseHTTPRequestHandler):
steps = int(post_data['steps'])
width = int(post_data['width'])
height = int(post_data['height'])
fit = 'fit' in post_data
cfgscale = float(post_data['cfgscale'])
sampler_name = post_data['sampler']
gfpgan_strength = float(post_data['gfpgan_strength']) if gfpgan_model_exists else 0
@ -80,7 +81,7 @@ class DreamServer(BaseHTTPRequestHandler):
seed = self.model.seed if int(post_data['seed']) == -1 else int(post_data['seed'])
self.canceled.clear()
print(f"Request to generate with prompt: {prompt}")
print(f">> Request to generate with prompt: {prompt}")
# In order to handle upscaled images, the PngWriter needs to maintain state
# across images generated by each call to prompt2img(), so we define it in
# the outer scope of image_done()
@ -181,6 +182,9 @@ class DreamServer(BaseHTTPRequestHandler):
seed = seed,
steps = steps,
sampler_name = sampler_name,
width = width,
height = height,
fit = fit,
gfpgan_strength=gfpgan_strength,
upscale = upscale,
step_callback=image_progress,
@ -192,8 +196,6 @@ class DreamServer(BaseHTTPRequestHandler):
print(f"Canceled.")
return
print(f"Prompt generated!")
class ThreadingDreamServer(ThreadingHTTPServer):
def __init__(self, server_address):

View File

@ -14,7 +14,7 @@ model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path)
gfpgan_model_exists = os.path.isfile(model_path)
def _run_gfpgan(image, strength, prompt, seed, upsampler_scale=4):
print(f'\n* GFPGAN - Restoring Faces: {prompt} : seed:{seed}')
print(f'>> GFPGAN - Restoring Faces: {prompt} : seed:{seed}')
gfpgan = None
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
@ -41,12 +41,12 @@ def _run_gfpgan(image, strength, prompt, seed, upsampler_scale=4):
except Exception:
import traceback
print('Error loading GFPGAN:', file=sys.stderr)
print('>> Error loading GFPGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if gfpgan is None:
print(
f'GFPGAN not initialized, it must be loaded via the --gfpgan argument'
f'>> GFPGAN not initialized, it must be loaded via the --gfpgan argument'
)
return image
@ -129,7 +129,7 @@ def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400):
def real_esrgan_upscale(image, strength, upsampler_scale, prompt, seed):
print(
f'\n* Real-ESRGAN Upscaling: {prompt} : seed:{seed} : scale:{upsampler_scale}x'
f'>> Real-ESRGAN Upscaling: {prompt} : seed:{seed} : scale:{upsampler_scale}x'
)
with warnings.catch_warnings():
@ -143,7 +143,7 @@ def real_esrgan_upscale(image, strength, upsampler_scale, prompt, seed):
except Exception:
import traceback
print('Error loading Real-ESRGAN:', file=sys.stderr)
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
output, img_mode = upsampler.enhance(

View File

@ -209,11 +209,11 @@ class T2I:
height = None,
# these are specific to img2img
init_img = None,
fit = False,
strength = None,
gfpgan_strength= 0,
save_original = False,
upscale = None,
variants=None,
sampler_name = None,
log_tokenization= False,
**args,
@ -232,7 +232,6 @@ class T2I:
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
step_callback // a function or method that will be called each step
image_callback // a function or method that will be called each time an image is generated
@ -269,9 +268,7 @@ class T2I:
0.0 <= strength <= 1.0
), 'can only work with strength in [0.0, 1.0]'
if not(width == self.width and height == self.height):
width, height, _ = self._resolution_check(width, height, log=True)
scope = autocast if self.precision == 'autocast' else nullcontext
if sampler_name and (sampler_name != self.sampler_name):
@ -295,6 +292,7 @@ class T2I:
init_img=init_img,
width=width,
height=height,
fit=fit,
strength=strength,
callback=step_callback,
)
@ -312,7 +310,7 @@ class T2I:
)
with scope(self.device.type), self.model.ema_scope():
for n in trange(iterations, desc='Generating'):
for n in trange(iterations, desc='>> Generating'):
seed_everything(seed)
image = next(images_iterator)
results.append([image, seed])
@ -365,12 +363,12 @@ class T2I:
print('Are you sure your system has an adequate NVIDIA GPU?')
toc = time.time()
print('Usage stats:')
print('>> Usage stats:')
print(
f' {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
)
print(
f' Max VRAM used for this generation:',
f'>> Max VRAM used for this generation:',
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
)
@ -379,7 +377,7 @@ class T2I:
self.session_peakmem, torch.cuda.max_memory_allocated()
)
print(
f' Max VRAM used since script start: ',
f'>> Max VRAM used since script start: ',
'%4.2fG' % (self.session_peakmem / 1e9),
)
return results
@ -435,6 +433,7 @@ class T2I:
init_img,
width,
height,
fit,
strength,
callback, # Currently not implemented for img2img
):
@ -445,13 +444,13 @@ class T2I:
# PLMS sampler not supported yet, so ignore previous sampler
if self.sampler_name != 'ddim':
print(
f"sampler '{self.sampler_name}' is not yet supported. Using DDIM sampler"
f">> sampler '{self.sampler_name}' is not yet supported. Using DDIM sampler"
)
sampler = DDIMSampler(self.model, device=self.device)
else:
sampler = self.sampler
init_image = self._load_img(init_img, width, height).to(self.device)
init_image = self._load_img(init_img, width, height,fit).to(self.device)
with precision_scope(self.device.type):
init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
@ -581,7 +580,7 @@ class T2I:
print(msg)
def _load_model_from_config(self, config, ckpt):
print(f'Loading model from {ckpt}')
print(f'>> Loading model from {ckpt}')
pl_sd = torch.load(ckpt, map_location='cpu')
# if "global_step" in pl_sd:
# print(f"Global Step: {pl_sd['global_step']}")
@ -596,41 +595,63 @@ class T2I:
)
else:
print(
'Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.'
'>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.'
)
model.half()
return model
def _load_img(self, path, width, height):
print(f'image path = {path}, cwd = {os.getcwd()}')
def _load_img(self, path, width, height, fit=False):
with Image.open(path) as img:
image = img.convert('RGB')
print(
f'loaded input image of size {image.width}x{image.height} from {path}')
f'>> loaded input image of size {image.width}x{image.height} from {path}'
)
from ldm.dream.image_util import InitImageResizer
if width == self.width and height == self.height:
new_image_width, new_image_height, resize_needed = self._resolution_check(
image.width, image.height)
# The logic here is:
# 1. If "fit" is true, then the image will be fit into the bounding box defined
# by width and height. It will do this in a way that preserves the init image's
# aspect ratio while preventing letterboxing. This means that if there is
# leftover horizontal space after rescaling the image to fit in the bounding box,
# the generated image's width will be reduced to the rescaled init image's width.
# Similarly for the vertical space.
# 2. Otherwise, if "fit" is false, then the image will be scaled, preserving its
# aspect ratio, to the nearest multiple of 64. Large images may generate an
# unexpected OOM error.
if fit:
image = self._fit_image(image,(width,height))
else:
if height == self.height:
new_image_width, new_image_height, resize_needed = self._resolution_check(
width, image.height)
if width == self.width:
new_image_width, new_image_height, resize_needed = self._resolution_check(
image.width, height)
else:
image = InitImageResizer(image).resize(width, height)
resize_needed=False
if resize_needed:
image = InitImageResizer(image).resize(
new_image_width, new_image_height)
image = self._squeeze_image(image)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def _squeeze_image(self,image):
x,y,resize_needed = self._resolution_check(image.width,image.height)
if resize_needed:
return InitImageResizer(image).resize(x,y)
return image
def _fit_image(self,image,max_dimensions):
w,h = max_dimensions
print(
f'>> image will be resized to fit inside a box {w}x{h} in size.'
)
if image.width > image.height:
h = None # by setting h to none, we tell InitImageResizer to fit into the width and calculate height
elif image.height > image.width:
w = None # ditto for w
else:
pass
image = InitImageResizer(image).resize(w,h) # note that InitImageResizer does the multiple of 64 truncation internally
print(
f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}'
)
return image
# TO DO: Move this and related weighted subprompt code into its own module.
def _split_weighted_subprompts(text, skip_normalize=False):
"""
grabs all text up to the first occurrence of ':'

View File

@ -88,7 +88,7 @@ def main():
tic = time.time()
t2i.load_model()
print(
f'model loaded in', '%4.2fs' % (time.time() - tic)
f'>> model loaded in', '%4.2fs' % (time.time() - tic)
)
if not infile:
@ -483,6 +483,13 @@ def create_cmd_parser():
type=str,
help='Path to input image for img2img mode (supersedes width and height)',
)
parser.add_argument(
'-T',
'-fit',
'--fit',
action='store_true',
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
)
parser.add_argument(
'-f',
'--strength',

View File

@ -8,13 +8,15 @@
margin-top: 20vh;
margin-left: auto;
margin-right: auto;
max-width: 800px;
max-width: 1024px;
text-align: center;
}
fieldset {
border: none;
}
div {
padding: 10px 10px 10px 10px;
}
#fieldset-search {
display: flex;
}
@ -78,3 +80,18 @@ label {
cursor: pointer;
color: red;
}
#txt2img {
background-color: #DCDCDC;
}
#img2img {
background-color: #F5F5F5;
}
#gfpgan {
background-color: #DCDCDC;
}
#progress-section {
background-color: #F5F5F5;
}
#about {
background-color: #DCDCDC;
}

View File

@ -14,6 +14,7 @@
<h2 id="header">Stable Diffusion Dream Server</h2>
<form id="generate-form" method="post" action="#">
<div id="txt2img">
<fieldset id="fieldset-search">
<input type="text" id="prompt" name="prompt">
<input type="submit" id="submit" value="Generate">
@ -62,15 +63,19 @@
<label title="Set to -1 for random seed" for="seed">Seed:</label>
<input value="-1" type="number" id="seed" name="seed">
<button type="button" id="reset-seed">&olarr;</button>
<input type="checkbox" name="progress_images" id="progress_images">
<label for="progress_images">Display in-progress images (slows down generation):</label>
<button type="button" id="reset-all">Reset to Defaults</button>
</div>
<div id="img2img">
<label title="Upload an image to use img2img" for="initimg">Initial image:</label>
<input type="file" id="initimg" name="initimg" accept=".jpg, .jpeg, .png">
<br>
<label for="strength">Img2Img Strength:</label>
<input value="0.75" type="number" id="strength" name="strength" step="0.01" min="0" max="1">
<label title="Upload an image to use img2img" for="initimg">Init:</label>
<input type="file" id="initimg" name="initimg" accept=".jpg, .jpeg, .png">
<button type="button" id="reset-all">Reset to Defaults</button>
<br>
<label for="progress_images">Display in-progress images (slows down generation):</label>
<input type="checkbox" name="progress_images" id="progress_images">
<input type="checkbox" id="fit" name="fit" checked>
<label title="Rescale image to fit within requested width and height" for="fit">Fit to width/height:</label>
</div>
<div id="gfpgan">
<label title="Strength of the gfpgan (face fixing) algorithm." for="gfpgan_strength">GPFGAN Strength (0 to disable):</label>
<input value="0.8" min="0" max="1" type="number" id="gfpgan_strength" name="gfpgan_strength" step="0.05">
@ -86,6 +91,7 @@
</fieldset>
</form>
<div id="about">For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a></div>
<br>
<div id="progress-section">
<progress id="progress-bar" value="0" max="1"></progress>
<span id="cancel-button" title="Cancel">&#10006;</span>