mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
code is reorganized and mostly functional. Grid needs to be brought back online, as well as naming of img2img variants (currently the variants get written but not logged)
This commit is contained in:
parent
b12955c963
commit
b978536385
@ -1,6 +1,7 @@
|
||||
'''Utilities for dealing with PNG images and their path names'''
|
||||
import os
|
||||
import atexit
|
||||
import re
|
||||
from PIL import Image,PngImagePlugin
|
||||
|
||||
# ---------------readline utilities---------------------
|
||||
@ -94,40 +95,43 @@ if readline_available:
|
||||
# -------------------image generation utils-----
|
||||
class PngWriter:
|
||||
|
||||
def __init__(self,opt):
|
||||
self.opt = opt
|
||||
self.filepath = None
|
||||
self.files_written = []
|
||||
def __init__(self,outdir,opt,prompt):
|
||||
self.outdir = outdir
|
||||
self.opt = opt
|
||||
self.prompt = prompt
|
||||
self.filepath = None
|
||||
self.files_written = []
|
||||
|
||||
def write_image(self,image,seed):
|
||||
self.filepath = self.unique_filename(self,opt,seed,self.filepath) # will increment name in some sensible way
|
||||
self.filepath = self.unique_filename(seed,self.filepath) # will increment name in some sensible way
|
||||
try:
|
||||
image.save(self.filename)
|
||||
prompt = f'{self.prompt} -S{seed}'
|
||||
self.save_image_and_prompt_to_png(image,prompt,self.filepath)
|
||||
except IOError as e:
|
||||
print(e)
|
||||
self.files_written.append([self.filepath,seed])
|
||||
|
||||
def unique_filename(self,opt,seed,previouspath):
|
||||
def unique_filename(self,seed,previouspath):
|
||||
revision = 1
|
||||
|
||||
if previouspath is None:
|
||||
# sort reverse alphabetically until we find max+1
|
||||
dirlist = sorted(os.listdir(outdir),reverse=True)
|
||||
dirlist = sorted(os.listdir(self.outdir),reverse=True)
|
||||
# find the first filename that matches our pattern or return 000000.0.png
|
||||
filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png')
|
||||
basecount = int(filename.split('.',1)[0])
|
||||
basecount += 1
|
||||
if opt.batch_size > 1:
|
||||
if self.opt.batch_size > 1:
|
||||
filename = f'{basecount:06}.{seed}.01.png'
|
||||
else:
|
||||
filename = f'{basecount:06}.{seed}.png'
|
||||
return os.path.join(outdir,filename)
|
||||
return os.path.join(self.outdir,filename)
|
||||
|
||||
else:
|
||||
basename = os.path.basename(previouspath)
|
||||
x = re.match('^(\d+)\..*\.png',basename)
|
||||
if not x:
|
||||
return self.unique_filename(opt,seed,previouspath)
|
||||
return self.unique_filename(seed,previouspath)
|
||||
|
||||
basecount = int(x.groups()[0])
|
||||
series = 0
|
||||
@ -135,9 +139,41 @@ class PngWriter:
|
||||
while not finished:
|
||||
series += 1
|
||||
filename = f'{basecount:06}.{seed}.png'
|
||||
if isbatch or os.path.exists(os.path.join(outdir,filename)):
|
||||
if self.opt.batch_size>1 or os.path.exists(os.path.join(self.outdir,filename)):
|
||||
filename = f'{basecount:06}.{seed}.{series:02}.png'
|
||||
finished = not os.path.exists(os.path.join(outdir,filename))
|
||||
return os.path.join(outdir,filename)
|
||||
finished = not os.path.exists(os.path.join(self.outdir,filename))
|
||||
return os.path.join(self.outdir,filename)
|
||||
|
||||
def save_image_and_prompt_to_png(self,image,prompt,path):
|
||||
info = PngImagePlugin.PngInfo()
|
||||
info.add_text("Dream",prompt)
|
||||
image.save(path,"PNG",pnginfo=info)
|
||||
|
||||
class PromptFormatter():
|
||||
def __init__(self,t2i,opt):
|
||||
self.t2i = t2i
|
||||
self.opt = opt
|
||||
|
||||
def normalize_prompt(self):
|
||||
'''Normalize the prompt and switches'''
|
||||
t2i = self.t2i
|
||||
opt = self.opt
|
||||
|
||||
switches = list()
|
||||
switches.append(f'"{opt.prompt}"')
|
||||
switches.append(f'-s{opt.steps or t2i.steps}')
|
||||
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
|
||||
switches.append(f'-W{opt.width or t2i.width}')
|
||||
switches.append(f'-H{opt.height or t2i.height}')
|
||||
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
||||
switches.append(f'-m{t2i.sampler_name}')
|
||||
if opt.variants:
|
||||
switches.append(f'-v{opt.variants}')
|
||||
if opt.init_img:
|
||||
switches.append(f'-I{opt.init_img}')
|
||||
if opt.strength and opt.init_img is not None:
|
||||
switches.append(f'-f{opt.strength or t2i.strength}')
|
||||
if t2i.full_precision:
|
||||
switches.append('-F')
|
||||
return ' '.join(switches)
|
||||
|
||||
|
@ -99,13 +99,13 @@ The vast majority of these arguments default to reasonable values.
|
||||
def __init__(self,
|
||||
batch_size=1,
|
||||
iterations = 1,
|
||||
grid=False,
|
||||
individual=None, # redundant
|
||||
steps=50,
|
||||
seed=None,
|
||||
cfg_scale=7.5,
|
||||
weights="models/ldm/stable-diffusion-v1/model.ckpt",
|
||||
config = "configs/stable-diffusion/v1-inference.yaml",
|
||||
width=512,
|
||||
height=512,
|
||||
sampler_name="klms",
|
||||
latent_channels=4,
|
||||
downsampling_factor=8,
|
||||
@ -121,7 +121,6 @@ The vast majority of these arguments default to reasonable values.
|
||||
self.iterations = iterations
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.grid = grid
|
||||
self.steps = steps
|
||||
self.cfg_scale = cfg_scale
|
||||
self.weights = weights
|
||||
@ -143,25 +142,26 @@ The vast majority of these arguments default to reasonable values.
|
||||
else:
|
||||
self.seed = seed
|
||||
|
||||
def generate(self,
|
||||
# these are common
|
||||
prompt,
|
||||
batch_size=None,
|
||||
iterations=None,
|
||||
steps=None,
|
||||
seed=None,
|
||||
cfg_scale=None,
|
||||
ddim_eta=None,
|
||||
skip_normalize=False,
|
||||
image_callback=None,
|
||||
# these are specific to txt2img
|
||||
width=None,
|
||||
height=None,
|
||||
# these are specific to img2img
|
||||
init_img=None,
|
||||
strength=None,
|
||||
variants=None):
|
||||
'''ldm.generate() is the common entry point for txt2img() and img2img()'''
|
||||
def prompt2image(self,
|
||||
# these are common
|
||||
prompt,
|
||||
batch_size=None,
|
||||
iterations=None,
|
||||
steps=None,
|
||||
seed=None,
|
||||
cfg_scale=None,
|
||||
ddim_eta=None,
|
||||
skip_normalize=False,
|
||||
image_callback=None,
|
||||
# these are specific to txt2img
|
||||
width=None,
|
||||
height=None,
|
||||
# these are specific to img2img
|
||||
init_img=None,
|
||||
strength=None,
|
||||
variants=None,
|
||||
**args): # eat up additional cruft
|
||||
'''ldm.prompt2image() is the common entry point for txt2img() and img2img()'''
|
||||
steps = steps or self.steps
|
||||
seed = seed or self.seed
|
||||
width = width or self.width
|
||||
@ -178,10 +178,6 @@ The vast majority of these arguments default to reasonable values.
|
||||
|
||||
data = [batch_size * [prompt]]
|
||||
scope = autocast if self.precision=="autocast" else nullcontext
|
||||
if grid:
|
||||
callback = self.image2png
|
||||
else:
|
||||
callback = None
|
||||
|
||||
tic = time.time()
|
||||
if init_img:
|
||||
@ -212,7 +208,7 @@ The vast majority of these arguments default to reasonable values.
|
||||
steps,seed,cfg_scale,ddim_eta,
|
||||
skip_normalize,
|
||||
width,height,
|
||||
callback=callback): # the callback is called each time a new Image is generated
|
||||
callback): # the callback is called each time a new Image is generated
|
||||
"""
|
||||
Generate an image from the prompt, writing iteration images into the outdir
|
||||
The output is a list of lists in the format: [[image1,seed1], [image2,seed2],...]
|
||||
@ -224,14 +220,14 @@ The vast majority of these arguments default to reasonable values.
|
||||
|
||||
# Gawd. Too many levels of indent here. Need to refactor into smaller routines!
|
||||
try:
|
||||
with precision_scope(self.device.type), model.ema_scope():
|
||||
with precision_scope(self.device.type), self.model.ema_scope():
|
||||
all_samples = list()
|
||||
for n in trange(iterations, desc="Sampling"):
|
||||
seed_everything(seed)
|
||||
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
||||
uc = None
|
||||
if cfg_scale != 1.0:
|
||||
uc = model.get_learned_conditioning(batch_size * [""])
|
||||
uc = self.model.get_learned_conditioning(batch_size * [""])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
|
||||
@ -247,9 +243,9 @@ The vast majority of these arguments default to reasonable values.
|
||||
weight = weights[i]
|
||||
if not skip_normalize:
|
||||
weight = weight / totalWeight
|
||||
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||
else: # just standard 1 prompt
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
c = self.model.get_learned_conditioning(prompts)
|
||||
|
||||
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
||||
samples_ddim, _ = sampler.sample(S=steps,
|
||||
@ -261,7 +257,7 @@ The vast majority of these arguments default to reasonable values.
|
||||
unconditional_conditioning=uc,
|
||||
eta=ddim_eta)
|
||||
|
||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
for x_sample in x_samples_ddim:
|
||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||
@ -277,8 +273,6 @@ The vast majority of these arguments default to reasonable values.
|
||||
except RuntimeError as e:
|
||||
print(str(e))
|
||||
|
||||
toc = time.time()
|
||||
print(f'{image_count} images generated in',"%4.2fs"% (toc-tic))
|
||||
return images
|
||||
|
||||
@torch.no_grad()
|
||||
@ -297,14 +291,14 @@ The vast majority of these arguments default to reasonable values.
|
||||
# PLMS sampler not supported yet, so ignore previous sampler
|
||||
if self.sampler_name!='ddim':
|
||||
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler")
|
||||
sampler = DDIMSampler(model, device=self.device)
|
||||
sampler = DDIMSampler(self.model, device=self.device)
|
||||
else:
|
||||
sampler = self.sampler
|
||||
|
||||
init_image = self._load_img(init_img).to(self.device)
|
||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||
with precision_scope(self.device.type):
|
||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
|
||||
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space
|
||||
|
||||
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
||||
|
||||
@ -314,14 +308,14 @@ The vast majority of these arguments default to reasonable values.
|
||||
images = list()
|
||||
|
||||
try:
|
||||
with precision_scope(self.device.type), model.ema_scope():
|
||||
with precision_scope(self.device.type), self.model.ema_scope():
|
||||
all_samples = list()
|
||||
for n in trange(iterations, desc="Sampling"):
|
||||
seed_everything(seed)
|
||||
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
||||
uc = None
|
||||
if cfg_scale != 1.0:
|
||||
uc = model.get_learned_conditioning(batch_size * [""])
|
||||
uc = self.model.get_learned_conditioning(batch_size * [""])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
|
||||
@ -337,9 +331,9 @@ The vast majority of these arguments default to reasonable values.
|
||||
weight = weights[i]
|
||||
if not skip_normalize:
|
||||
weight = weight / totalWeight
|
||||
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||
else: # just standard 1 prompt
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
c = self.model.get_learned_conditioning(prompts)
|
||||
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
|
||||
@ -347,7 +341,7 @@ The vast majority of these arguments default to reasonable values.
|
||||
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,)
|
||||
|
||||
x_samples = model.decode_first_stage(samples)
|
||||
x_samples = self.model.decode_first_stage(samples)
|
||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
for x_sample in x_samples:
|
||||
|
@ -6,7 +6,7 @@ import shlex
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
from ldm.dream_util import Completer,PngWriter
|
||||
from ldm.dream_util import Completer,PngWriter,PromptFormatter
|
||||
|
||||
debugging = False
|
||||
|
||||
@ -27,10 +27,6 @@ def main():
|
||||
config = "configs/stable-diffusion/v1-inference.yaml"
|
||||
weights = "models/ldm/stable-diffusion-v1/model.ckpt"
|
||||
|
||||
# command line history will be stored in a file called "~/.dream_history"
|
||||
if readline_available:
|
||||
setup_readline()
|
||||
|
||||
print("* Initializing, be patient...\n")
|
||||
sys.path.append('.')
|
||||
from pytorch_lightning import logging
|
||||
@ -46,8 +42,6 @@ def main():
|
||||
# the user input loop
|
||||
t2i = T2I(width=width,
|
||||
height=height,
|
||||
batch_size=opt.batch_size,
|
||||
outdir=opt.outdir,
|
||||
sampler_name=opt.sampler_name,
|
||||
weights=weights,
|
||||
full_precision=opt.full_precision,
|
||||
@ -79,13 +73,13 @@ def main():
|
||||
log_path = os.path.join(opt.outdir,'dream_log.txt')
|
||||
with open(log_path,'a') as log:
|
||||
cmd_parser = create_cmd_parser()
|
||||
main_loop(t2i,cmd_parser,log,infile)
|
||||
main_loop(t2i,opt.outdir,cmd_parser,log,infile)
|
||||
log.close()
|
||||
if infile:
|
||||
infile.close()
|
||||
|
||||
|
||||
def main_loop(t2i,parser,log,infile):
|
||||
def main_loop(t2i,outdir,parser,log,infile):
|
||||
''' prompt/read/execute loop '''
|
||||
done = False
|
||||
|
||||
@ -123,13 +117,13 @@ def main_loop(t2i,parser,log,infile):
|
||||
if elements[0]=='cd' and len(elements)>1:
|
||||
if os.path.exists(elements[1]):
|
||||
print(f"setting image output directory to {elements[1]}")
|
||||
opt.outdir=elements[1]
|
||||
outdir=elements[1]
|
||||
else:
|
||||
print(f"directory {elements[1]} does not exist")
|
||||
continue
|
||||
|
||||
if elements[0]=='pwd':
|
||||
print(f"current output directory is {opt.outdir}")
|
||||
print(f"current output directory is {outdir}")
|
||||
continue
|
||||
|
||||
if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command
|
||||
@ -158,88 +152,41 @@ def main_loop(t2i,parser,log,infile):
|
||||
print("Try again with a prompt!")
|
||||
continue
|
||||
|
||||
normalized_prompt = PromptFormatter(t2i,opt).normalize_prompt()
|
||||
try:
|
||||
file_writer = PngWriter(opt)
|
||||
opt.callback = file_writer(write_image)
|
||||
run_generator(**vars(opt))
|
||||
file_writer = PngWriter(outdir,opt,normalized_prompt)
|
||||
callback = file_writer.write_image
|
||||
|
||||
t2i.prompt2image(image_callback=callback,
|
||||
**vars(opt))
|
||||
results = file_writer.files_written
|
||||
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
continue
|
||||
|
||||
print("Outputs:")
|
||||
write_log_message(t2i,opt,results,log)
|
||||
write_log_message(t2i,normalized_prompt,results,log)
|
||||
|
||||
print("goodbye!")
|
||||
|
||||
def write_log_message(t2i,opt,results,logfile):
|
||||
def write_log_message(t2i,prompt,results,logfile):
|
||||
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata '''
|
||||
switches = _reconstruct_switches(t2i,opt)
|
||||
prompt_str = ' '.join(switches)
|
||||
|
||||
# when multiple images are produced in batch, then we keep track of where each starts
|
||||
last_seed = None
|
||||
img_num = 1
|
||||
batch_size = opt.batch_size or t2i.batch_size
|
||||
seenit = {}
|
||||
|
||||
seeds = [a[1] for a in results]
|
||||
if batch_size > 1:
|
||||
seeds = f"(seeds for each batch row: {seeds})"
|
||||
else:
|
||||
seeds = f"(seeds for individual images: {seeds})"
|
||||
seeds = f"(seeds for individual images: {seeds})"
|
||||
|
||||
for r in results:
|
||||
seed = r[1]
|
||||
log_message = (f'{r[0]}: {prompt_str} -S{seed}')
|
||||
log_message = (f'{r[0]}: {prompt} -S{seed}')
|
||||
|
||||
if batch_size > 1:
|
||||
if seed != last_seed:
|
||||
img_num = 1
|
||||
log_message += f' # (batch image {img_num} of {batch_size})'
|
||||
else:
|
||||
img_num += 1
|
||||
log_message += f' # (batch image {img_num} of {batch_size})'
|
||||
last_seed = seed
|
||||
print(log_message)
|
||||
logfile.write(log_message+"\n")
|
||||
logfile.flush()
|
||||
if r[0] not in seenit:
|
||||
seenit[r[0]] = True
|
||||
try:
|
||||
if opt.grid:
|
||||
_write_prompt_to_png(r[0],f'{prompt_str} -g -S{seed} {seeds}')
|
||||
else:
|
||||
_write_prompt_to_png(r[0],f'{prompt_str} -S{seed}')
|
||||
except FileNotFoundError:
|
||||
print(f"Could not open file '{r[0]}' for reading")
|
||||
|
||||
def _reconstruct_switches(t2i,opt):
|
||||
'''Normalize the prompt and switches'''
|
||||
switches = list()
|
||||
switches.append(f'"{opt.prompt}"')
|
||||
switches.append(f'-s{opt.steps or t2i.steps}')
|
||||
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
|
||||
switches.append(f'-W{opt.width or t2i.width}')
|
||||
switches.append(f'-H{opt.height or t2i.height}')
|
||||
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
||||
switches.append(f'-m{t2i.sampler_name}')
|
||||
if opt.variants:
|
||||
switches.append(f'-v{opt.variants}')
|
||||
if opt.init_img:
|
||||
switches.append(f'-I{opt.init_img}')
|
||||
if opt.strength and opt.init_img is not None:
|
||||
switches.append(f'-f{opt.strength or t2i.strength}')
|
||||
if t2i.full_precision:
|
||||
switches.append('-F')
|
||||
return switches
|
||||
|
||||
def _write_prompt_to_png(path,prompt):
|
||||
info = PngImagePlugin.PngInfo()
|
||||
info.add_text("Dream",prompt)
|
||||
im = Image.open(path)
|
||||
im.save(path,"PNG",pnginfo=info)
|
||||
|
||||
def create_argv_parser():
|
||||
parser = argparse.ArgumentParser(description="Parse script's command line args")
|
||||
parser.add_argument("--laion400m",
|
||||
@ -260,10 +207,6 @@ def create_argv_parser():
|
||||
dest='full_precision',
|
||||
action='store_true',
|
||||
help="use slower full precision math for calculations")
|
||||
parser.add_argument('-b','--batch_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of images to produce per iteration (faster, but doesn't generate individual seeds")
|
||||
parser.add_argument('--sampler','-m',
|
||||
dest="sampler_name",
|
||||
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],
|
||||
|
Loading…
Reference in New Issue
Block a user