Merge branch 'resolution-checker' of https://github.com/blessedcoolant/stable-diffusion into main

This commit is contained in:
Lincoln Stein 2022-08-31 14:43:17 -04:00
commit 0be2351c97
2 changed files with 61 additions and 29 deletions

View File

@ -50,6 +50,8 @@ class InitImageResizer():
new_image = Image.new('RGB',(width,height)) new_image = Image.new('RGB',(width,height))
new_image.paste(resized_image,((width-rw)//2,(height-rh)//2)) new_image.paste(resized_image,((width-rw)//2,(height-rh)//2))
print(f'>> Resized image size to {width}x{height}')
return new_image return new_image
def make_grid(image_list, rows=None, cols=None): def make_grid(image_list, rows=None, cols=None):

View File

@ -266,16 +266,9 @@ class T2I:
assert ( assert (
0.0 <= strength <= 1.0 0.0 <= strength <= 1.0
), 'can only work with strength in [0.0, 1.0]' ), 'can only work with strength in [0.0, 1.0]'
w, h = map(
lambda x: x - x % 64, (width, height)
) # resize to integer multiple of 64
if h != height or w != width: if not(width == self.width and height == self.height):
print( width, height, _ = self._resolution_check(width, height, log=True)
f'Height and width must be multiples of 64. Resizing to {h}x{w}.'
)
height = h
width = w
scope = autocast if self.precision == 'autocast' else nullcontext scope = autocast if self.precision == 'autocast' else nullcontext
@ -353,8 +346,11 @@ class T2I:
f'Error running RealESRGAN - Your image was not upscaled.\n{e}' f'Error running RealESRGAN - Your image was not upscaled.\n{e}'
) )
if image_callback is not None: if image_callback is not None:
image_callback(image, seed, upscaled=True) if save_original:
else: # no callback passed, so we simply replace old image with rescaled one image_callback(image, seed)
else:
image_callback(image, seed, upscaled=True)
else: # no callback passed, so we simply replace old image with rescaled one
result[0] = image result[0] = image
except KeyboardInterrupt: except KeyboardInterrupt:
@ -436,7 +432,7 @@ class T2I:
width, width,
height, height,
strength, strength,
callback, # Currently not implemented for img2img callback, # Currently not implemented for img2img
): ):
""" """
An infinite iterator of images from the prompt and the initial image An infinite iterator of images from the prompt and the initial image
@ -445,13 +441,13 @@ class T2I:
# PLMS sampler not supported yet, so ignore previous sampler # PLMS sampler not supported yet, so ignore previous sampler
if self.sampler_name != 'ddim': if self.sampler_name != 'ddim':
print( print(
f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler" f"sampler '{self.sampler_name}' is not yet supported. Using DDIM sampler"
) )
sampler = DDIMSampler(self.model, device=self.device) sampler = DDIMSampler(self.model, device=self.device)
else: else:
sampler = self.sampler sampler = self.sampler
init_image = self._load_img(init_img,width,height).to(self.device) init_image = self._load_img(init_img, width, height).to(self.device)
with precision_scope(self.device.type): with precision_scope(self.device.type):
init_latent = self.model.get_first_stage_encoding( init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image) self.model.encode_first_stage(init_image)
@ -514,7 +510,8 @@ class T2I:
x_samples = self.model.decode_first_stage(samples) x_samples = self.model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if len(x_samples) != 1: if len(x_samples) != 1:
raise Exception(f'expected to get a single image, but got {len(x_samples)}') raise Exception(
f'expected to get a single image, but got {len(x_samples)}')
x_sample = 255.0 * rearrange( x_sample = 255.0 * rearrange(
x_samples[0].cpu().numpy(), 'c h w -> h w c' x_samples[0].cpu().numpy(), 'c h w -> h w c'
) )
@ -545,8 +542,9 @@ class T2I:
self.model.cond_stage_model.device = self.device self.model.cond_stage_model.device = self.device
except AttributeError: except AttributeError:
import traceback import traceback
print('Error loading model. Only the CUDA backend is supported',file=sys.stderr) print(
print(traceback.format_exc(),file=sys.stderr) 'Error loading model. Only the CUDA backend is supported', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
raise SystemExit raise SystemExit
self._set_sampler() self._set_sampler()
@ -606,10 +604,26 @@ class T2I:
print(f'image path = {path}, cwd = {os.getcwd()}') print(f'image path = {path}, cwd = {os.getcwd()}')
with Image.open(path) as img: with Image.open(path) as img:
image = img.convert('RGB') image = img.convert('RGB')
print(f'loaded input image of size {image.width}x{image.height} from {path}') print(
f'loaded input image of size {image.width}x{image.height} from {path}')
image = InitImageResizer(image).resize(width,height) from ldm.dream.image_util import InitImageResizer
print(f'resized input image to size {image.width}x{image.height}') if width == self.width and height == self.height:
new_image_width, new_image_height, resize_needed = self._resolution_check(
image.width, image.height)
else:
if height == self.height:
new_image_width, new_image_height, resize_needed = self._resolution_check(
width, image.height)
if width == self.width:
new_image_width, new_image_height, resize_needed = self._resolution_check(
image.width, height)
else:
image = InitImageResizer(image).resize(width, height)
resize_needed=False
if resize_needed:
image = InitImageResizer(image).resize(
new_image_width, new_image_height)
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
@ -633,7 +647,7 @@ class T2I:
prompt = text[:idx] prompt = text[:idx]
remaining -= idx remaining -= idx
# remove from main text # remove from main text
text = text[idx + 1 :] text = text[idx + 1:]
# find value for weight # find value for weight
if ' ' in text: if ' ' in text:
idx = text.index(' ') # first occurence idx = text.index(' ') # first occurence
@ -651,7 +665,7 @@ class T2I:
weight = 1.0 weight = 1.0
# remove from main text # remove from main text
remaining -= idx remaining -= idx
text = text[idx + 1 :] text = text[idx + 1:]
# append the sub-prompt and its weight # append the sub-prompt and its weight
prompts.append(prompt) prompts.append(prompt)
weights.append(weight) weights.append(weight)
@ -662,9 +676,9 @@ class T2I:
weights.append(1.0) weights.append(1.0)
remaining = 0 remaining = 0
return prompts, weights return prompts, weights
# shows how the prompt is tokenized # shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word, # usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' ' # but for readability it has been replaced with ' '
def _log_tokenization(self, text): def _log_tokenization(self, text):
if not self.log_tokenization: if not self.log_tokenization:
@ -674,15 +688,31 @@ class T2I:
discarded = "" discarded = ""
usedTokens = 0 usedTokens = 0
totalTokens = len(tokens) totalTokens = len(tokens)
for i in range(0,totalTokens): for i in range(0, totalTokens):
token = tokens[i].replace('</w>',' ') token = tokens[i].replace('</w>', ' ')
# alternate color # alternate color
s = (usedTokens % 6) + 1 s = (usedTokens % 6) + 1
if i < self.model.cond_stage_model.max_length: if i < self.model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}" tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1 usedTokens += 1
else: # over max token length else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}" discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\nTokens ({usedTokens}):\n{tokenized}\x1b[0m") print(f"\nTokens ({usedTokens}):\n{tokenized}\x1b[0m")
if discarded != "": if discarded != "":
print(f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m") print(
f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m")
def _resolution_check(self, width, height, log=False):
resize_needed = False
w, h = map(
lambda x: x - x % 64, (width, height)
) # resize to integer multiple of 64
if h != height or w != width:
if log:
print(
f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}'
)
height = h
width = w
resize_needed = True
return width, height, resize_needed