mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'resolution-checker' of https://github.com/blessedcoolant/stable-diffusion into main
This commit is contained in:
commit
0be2351c97
@ -50,6 +50,8 @@ class InitImageResizer():
|
|||||||
new_image = Image.new('RGB',(width,height))
|
new_image = Image.new('RGB',(width,height))
|
||||||
new_image.paste(resized_image,((width-rw)//2,(height-rh)//2))
|
new_image.paste(resized_image,((width-rw)//2,(height-rh)//2))
|
||||||
|
|
||||||
|
print(f'>> Resized image size to {width}x{height}')
|
||||||
|
|
||||||
return new_image
|
return new_image
|
||||||
|
|
||||||
def make_grid(image_list, rows=None, cols=None):
|
def make_grid(image_list, rows=None, cols=None):
|
||||||
|
@ -266,16 +266,9 @@ class T2I:
|
|||||||
assert (
|
assert (
|
||||||
0.0 <= strength <= 1.0
|
0.0 <= strength <= 1.0
|
||||||
), 'can only work with strength in [0.0, 1.0]'
|
), 'can only work with strength in [0.0, 1.0]'
|
||||||
w, h = map(
|
|
||||||
lambda x: x - x % 64, (width, height)
|
|
||||||
) # resize to integer multiple of 64
|
|
||||||
|
|
||||||
if h != height or w != width:
|
if not(width == self.width and height == self.height):
|
||||||
print(
|
width, height, _ = self._resolution_check(width, height, log=True)
|
||||||
f'Height and width must be multiples of 64. Resizing to {h}x{w}.'
|
|
||||||
)
|
|
||||||
height = h
|
|
||||||
width = w
|
|
||||||
|
|
||||||
scope = autocast if self.precision == 'autocast' else nullcontext
|
scope = autocast if self.precision == 'autocast' else nullcontext
|
||||||
|
|
||||||
@ -353,8 +346,11 @@ class T2I:
|
|||||||
f'Error running RealESRGAN - Your image was not upscaled.\n{e}'
|
f'Error running RealESRGAN - Your image was not upscaled.\n{e}'
|
||||||
)
|
)
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
image_callback(image, seed, upscaled=True)
|
if save_original:
|
||||||
else: # no callback passed, so we simply replace old image with rescaled one
|
image_callback(image, seed)
|
||||||
|
else:
|
||||||
|
image_callback(image, seed, upscaled=True)
|
||||||
|
else: # no callback passed, so we simply replace old image with rescaled one
|
||||||
result[0] = image
|
result[0] = image
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@ -436,7 +432,7 @@ class T2I:
|
|||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
strength,
|
strength,
|
||||||
callback, # Currently not implemented for img2img
|
callback, # Currently not implemented for img2img
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
An infinite iterator of images from the prompt and the initial image
|
An infinite iterator of images from the prompt and the initial image
|
||||||
@ -445,13 +441,13 @@ class T2I:
|
|||||||
# PLMS sampler not supported yet, so ignore previous sampler
|
# PLMS sampler not supported yet, so ignore previous sampler
|
||||||
if self.sampler_name != 'ddim':
|
if self.sampler_name != 'ddim':
|
||||||
print(
|
print(
|
||||||
f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler"
|
f"sampler '{self.sampler_name}' is not yet supported. Using DDIM sampler"
|
||||||
)
|
)
|
||||||
sampler = DDIMSampler(self.model, device=self.device)
|
sampler = DDIMSampler(self.model, device=self.device)
|
||||||
else:
|
else:
|
||||||
sampler = self.sampler
|
sampler = self.sampler
|
||||||
|
|
||||||
init_image = self._load_img(init_img,width,height).to(self.device)
|
init_image = self._load_img(init_img, width, height).to(self.device)
|
||||||
with precision_scope(self.device.type):
|
with precision_scope(self.device.type):
|
||||||
init_latent = self.model.get_first_stage_encoding(
|
init_latent = self.model.get_first_stage_encoding(
|
||||||
self.model.encode_first_stage(init_image)
|
self.model.encode_first_stage(init_image)
|
||||||
@ -514,7 +510,8 @@ class T2I:
|
|||||||
x_samples = self.model.decode_first_stage(samples)
|
x_samples = self.model.decode_first_stage(samples)
|
||||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
if len(x_samples) != 1:
|
if len(x_samples) != 1:
|
||||||
raise Exception(f'expected to get a single image, but got {len(x_samples)}')
|
raise Exception(
|
||||||
|
f'expected to get a single image, but got {len(x_samples)}')
|
||||||
x_sample = 255.0 * rearrange(
|
x_sample = 255.0 * rearrange(
|
||||||
x_samples[0].cpu().numpy(), 'c h w -> h w c'
|
x_samples[0].cpu().numpy(), 'c h w -> h w c'
|
||||||
)
|
)
|
||||||
@ -545,8 +542,9 @@ class T2I:
|
|||||||
self.model.cond_stage_model.device = self.device
|
self.model.cond_stage_model.device = self.device
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
import traceback
|
import traceback
|
||||||
print('Error loading model. Only the CUDA backend is supported',file=sys.stderr)
|
print(
|
||||||
print(traceback.format_exc(),file=sys.stderr)
|
'Error loading model. Only the CUDA backend is supported', file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
raise SystemExit
|
raise SystemExit
|
||||||
|
|
||||||
self._set_sampler()
|
self._set_sampler()
|
||||||
@ -606,10 +604,26 @@ class T2I:
|
|||||||
print(f'image path = {path}, cwd = {os.getcwd()}')
|
print(f'image path = {path}, cwd = {os.getcwd()}')
|
||||||
with Image.open(path) as img:
|
with Image.open(path) as img:
|
||||||
image = img.convert('RGB')
|
image = img.convert('RGB')
|
||||||
print(f'loaded input image of size {image.width}x{image.height} from {path}')
|
print(
|
||||||
|
f'loaded input image of size {image.width}x{image.height} from {path}')
|
||||||
|
|
||||||
image = InitImageResizer(image).resize(width,height)
|
from ldm.dream.image_util import InitImageResizer
|
||||||
print(f'resized input image to size {image.width}x{image.height}')
|
if width == self.width and height == self.height:
|
||||||
|
new_image_width, new_image_height, resize_needed = self._resolution_check(
|
||||||
|
image.width, image.height)
|
||||||
|
else:
|
||||||
|
if height == self.height:
|
||||||
|
new_image_width, new_image_height, resize_needed = self._resolution_check(
|
||||||
|
width, image.height)
|
||||||
|
if width == self.width:
|
||||||
|
new_image_width, new_image_height, resize_needed = self._resolution_check(
|
||||||
|
image.width, height)
|
||||||
|
else:
|
||||||
|
image = InitImageResizer(image).resize(width, height)
|
||||||
|
resize_needed=False
|
||||||
|
if resize_needed:
|
||||||
|
image = InitImageResizer(image).resize(
|
||||||
|
new_image_width, new_image_height)
|
||||||
|
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
@ -633,7 +647,7 @@ class T2I:
|
|||||||
prompt = text[:idx]
|
prompt = text[:idx]
|
||||||
remaining -= idx
|
remaining -= idx
|
||||||
# remove from main text
|
# remove from main text
|
||||||
text = text[idx + 1 :]
|
text = text[idx + 1:]
|
||||||
# find value for weight
|
# find value for weight
|
||||||
if ' ' in text:
|
if ' ' in text:
|
||||||
idx = text.index(' ') # first occurence
|
idx = text.index(' ') # first occurence
|
||||||
@ -651,7 +665,7 @@ class T2I:
|
|||||||
weight = 1.0
|
weight = 1.0
|
||||||
# remove from main text
|
# remove from main text
|
||||||
remaining -= idx
|
remaining -= idx
|
||||||
text = text[idx + 1 :]
|
text = text[idx + 1:]
|
||||||
# append the sub-prompt and its weight
|
# append the sub-prompt and its weight
|
||||||
prompts.append(prompt)
|
prompts.append(prompt)
|
||||||
weights.append(weight)
|
weights.append(weight)
|
||||||
@ -662,9 +676,9 @@ class T2I:
|
|||||||
weights.append(1.0)
|
weights.append(1.0)
|
||||||
remaining = 0
|
remaining = 0
|
||||||
return prompts, weights
|
return prompts, weights
|
||||||
|
|
||||||
# shows how the prompt is tokenized
|
# shows how the prompt is tokenized
|
||||||
# usually tokens have '</w>' to indicate end-of-word,
|
# usually tokens have '</w>' to indicate end-of-word,
|
||||||
# but for readability it has been replaced with ' '
|
# but for readability it has been replaced with ' '
|
||||||
def _log_tokenization(self, text):
|
def _log_tokenization(self, text):
|
||||||
if not self.log_tokenization:
|
if not self.log_tokenization:
|
||||||
@ -674,15 +688,31 @@ class T2I:
|
|||||||
discarded = ""
|
discarded = ""
|
||||||
usedTokens = 0
|
usedTokens = 0
|
||||||
totalTokens = len(tokens)
|
totalTokens = len(tokens)
|
||||||
for i in range(0,totalTokens):
|
for i in range(0, totalTokens):
|
||||||
token = tokens[i].replace('</w>',' ')
|
token = tokens[i].replace('</w>', ' ')
|
||||||
# alternate color
|
# alternate color
|
||||||
s = (usedTokens % 6) + 1
|
s = (usedTokens % 6) + 1
|
||||||
if i < self.model.cond_stage_model.max_length:
|
if i < self.model.cond_stage_model.max_length:
|
||||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||||
usedTokens += 1
|
usedTokens += 1
|
||||||
else: # over max token length
|
else: # over max token length
|
||||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||||
print(f"\nTokens ({usedTokens}):\n{tokenized}\x1b[0m")
|
print(f"\nTokens ({usedTokens}):\n{tokenized}\x1b[0m")
|
||||||
if discarded != "":
|
if discarded != "":
|
||||||
print(f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m")
|
print(
|
||||||
|
f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m")
|
||||||
|
|
||||||
|
def _resolution_check(self, width, height, log=False):
|
||||||
|
resize_needed = False
|
||||||
|
w, h = map(
|
||||||
|
lambda x: x - x % 64, (width, height)
|
||||||
|
) # resize to integer multiple of 64
|
||||||
|
if h != height or w != width:
|
||||||
|
if log:
|
||||||
|
print(
|
||||||
|
f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}'
|
||||||
|
)
|
||||||
|
height = h
|
||||||
|
width = w
|
||||||
|
resize_needed = True
|
||||||
|
return width, height, resize_needed
|
||||||
|
Loading…
x
Reference in New Issue
Block a user