img2img is now working; small refactoring of grid code in simplet2i.py

This commit is contained in:
Lincoln Stein 2022-08-18 10:47:53 -04:00
parent c477525036
commit bf76c4f283
2 changed files with 170 additions and 29 deletions

View File

@ -26,23 +26,22 @@ t2i.load_model()
# override the default values assigned during class initialization # override the default values assigned during class initialization
# Will call load_model() if the model was not previously loaded. # Will call load_model() if the model was not previously loaded.
# The method returns a list of images. Each row of the list is a sub-list of [filename,seed] # The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
results = t2i.txt2img(prompt = <string> // required results = t2i.txt2img(prompt = "an astronaut riding a horse"
outdir = <path> // the remaining option arguments override constructur value when present outdir = "./outputs/txt2img-samples)
iterations = <integer> )
batch = <integer>
steps = <integer>
seed = <integer>
sampler = ['ddim','plms']
grid = <boolean>
width = <integer>
height = <integer>
cfg_scale = <float>
) -> boolean
for row in results: for row in results:
print(f'filename={row[0]}') print(f'filename={row[0]}')
print(f'seed ={row[1]}') print(f'seed ={row[1]}')
# Same thing, but using an initial image.
results = t2i.img2img(prompt = "an astronaut riding a horse"
outdir = "./outputs/img2img-samples"
init_img = "./sketches/horse+rider.png")
for row in results:
print(f'filename={row[0]}')
print(f'seed ={row[1]}')
""" """
import torch import torch
@ -54,7 +53,7 @@ from omegaconf import OmegaConf
from PIL import Image from PIL import Image
from tqdm import tqdm, trange from tqdm import tqdm, trange
from itertools import islice from itertools import islice
from einops import rearrange from einops import rearrange, repeat
from torchvision.utils import make_grid from torchvision.utils import make_grid
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from torch import autocast from torch import autocast
@ -87,6 +86,7 @@ class T2I:
latent_channels latent_channels
downsampling_factor downsampling_factor
precision precision
strength
""" """
def __init__(self, def __init__(self,
outdir="outputs/txt2img-samples", outdir="outputs/txt2img-samples",
@ -106,7 +106,8 @@ class T2I:
downsampling_factor=8, downsampling_factor=8,
ddim_eta=0.0, # deterministic ddim_eta=0.0, # deterministic
fixed_code=False, fixed_code=False,
precision='autocast' precision='autocast',
strength=0.75 # default in scripts/img2img.py
): ):
self.outdir = outdir self.outdir = outdir
self.batch = batch self.batch = batch
@ -124,15 +125,17 @@ class T2I:
self.downsampling_factor = downsampling_factor self.downsampling_factor = downsampling_factor
self.ddim_eta = ddim_eta self.ddim_eta = ddim_eta
self.precision = precision self.precision = precision
self.strength = strength
self.model = None # empty for now self.model = None # empty for now
self.sampler = None self.sampler = None
if seed is None: if seed is None:
self.seed = self._new_seed() self.seed = self._new_seed()
else: else:
self.seed = seed self.seed = seed
def txt2img(self,prompt,outdir=None,batch=None,iterations=None, def txt2img(self,prompt,outdir=None,batch=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None, steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None): cfg_scale=None,ddim_eta=None,strength=None,init_img=None):
""" """
Generate an image from the prompt, writing iteration images into the outdir Generate an image from the prompt, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
@ -146,6 +149,7 @@ class T2I:
ddim_eta = ddim_eta or self.ddim_eta ddim_eta = ddim_eta or self.ddim_eta
batch = batch or self.batch batch = batch or self.batch
iterations = iterations or self.iterations iterations = iterations or self.iterations
strength = strength or self.strength # not actually used here, but preserved for code refactoring
model = self.load_model() # will instantiate the model or return it from cache model = self.load_model() # will instantiate the model or return it from cache
@ -218,24 +222,146 @@ class T2I:
seed = self._new_seed() seed = self._new_seed()
if grid: if grid:
n_rows = batch if batch>1 else int(math.sqrt(batch * iterations)) images = self._make_grid(samples=all_samples,
# save as grid seeds=seeds,
grid = torch.stack(all_samples, 0) batch_size=batch,
grid = rearrange(grid, 'n b c h w -> (n b) c h w') iterations=iterations,
grid = make_grid(grid, nrow=n_rows) outdir=outdir)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
filename = os.path.join(outdir, f"{base_count:05}.png")
Image.fromarray(grid.astype(np.uint8)).save(filename)
for s in seeds:
images.append([filename,s])
toc = time.time() toc = time.time()
print(f'{batch * iterations} images generated in',"%4.2fs"% (toc-tic)) print(f'{batch * iterations} images generated in',"%4.2fs"% (toc-tic))
return images return images
# There is lots of shared code between this and txt2img and should be refactored.
def img2img(self,prompt,outdir=None,init_img=None,batch=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None,strength=None):
"""
Generate an image from the prompt and the initial image, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
"""
outdir = outdir or self.outdir
steps = steps or self.steps
seed = seed or self.seed
cfg_scale = cfg_scale or self.cfg_scale
ddim_eta = ddim_eta or self.ddim_eta
batch = batch or self.batch
iterations = iterations or self.iterations
strength = strength or self.strength
if init_img is None:
print("no init_img provided!")
return []
model = self.load_model() # will instantiate the model or return it from cache
# grid and individual are mutually exclusive, with individual taking priority.
# not necessary, but needed for compatability with dream bot
if (grid is None):
grid = self.grid
if individual:
grid = False
data = [batch * [prompt]]
# PLMS sampler not supported yet, so ignore previous sampler
if self.sampler_name!='ddim':
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler")
sampler = DDIMSampler(model)
else:
sampler = self.sampler
# make directories and establish names for the output files
os.makedirs(outdir, exist_ok=True)
base_count = len(os.listdir(outdir))-1
assert os.path.isfile(init_img)
init_image = self._load_img(init_img).to(self.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
try:
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
except AssertionError:
print(f"strength must be between 0.0 and 1.0, but received value {strength}")
return []
t_enc = int(strength * steps)
print(f"target t_enc is {t_enc} steps")
precision_scope = autocast if self.precision=="autocast" else nullcontext
images = list()
seeds = list()
tic = time.time()
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
all_samples = list()
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
uc = None
if cfg_scale != 1.0:
uc = model.get_learned_conditioning(batch * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch).to(self.device))
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,)
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if not grid:
for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
filename = os.path.join(outdir, f"{base_count:05}.png")
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
images.append([filename,seed])
base_count += 1
else:
all_samples.append(x_samples)
seeds.append(seed)
seed = self._new_seed()
if grid:
images = self._make_grid(samples=all_samples,
seeds=seeds,
batch_size=batch,
iterations=iterations,
outdir=outdir)
toc = time.time()
print(f'{batch * iterations} images generated in',"%4.2fs"% (toc-tic))
return images
def _make_grid(self,samples,seeds,batch_size,iterations,outdir):
images = list()
base_count = len(os.listdir(outdir))-1
n_rows = batch_size if batch_size>1 else int(math.sqrt(batch_size * iterations))
# save as grid
grid = torch.stack(samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
filename = os.path.join(outdir, f"{base_count:05}.png")
Image.fromarray(grid.astype(np.uint8)).save(filename)
for s in seeds:
images.append([filename,s])
return images
def _new_seed(self): def _new_seed(self):
self.seed = random.randrange(0,np.iinfo(np.uint32).max) self.seed = random.randrange(0,np.iinfo(np.uint32).max)
@ -277,3 +403,13 @@ class T2I:
model.eval() model.eval()
return model return model
def _load_img(self,path):
image = Image.open(path).convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}")
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.*image - 1.

View File

@ -92,7 +92,10 @@ def main_loop(t2i,parser,log):
print("Try again with a prompt!") print("Try again with a prompt!")
continue continue
results = t2i.txt2img(**vars(opt)) if opt.init_img is None:
results = t2i.txt2img(**vars(opt))
else:
results = t2i.img2img(**vars(opt))
print("Outputs:") print("Outputs:")
write_log_message(opt,switches,results,log) write_log_message(opt,switches,results,log)
@ -161,9 +164,11 @@ def create_cmd_parser():
parser.add_argument('-b','--batch',type=int,default=1,help="number of images to produce per sampling (currently broken)") parser.add_argument('-b','--batch',type=int,default=1,help="number of images to produce per sampling (currently broken)")
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64") parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64") parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
parser.add_argument('-C','--cfg_scale',type=float,help="prompt configuration scale (7.5)") parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale")
parser.add_argument('-g','--grid',action='store_true',help="generate a grid") parser.add_argument('-g','--grid',action='store_true',help="generate a grid")
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)") parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)")
parser.add_argument('-I','--init_img',type=str,help="path to input image (supersedes width and height)")
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely")
return parser return parser
def load_history(): def load_history():