From 1480ef84dcafc7a2119879f1e2bef8d14fd810e8 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 31 Aug 2022 14:49:00 +1200 Subject: [PATCH] Add Resolution Checker --- ldm/dream/image_util.py | 2 + ldm/simplet2i.py | 84 +++++++++++++++++++++++++++-------------- 2 files changed, 57 insertions(+), 29 deletions(-) diff --git a/ldm/dream/image_util.py b/ldm/dream/image_util.py index fa14ec897b..55610a9bab 100644 --- a/ldm/dream/image_util.py +++ b/ldm/dream/image_util.py @@ -49,6 +49,8 @@ class InitImageResizer(): new_image = Image.new('RGB',(width,height)) new_image.paste(resized_image,((width-rw)//2,(height-rh)//2)) + print(f'>> Resized image size to {width}x{height}') + return new_image diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 6901d45774..9ec10fe5a9 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -27,7 +27,6 @@ from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.ksampler import KSampler from ldm.dream.pngwriter import PngWriter -from ldm.dream.image_util import InitImageResizer """Simplified text to image API for stable diffusion/latent diffusion @@ -261,16 +260,9 @@ class T2I: assert ( 0.0 <= strength <= 1.0 ), 'can only work with strength in [0.0, 1.0]' - w, h = map( - lambda x: x - x % 64, (width, height) - ) # resize to integer multiple of 64 - if h != height or w != width: - print( - f'Height and width must be multiples of 64. Resizing to {h}x{w}.' - ) - height = h - width = w + if not(width == self.width and height == self.height): + width, height, _ = self._resolution_check(width, height, log=True) scope = autocast if self.precision == 'autocast' else nullcontext @@ -352,7 +344,7 @@ class T2I: image_callback(image, seed) else: image_callback(image, seed, upscaled=True) - else: # no callback passed, so we simply replace old image with rescaled one + else: # no callback passed, so we simply replace old image with rescaled one result[0] = image except KeyboardInterrupt: @@ -434,7 +426,7 @@ class T2I: width, height, strength, - callback, # Currently not implemented for img2img + callback, # Currently not implemented for img2img ): """ An infinite iterator of images from the prompt and the initial image @@ -443,13 +435,13 @@ class T2I: # PLMS sampler not supported yet, so ignore previous sampler if self.sampler_name != 'ddim': print( - f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler" + f"sampler '{self.sampler_name}' is not yet supported. Using DDIM sampler" ) sampler = DDIMSampler(self.model, device=self.device) else: sampler = self.sampler - init_image = self._load_img(init_img,width,height).to(self.device) + init_image = self._load_img(init_img, width, height).to(self.device) with precision_scope(self.device.type): init_latent = self.model.get_first_stage_encoding( self.model.encode_first_stage(init_image) @@ -512,7 +504,8 @@ class T2I: x_samples = self.model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) if len(x_samples) != 1: - raise Exception(f'expected to get a single image, but got {len(x_samples)}') + raise Exception( + f'expected to get a single image, but got {len(x_samples)}') x_sample = 255.0 * rearrange( x_samples[0].cpu().numpy(), 'c h w -> h w c' ) @@ -547,8 +540,9 @@ class T2I: self.model.cond_stage_model.device = self.device except AttributeError: import traceback - print('Error loading model. Only the CUDA backend is supported',file=sys.stderr) - print(traceback.format_exc(),file=sys.stderr) + print( + 'Error loading model. Only the CUDA backend is supported', file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) raise SystemExit self._set_sampler() @@ -608,10 +602,26 @@ class T2I: print(f'image path = {path}, cwd = {os.getcwd()}') with Image.open(path) as img: image = img.convert('RGB') - print(f'loaded input image of size {image.width}x{image.height} from {path}') + print( + f'loaded input image of size {image.width}x{image.height} from {path}') - image = InitImageResizer(image).resize(width,height) - print(f'resized input image to size {image.width}x{image.height}') + from ldm.dream.image_util import InitImageResizer + if width == self.width and height == self.height: + new_image_width, new_image_height, resize_needed = self._resolution_check( + image.width, image.height) + else: + if height == self.height: + new_image_width, new_image_height, resize_needed = self._resolution_check( + width, image.height) + if width == self.width: + new_image_width, new_image_height, resize_needed = self._resolution_check( + image.width, height) + else: + image = InitImageResizer(image).resize(width, height) + resize_needed=False + if resize_needed: + image = InitImageResizer(image).resize( + new_image_width, new_image_height) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) @@ -635,7 +645,7 @@ class T2I: prompt = text[:idx] remaining -= idx # remove from main text - text = text[idx + 1 :] + text = text[idx + 1:] # find value for weight if ' ' in text: idx = text.index(' ') # first occurence @@ -653,7 +663,7 @@ class T2I: weight = 1.0 # remove from main text remaining -= idx - text = text[idx + 1 :] + text = text[idx + 1:] # append the sub-prompt and its weight prompts.append(prompt) weights.append(weight) @@ -664,9 +674,9 @@ class T2I: weights.append(1.0) remaining = 0 return prompts, weights - - # shows how the prompt is tokenized - # usually tokens have '' to indicate end-of-word, + + # shows how the prompt is tokenized + # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' def _log_tokenization(self, text): if not self.log_tokenization: @@ -676,15 +686,31 @@ class T2I: discarded = "" usedTokens = 0 totalTokens = len(tokens) - for i in range(0,totalTokens): - token = tokens[i].replace('',' ') + for i in range(0, totalTokens): + token = tokens[i].replace('', ' ') # alternate color s = (usedTokens % 6) + 1 if i < self.model.cond_stage_model.max_length: tokenized = tokenized + f"\x1b[0;3{s};40m{token}" usedTokens += 1 - else: # over max token length + else: # over max token length discarded = discarded + f"\x1b[0;3{s};40m{token}" print(f"\nTokens ({usedTokens}):\n{tokenized}\x1b[0m") if discarded != "": - print(f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m") + print( + f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m") + + def _resolution_check(self, width, height, log=False): + resize_needed = False + w, h = map( + lambda x: x - x % 64, (width, height) + ) # resize to integer multiple of 64 + if h != height or w != width: + if log: + print( + f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}' + ) + height = h + width = w + resize_needed = True + return width, height, resize_needed