mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'bakkot-more-refactor' into main
This commit is contained in:
commit
23fb4a72bb
268
ldm/simplet2i.py
268
ldm/simplet2i.py
@ -52,7 +52,7 @@ t2i = T2I(model = <path> // models/ldm/stable-diffusion-v1/model.ck
|
|||||||
# do the slow model initialization
|
# do the slow model initialization
|
||||||
t2i.load_model()
|
t2i.load_model()
|
||||||
|
|
||||||
# Do the fast inference & image generation. Any options passed here
|
# Do the fast inference & image generation. Any options passed here
|
||||||
# override the default values assigned during class initialization
|
# override the default values assigned during class initialization
|
||||||
# Will call load_model() if the model was not previously loaded and so
|
# Will call load_model() if the model was not previously loaded and so
|
||||||
# may be slow at first.
|
# may be slow at first.
|
||||||
@ -70,7 +70,7 @@ results = t2i.prompt2png(prompt = "an astronaut riding a horse",
|
|||||||
outdir = "./outputs/,
|
outdir = "./outputs/,
|
||||||
iterations = 3,
|
iterations = 3,
|
||||||
init_img = "./sketches/horse+rider.png")
|
init_img = "./sketches/horse+rider.png")
|
||||||
|
|
||||||
for row in results:
|
for row in results:
|
||||||
print(f'filename={row[0]}')
|
print(f'filename={row[0]}')
|
||||||
print(f'seed ={row[1]}')
|
print(f'seed ={row[1]}')
|
||||||
@ -181,7 +181,7 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
outdir = kwargs.get('outdir','outputs/img-samples')
|
outdir = kwargs.get('outdir','outputs/img-samples')
|
||||||
assert 'init_img' in kwargs,'call to img2img() must include the init_img argument'
|
assert 'init_img' in kwargs,'call to img2img() must include the init_img argument'
|
||||||
return self.prompt2png(prompt,outdir,**kwargs)
|
return self.prompt2png(prompt,outdir,**kwargs)
|
||||||
|
|
||||||
def prompt2image(self,
|
def prompt2image(self,
|
||||||
# these are common
|
# these are common
|
||||||
prompt,
|
prompt,
|
||||||
@ -216,10 +216,10 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
|
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||||
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
||||||
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
|
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
|
||||||
callback // a function or method that will be called each time an image is generated
|
image_callback // a function or method that will be called each time an image is generated
|
||||||
|
|
||||||
To use the callback, define a function of method that receives two arguments, an Image object
|
To use the callback, define a function of method that receives two arguments, an Image object
|
||||||
and the seed. You can then do whatever you like with the image, including converting it to
|
and the seed. You can then do whatever you like with the image, including converting it to
|
||||||
different formats and manipulating it. For example:
|
different formats and manipulating it. For example:
|
||||||
|
|
||||||
def process_image(image,seed):
|
def process_image(image,seed):
|
||||||
@ -249,116 +249,86 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
height = h
|
height = h
|
||||||
width = w
|
width = w
|
||||||
|
|
||||||
data = [batch_size * [prompt]]
|
|
||||||
scope = autocast if self.precision=="autocast" else nullcontext
|
scope = autocast if self.precision=="autocast" else nullcontext
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
if init_img:
|
results = list()
|
||||||
assert os.path.exists(init_img),f'{init_img}: File not found'
|
|
||||||
results = self._img2img(prompt,
|
|
||||||
data=data,precision_scope=scope,
|
|
||||||
batch_size=batch_size,iterations=iterations,
|
|
||||||
steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
|
|
||||||
skip_normalize=skip_normalize,
|
|
||||||
init_img=init_img,strength=strength,variants=variants,
|
|
||||||
callback=image_callback)
|
|
||||||
else:
|
|
||||||
results = self._txt2img(prompt,
|
|
||||||
data=data,precision_scope=scope,
|
|
||||||
batch_size=batch_size,iterations=iterations,
|
|
||||||
steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
|
|
||||||
skip_normalize=skip_normalize,
|
|
||||||
width=width,height=height,
|
|
||||||
callback=image_callback)
|
|
||||||
toc = time.time()
|
|
||||||
print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic))
|
|
||||||
return results
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _txt2img(self,prompt,
|
|
||||||
data,precision_scope,
|
|
||||||
batch_size,iterations,
|
|
||||||
steps,seed,cfg_scale,ddim_eta,
|
|
||||||
skip_normalize,
|
|
||||||
width,height,
|
|
||||||
callback): # the callback is called each time a new Image is generated
|
|
||||||
"""
|
|
||||||
Generate an image from the prompt, writing iteration images into the outdir
|
|
||||||
The output is a list of lists in the format: [[image1,seed1], [image2,seed2],...]
|
|
||||||
"""
|
|
||||||
|
|
||||||
sampler = self.sampler
|
|
||||||
images = list()
|
|
||||||
image_count = 0
|
|
||||||
|
|
||||||
# Gawd. Too many levels of indent here. Need to refactor into smaller routines!
|
|
||||||
try:
|
try:
|
||||||
with precision_scope(self.device.type), self.model.ema_scope():
|
if init_img:
|
||||||
all_samples = list()
|
assert os.path.exists(init_img),f'{init_img}: File not found'
|
||||||
|
images_iterator = self._img2img(prompt,
|
||||||
|
precision_scope=scope,
|
||||||
|
batch_size=batch_size,
|
||||||
|
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
|
||||||
|
skip_normalize=skip_normalize,
|
||||||
|
init_img=init_img,strength=strength)
|
||||||
|
else:
|
||||||
|
images_iterator = self._txt2img(prompt,
|
||||||
|
precision_scope=scope,
|
||||||
|
batch_size=batch_size,
|
||||||
|
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
|
||||||
|
skip_normalize=skip_normalize,
|
||||||
|
width=width,height=height)
|
||||||
|
|
||||||
|
with scope(self.device.type), self.model.ema_scope():
|
||||||
for n in trange(iterations, desc="Sampling"):
|
for n in trange(iterations, desc="Sampling"):
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
iter_images = next(images_iterator)
|
||||||
uc = None
|
for image in iter_images:
|
||||||
if cfg_scale != 1.0:
|
results.append([image, seed])
|
||||||
uc = self.model.get_learned_conditioning(batch_size * [""])
|
if image_callback is not None:
|
||||||
if isinstance(prompts, tuple):
|
image_callback(image,seed)
|
||||||
prompts = list(prompts)
|
|
||||||
|
|
||||||
# weighted sub-prompts
|
|
||||||
subprompts,weights = T2I._split_weighted_subprompts(prompts[0])
|
|
||||||
if len(subprompts) > 1:
|
|
||||||
# i dont know if this is correct.. but it works
|
|
||||||
c = torch.zeros_like(uc)
|
|
||||||
# get total weight for normalizing
|
|
||||||
totalWeight = sum(weights)
|
|
||||||
# normalize each "sub prompt" and add it
|
|
||||||
for i in range(0,len(subprompts)):
|
|
||||||
weight = weights[i]
|
|
||||||
if not skip_normalize:
|
|
||||||
weight = weight / totalWeight
|
|
||||||
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
|
||||||
else: # just standard 1 prompt
|
|
||||||
c = self.model.get_learned_conditioning(prompts)
|
|
||||||
|
|
||||||
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
|
||||||
samples_ddim, _ = sampler.sample(S=steps,
|
|
||||||
conditioning=c,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shape=shape,
|
|
||||||
verbose=False,
|
|
||||||
unconditional_guidance_scale=cfg_scale,
|
|
||||||
unconditional_conditioning=uc,
|
|
||||||
eta=ddim_eta)
|
|
||||||
|
|
||||||
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
|
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
for x_sample in x_samples_ddim:
|
|
||||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
|
||||||
image = Image.fromarray(x_sample.astype(np.uint8))
|
|
||||||
images.append([image,seed])
|
|
||||||
if callback is not None:
|
|
||||||
callback(image,seed)
|
|
||||||
|
|
||||||
seed = self._new_seed()
|
seed = self._new_seed()
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print('*interrupted*')
|
print('*interrupted*')
|
||||||
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
|
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print(str(e))
|
print(str(e))
|
||||||
|
print('Are you sure your system has an adequate NVIDIA GPU?')
|
||||||
|
|
||||||
|
toc = time.time()
|
||||||
|
print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic))
|
||||||
|
return results
|
||||||
|
|
||||||
return images
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _img2img(self,prompt,
|
def _txt2img(self,
|
||||||
data,precision_scope,
|
prompt,
|
||||||
batch_size,iterations,
|
precision_scope,
|
||||||
steps,seed,cfg_scale,ddim_eta,
|
batch_size,
|
||||||
|
steps,cfg_scale,ddim_eta,
|
||||||
skip_normalize,
|
skip_normalize,
|
||||||
init_img,strength,variants,
|
width,height):
|
||||||
callback):
|
|
||||||
"""
|
"""
|
||||||
Generate an image from the prompt and the initial image, writing iteration images into the outdir
|
An infinite iterator of images from the prompt.
|
||||||
The output is a list of lists in the format: [[image,seed1], [image,seed2],...]
|
"""
|
||||||
|
|
||||||
|
sampler = self.sampler
|
||||||
|
|
||||||
|
while True:
|
||||||
|
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
|
||||||
|
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
||||||
|
samples, _ = sampler.sample(S=steps,
|
||||||
|
conditioning=c,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shape=shape,
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=cfg_scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
eta=ddim_eta)
|
||||||
|
yield self._samples_to_images(samples)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _img2img(self,
|
||||||
|
prompt,
|
||||||
|
precision_scope,
|
||||||
|
batch_size,
|
||||||
|
steps,cfg_scale,ddim_eta,
|
||||||
|
skip_normalize,
|
||||||
|
init_img,strength):
|
||||||
|
"""
|
||||||
|
An infinite iterator of images from the prompt and the initial image
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# PLMS sampler not supported yet, so ignore previous sampler
|
# PLMS sampler not supported yet, so ignore previous sampler
|
||||||
@ -374,62 +344,50 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space
|
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space
|
||||||
|
|
||||||
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
||||||
|
|
||||||
t_enc = int(strength * steps)
|
t_enc = int(strength * steps)
|
||||||
# print(f"target t_enc is {t_enc} steps")
|
# print(f"target t_enc is {t_enc} steps")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
|
||||||
|
|
||||||
|
# encode (scaled latent)
|
||||||
|
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
|
||||||
|
# decode it
|
||||||
|
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
|
||||||
|
unconditional_conditioning=uc,)
|
||||||
|
yield self._samples_to_images(samples)
|
||||||
|
|
||||||
|
# TODO: does this actually need to run every loop? does anything in it vary by random seed?
|
||||||
|
def _get_uc_and_c(self, prompt, batch_size, skip_normalize):
|
||||||
|
|
||||||
|
uc = self.model.get_learned_conditioning(batch_size * [""])
|
||||||
|
|
||||||
|
# weighted sub-prompts
|
||||||
|
subprompts,weights = T2I._split_weighted_subprompts(prompt)
|
||||||
|
if len(subprompts) > 1:
|
||||||
|
# i dont know if this is correct.. but it works
|
||||||
|
c = torch.zeros_like(uc)
|
||||||
|
# get total weight for normalizing
|
||||||
|
totalWeight = sum(weights)
|
||||||
|
# normalize each "sub prompt" and add it
|
||||||
|
for i in range(0,len(subprompts)):
|
||||||
|
weight = weights[i]
|
||||||
|
if not skip_normalize:
|
||||||
|
weight = weight / totalWeight
|
||||||
|
c = torch.add(c, self.model.get_learned_conditioning(batch_size * [subprompts[i]]), alpha=weight)
|
||||||
|
else: # just standard 1 prompt
|
||||||
|
c = self.model.get_learned_conditioning(batch_size * [prompt])
|
||||||
|
return (uc, c)
|
||||||
|
|
||||||
|
def _samples_to_images(self, samples):
|
||||||
|
x_samples = self.model.decode_first_stage(samples)
|
||||||
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
images = list()
|
images = list()
|
||||||
|
for x_sample in x_samples:
|
||||||
try:
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
with precision_scope(self.device.type), self.model.ema_scope():
|
image = Image.fromarray(x_sample.astype(np.uint8))
|
||||||
all_samples = list()
|
images.append(image)
|
||||||
for n in trange(iterations, desc="Sampling"):
|
|
||||||
seed_everything(seed)
|
|
||||||
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
|
||||||
uc = None
|
|
||||||
if cfg_scale != 1.0:
|
|
||||||
uc = self.model.get_learned_conditioning(batch_size * [""])
|
|
||||||
if isinstance(prompts, tuple):
|
|
||||||
prompts = list(prompts)
|
|
||||||
|
|
||||||
# weighted sub-prompts
|
|
||||||
subprompts,weights = T2I._split_weighted_subprompts(prompts[0])
|
|
||||||
if len(subprompts) > 1:
|
|
||||||
# i dont know if this is correct.. but it works
|
|
||||||
c = torch.zeros_like(uc)
|
|
||||||
# get total weight for normalizing
|
|
||||||
totalWeight = sum(weights)
|
|
||||||
# normalize each "sub prompt" and add it
|
|
||||||
for i in range(0,len(subprompts)):
|
|
||||||
weight = weights[i]
|
|
||||||
if not skip_normalize:
|
|
||||||
weight = weight / totalWeight
|
|
||||||
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
|
||||||
else: # just standard 1 prompt
|
|
||||||
c = self.model.get_learned_conditioning(prompts)
|
|
||||||
|
|
||||||
# encode (scaled latent)
|
|
||||||
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
|
|
||||||
# decode it
|
|
||||||
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
|
|
||||||
unconditional_conditioning=uc,)
|
|
||||||
|
|
||||||
x_samples = self.model.decode_first_stage(samples)
|
|
||||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
|
|
||||||
for x_sample in x_samples:
|
|
||||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
|
||||||
image = Image.fromarray(x_sample.astype(np.uint8))
|
|
||||||
images.append([image,seed])
|
|
||||||
if callback is not None:
|
|
||||||
callback(image,seed)
|
|
||||||
seed = self._new_seed()
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print('*interrupted*')
|
|
||||||
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
|
|
||||||
except RuntimeError as e:
|
|
||||||
print("Oops! A runtime error has occurred. If this is unexpected, please copy-and-paste this stack trace and post it as an Issue to http://github.com/lstein/stable-diffusion")
|
|
||||||
traceback.print_exc()
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
def _new_seed(self):
|
def _new_seed(self):
|
||||||
@ -476,7 +434,7 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
print(msg)
|
print(msg)
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def _load_model_from_config(self, config, ckpt):
|
def _load_model_from_config(self, config, ckpt):
|
||||||
print(f"Loading model from {ckpt}")
|
print(f"Loading model from {ckpt}")
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
@ -507,7 +465,7 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
|
|
||||||
def _split_weighted_subprompts(text):
|
def _split_weighted_subprompts(text):
|
||||||
"""
|
"""
|
||||||
grabs all text up to the first occurrence of ':'
|
grabs all text up to the first occurrence of ':'
|
||||||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||||
if ':' has no value defined, defaults to 1.0
|
if ':' has no value defined, defaults to 1.0
|
||||||
repeats until no text remaining
|
repeats until no text remaining
|
||||||
@ -523,7 +481,7 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
remaining -= idx
|
remaining -= idx
|
||||||
# remove from main text
|
# remove from main text
|
||||||
text = text[idx+1:]
|
text = text[idx+1:]
|
||||||
# find value for weight
|
# find value for weight
|
||||||
if " " in text:
|
if " " in text:
|
||||||
idx = text.index(" ") # first occurence
|
idx = text.index(" ") # first occurence
|
||||||
else: # no space, read to end
|
else: # no space, read to end
|
||||||
|
@ -252,7 +252,7 @@ def create_argv_parser():
|
|||||||
'-o',
|
'-o',
|
||||||
type=str,
|
type=str,
|
||||||
default="outputs/img-samples",
|
default="outputs/img-samples",
|
||||||
help="directory in which to place generated images and a log of prompts and seeds")
|
help="directory in which to place generated images and a log of prompts and seeds (outputs/img-samples")
|
||||||
parser.add_argument('--embedding_path',
|
parser.add_argument('--embedding_path',
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to a pre-trained embedding manager checkpoint - can only be set on command line")
|
help="Path to a pre-trained embedding manager checkpoint - can only be set on command line")
|
||||||
|
Loading…
Reference in New Issue
Block a user