folded in changes from img2img-dev

This commit is contained in:
Lincoln Stein 2022-08-18 12:45:02 -04:00
commit 87fb4186d4
2 changed files with 256 additions and 45 deletions

View File

@ -8,7 +8,7 @@ t2i = T2I(outdir = <path> // outputs/txt2img-samples
model = <path> // models/ldm/stable-diffusion-v1/model.ckpt model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
config = <path> // default="configs/stable-diffusion/v1-inference.yaml config = <path> // default="configs/stable-diffusion/v1-inference.yaml
iterations = <integer> // how many times to run the sampling (1) iterations = <integer> // how many times to run the sampling (1)
batch = <integer> // how many images to generate per sampling (1) batch_size = <integer> // how many images to generate per sampling (1)
steps = <integer> // 50 steps = <integer> // 50
seed = <integer> // current system time seed = <integer> // current system time
sampler = ['ddim','plms'] // ddim sampler = ['ddim','plms'] // ddim
@ -26,23 +26,22 @@ t2i.load_model()
# override the default values assigned during class initialization # override the default values assigned during class initialization
# Will call load_model() if the model was not previously loaded. # Will call load_model() if the model was not previously loaded.
# The method returns a list of images. Each row of the list is a sub-list of [filename,seed] # The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
results = t2i.txt2img(prompt = <string> // required results = t2i.txt2img(prompt = "an astronaut riding a horse"
outdir = <path> // the remaining option arguments override constructur value when present outdir = "./outputs/txt2img-samples)
iterations = <integer> )
batch = <integer>
steps = <integer>
seed = <integer>
sampler = ['ddim','plms']
grid = <boolean>
width = <integer>
height = <integer>
cfg_scale = <float>
) -> boolean
for row in results: for row in results:
print(f'filename={row[0]}') print(f'filename={row[0]}')
print(f'seed ={row[1]}') print(f'seed ={row[1]}')
# Same thing, but using an initial image.
results = t2i.img2img(prompt = "an astronaut riding a horse"
outdir = "./outputs/img2img-samples"
init_img = "./sketches/horse+rider.png")
for row in results:
print(f'filename={row[0]}')
print(f'seed ={row[1]}')
""" """
import torch import torch
@ -54,7 +53,7 @@ from omegaconf import OmegaConf
from PIL import Image from PIL import Image
from tqdm import tqdm, trange from tqdm import tqdm, trange
from itertools import islice from itertools import islice
from einops import rearrange from einops import rearrange, repeat
from torchvision.utils import make_grid from torchvision.utils import make_grid
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from torch import autocast from torch import autocast
@ -74,7 +73,7 @@ class T2I:
model model
config config
iterations iterations
batch batch_size
steps steps
seed seed
sampler sampler
@ -87,10 +86,11 @@ class T2I:
latent_channels latent_channels
downsampling_factor downsampling_factor
precision precision
strength
""" """
def __init__(self, def __init__(self,
outdir="outputs/txt2img-samples", outdir="outputs/txt2img-samples",
batch=1, batch_size=1,
iterations = 1, iterations = 1,
width=512, width=512,
height=512, height=512,
@ -106,10 +106,11 @@ class T2I:
downsampling_factor=8, downsampling_factor=8,
ddim_eta=0.0, # deterministic ddim_eta=0.0, # deterministic
fixed_code=False, fixed_code=False,
precision='autocast' precision='autocast',
strength=0.75 # default in scripts/img2img.py
): ):
self.outdir = outdir self.outdir = outdir
self.batch = batch self.batch_size = batch_size
self.iterations = iterations self.iterations = iterations
self.width = width self.width = width
self.height = height self.height = height
@ -124,15 +125,17 @@ class T2I:
self.downsampling_factor = downsampling_factor self.downsampling_factor = downsampling_factor
self.ddim_eta = ddim_eta self.ddim_eta = ddim_eta
self.precision = precision self.precision = precision
self.strength = strength
self.model = None # empty for now self.model = None # empty for now
self.sampler = None self.sampler = None
if seed is None: if seed is None:
self.seed = self._new_seed() self.seed = self._new_seed()
else: else:
self.seed = seed self.seed = seed
def txt2img(self,prompt,outdir=None,batch=None,iterations=None,
def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None, steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None): cfg_scale=None,ddim_eta=None,strength=None,init_img=None):
""" """
Generate an image from the prompt, writing iteration images into the outdir Generate an image from the prompt, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
@ -144,8 +147,9 @@ class T2I:
height = height or self.height height = height or self.height
cfg_scale = cfg_scale or self.cfg_scale cfg_scale = cfg_scale or self.cfg_scale
ddim_eta = ddim_eta or self.ddim_eta ddim_eta = ddim_eta or self.ddim_eta
batch = batch or self.batch batch_size = batch_size or self.batch_size
iterations = iterations or self.iterations iterations = iterations or self.iterations
strength = strength or self.strength # not actually used here, but preserved for code refactoring
model = self.load_model() # will instantiate the model or return it from cache model = self.load_model() # will instantiate the model or return it from cache
@ -156,7 +160,7 @@ class T2I:
if individual: if individual:
grid = False grid = False
data = [batch * [prompt]] data = [batch_size * [prompt]]
# make directories and establish names for the output files # make directories and establish names for the output files
os.makedirs(outdir, exist_ok=True) os.makedirs(outdir, exist_ok=True)
@ -164,7 +168,7 @@ class T2I:
start_code = None start_code = None
if self.fixed_code: if self.fixed_code:
start_code = torch.randn([batch, start_code = torch.randn([batch_size,
self.latent_channels, self.latent_channels,
height // self.downsampling_factor, height // self.downsampling_factor,
width // self.downsampling_factor], width // self.downsampling_factor],
@ -186,14 +190,14 @@ class T2I:
for prompts in tqdm(data, desc="data", dynamic_ncols=True): for prompts in tqdm(data, desc="data", dynamic_ncols=True):
uc = None uc = None
if cfg_scale != 1.0: if cfg_scale != 1.0:
uc = model.get_learned_conditioning(batch * [""]) uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple): if isinstance(prompts, tuple):
prompts = list(prompts) prompts = list(prompts)
c = model.get_learned_conditioning(prompts) c = model.get_learned_conditioning(prompts)
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
samples_ddim, _ = sampler.sample(S=steps, samples_ddim, _ = sampler.sample(S=steps,
conditioning=c, conditioning=c,
batch_size=batch, batch_size_size=batch_size,
shape=shape, shape=shape,
verbose=False, verbose=False,
unconditional_guidance_scale=cfg_scale, unconditional_guidance_scale=cfg_scale,
@ -218,24 +222,146 @@ class T2I:
seed = self._new_seed() seed = self._new_seed()
if grid: if grid:
n_rows = batch if batch>1 else int(math.sqrt(batch * iterations)) images = self._make_grid(samples=all_samples,
# save as grid seeds=seeds,
grid = torch.stack(all_samples, 0) batch_size=batch_size,
grid = rearrange(grid, 'n b c h w -> (n b) c h w') iterations=iterations,
grid = make_grid(grid, nrow=n_rows) outdir=outdir)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
filename = os.path.join(outdir, f"{base_count:05}.png")
Image.fromarray(grid.astype(np.uint8)).save(filename)
for s in seeds:
images.append([filename,s])
toc = time.time() toc = time.time()
print(f'{batch * iterations} images generated in',"%4.2fs"% (toc-tic)) print(f'{batch_size * iterations} images generated in',"%4.2fs"% (toc-tic))
return images return images
# There is lots of shared code between this and txt2img and should be refactored.
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None,strength=None):
"""
Generate an image from the prompt and the initial image, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
"""
outdir = outdir or self.outdir
steps = steps or self.steps
seed = seed or self.seed
cfg_scale = cfg_scale or self.cfg_scale
ddim_eta = ddim_eta or self.ddim_eta
batch_size = batch_size or self.batch_size
iterations = iterations or self.iterations
strength = strength or self.strength
if init_img is None:
print("no init_img provided!")
return []
model = self.load_model() # will instantiate the model or return it from cache
# grid and individual are mutually exclusive, with individual taking priority.
# not necessary, but needed for compatability with dream bot
if (grid is None):
grid = self.grid
if individual:
grid = False
data = [batch_size * [prompt]]
# PLMS sampler not supported yet, so ignore previous sampler
if self.sampler_name!='ddim':
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler")
sampler = DDIMSampler(model)
else:
sampler = self.sampler
# make directories and establish names for the output files
os.makedirs(outdir, exist_ok=True)
base_count = len(os.listdir(outdir))-1
assert os.path.isfile(init_img)
init_image = self._load_img(init_img).to(self.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
try:
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
except AssertionError:
print(f"strength must be between 0.0 and 1.0, but received value {strength}")
return []
t_enc = int(strength * steps)
print(f"target t_enc is {t_enc} steps")
precision_scope = autocast if self.precision=="autocast" else nullcontext
images = list()
seeds = list()
tic = time.time()
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
all_samples = list()
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
uc = None
if cfg_scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,)
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if not grid:
for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
filename = os.path.join(outdir, f"{base_count:05}.png")
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
images.append([filename,seed])
base_count += 1
else:
all_samples.append(x_samples)
seeds.append(seed)
seed = self._new_seed()
if grid:
images = self._make_grid(samples=all_samples,
seeds=seeds,
batch_size=batch_size,
iterations=iterations,
outdir=outdir)
toc = time.time()
print(f'{batch_size * iterations} images generated in',"%4.2fs"% (toc-tic))
return images
def _make_grid(self,samples,seeds,batch_size,iterations,outdir):
images = list()
base_count = len(os.listdir(outdir))-1
n_rows = batch_size if batch_size>1 else int(math.sqrt(batch_size * iterations))
# save as grid
grid = torch.stack(samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
filename = os.path.join(outdir, f"{base_count:05}.png")
Image.fromarray(grid.astype(np.uint8)).save(filename)
for s in seeds:
images.append([filename,s])
return images
def _new_seed(self): def _new_seed(self):
self.seed = random.randrange(0,np.iinfo(np.uint32).max) self.seed = random.randrange(0,np.iinfo(np.uint32).max)
@ -277,3 +403,13 @@ class T2I:
model.eval() model.eval()
return model return model
def _load_img(self,path):
image = Image.open(path).convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}")
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.*image - 1.

View File

@ -6,6 +6,8 @@ import shlex
import atexit import atexit
import os import os
debugging = False
def main(): def main():
''' Initialize command-line parsers and the diffusion model ''' ''' Initialize command-line parsers and the diffusion model '''
arg_parser = create_argv_parser() arg_parser = create_argv_parser()
@ -24,7 +26,7 @@ def main():
weights = "models/ldm/stable-diffusion-v1/model.ckpt" weights = "models/ldm/stable-diffusion-v1/model.ckpt"
# command line history will be stored in a file called "~/.dream_history" # command line history will be stored in a file called "~/.dream_history"
load_history() setup_readline()
print("* Initializing, be patient...\n") print("* Initializing, be patient...\n")
from pytorch_lightning import logging from pytorch_lightning import logging
@ -36,7 +38,7 @@ def main():
# the user input loop # the user input loop
t2i = T2I(width=width, t2i = T2I(width=width,
height=height, height=height,
batch=opt.batch, batch_size=opt.batch_size,
outdir=opt.outdir, outdir=opt.outdir,
sampler=opt.sampler, sampler=opt.sampler,
weights=weights, weights=weights,
@ -50,7 +52,8 @@ def main():
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
# preload the model # preload the model
t2i.load_model() if not debugging:
t2i.load_model()
print("\n* Initialization done! Awaiting your command (-h for help)...") print("\n* Initialization done! Awaiting your command (-h for help)...")
log_path = os.path.join(opt.outdir,"dream_log.txt") log_path = os.path.join(opt.outdir,"dream_log.txt")
@ -92,7 +95,10 @@ def main_loop(t2i,parser,log):
print("Try again with a prompt!") print("Try again with a prompt!")
continue continue
results = t2i.txt2img(**vars(opt)) if opt.init_img is None:
results = t2i.txt2img(**vars(opt))
else:
results = t2i.img2img(**vars(opt))
print("Outputs:") print("Outputs:")
write_log_message(opt,switches,results,log) write_log_message(opt,switches,results,log)
@ -136,7 +142,7 @@ def create_argv_parser():
type=int, type=int,
default=1, default=1,
help="number of images to generate") help="number of images to generate")
parser.add_argument('-b','--batch', parser.add_argument('-b','--batch_size',
type=int, type=int,
default=1, default=1,
help="number of images to produce per iteration (currently not working properly - producing too many images)") help="number of images to produce per iteration (currently not working properly - producing too many images)")
@ -158,14 +164,24 @@ def create_cmd_parser():
parser.add_argument('-s','--steps',type=int,help="number of steps") parser.add_argument('-s','--steps',type=int,help="number of steps")
parser.add_argument('-S','--seed',type=int,help="image seed") parser.add_argument('-S','--seed',type=int,help="image seed")
parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform") parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform")
parser.add_argument('-b','--batch',type=int,default=1,help="number of images to produce per sampling (currently broken)") parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling (currently broken)")
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64") parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64") parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
parser.add_argument('-C','--cfg_scale',type=float,help="prompt configuration scale (7.5)") parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale")
parser.add_argument('-g','--grid',action='store_true',help="generate a grid") parser.add_argument('-g','--grid',action='store_true',help="generate a grid")
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)") parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)")
parser.add_argument('-I','--init_img',type=str,help="path to input image (supersedes width and height)")
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely")
return parser return parser
def setup_readline():
readline.set_completer(Completer(['--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b',
'--width','-W','--height','-H','--cfg_scale','-C','--grid','-g',
'--individual','-i','--init_img','-I','--strength','-f']).complete)
readline.set_completer_delims(" ")
readline.parse_and_bind('tab: complete')
load_history()
def load_history(): def load_history():
histfile = os.path.join(os.path.expanduser('~'),".dream_history") histfile = os.path.join(os.path.expanduser('~'),".dream_history")
try: try:
@ -175,5 +191,64 @@ def load_history():
pass pass
atexit.register(readline.write_history_file,histfile) atexit.register(readline.write_history_file,histfile)
class Completer():
def __init__(self,options):
self.options = sorted(options)
return
def complete(self,text,state):
if text.startswith('-I') or text.startswith('--init_img'):
return self._image_completions(text,state)
response = None
if state == 0:
# This is the first time for this text, so build a match list.
if text:
self.matches = [s
for s in self.options
if s and s.startswith(text)]
else:
self.matches = self.options[:]
# Return the state'th item from the match list,
# if we have that many.
try:
response = self.matches[state]
except IndexError:
response = None
return response
def _image_completions(self,text,state):
# get the path so far
if text.startswith('-I'):
path = text.replace('-I','',1).lstrip()
elif text.startswith('--init_img='):
path = text.replace('--init_img=','',1).lstrip()
matches = list()
path = os.path.expanduser(path)
if len(path)==0:
matches.append(text+'./')
else:
dir = os.path.dirname(path)
dir_list = os.listdir(dir)
for n in dir_list:
if n.startswith('.') and len(n)>1:
continue
full_path = os.path.join(dir,n)
if full_path.startswith(path):
if os.path.isdir(full_path):
matches.append(os.path.join(os.path.dirname(text),n)+'/')
elif n.endswith('.png'):
matches.append(os.path.join(os.path.dirname(text),n))
try:
response = matches[state]
except IndexError:
response = None
return response
if __name__ == "__main__": if __name__ == "__main__":
main() main()