mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
img2img is now working; small refactoring of grid code in simplet2i.py
This commit is contained in:
parent
c477525036
commit
bf76c4f283
190
ldm/simplet2i.py
190
ldm/simplet2i.py
@ -26,23 +26,22 @@ t2i.load_model()
|
|||||||
# override the default values assigned during class initialization
|
# override the default values assigned during class initialization
|
||||||
# Will call load_model() if the model was not previously loaded.
|
# Will call load_model() if the model was not previously loaded.
|
||||||
# The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
|
# The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
|
||||||
results = t2i.txt2img(prompt = <string> // required
|
results = t2i.txt2img(prompt = "an astronaut riding a horse"
|
||||||
outdir = <path> // the remaining option arguments override constructur value when present
|
outdir = "./outputs/txt2img-samples)
|
||||||
iterations = <integer>
|
)
|
||||||
batch = <integer>
|
|
||||||
steps = <integer>
|
|
||||||
seed = <integer>
|
|
||||||
sampler = ['ddim','plms']
|
|
||||||
grid = <boolean>
|
|
||||||
width = <integer>
|
|
||||||
height = <integer>
|
|
||||||
cfg_scale = <float>
|
|
||||||
) -> boolean
|
|
||||||
|
|
||||||
for row in results:
|
for row in results:
|
||||||
print(f'filename={row[0]}')
|
print(f'filename={row[0]}')
|
||||||
print(f'seed ={row[1]}')
|
print(f'seed ={row[1]}')
|
||||||
|
|
||||||
|
# Same thing, but using an initial image.
|
||||||
|
results = t2i.img2img(prompt = "an astronaut riding a horse"
|
||||||
|
outdir = "./outputs/img2img-samples"
|
||||||
|
init_img = "./sketches/horse+rider.png")
|
||||||
|
|
||||||
|
for row in results:
|
||||||
|
print(f'filename={row[0]}')
|
||||||
|
print(f'seed ={row[1]}')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -54,7 +53,7 @@ from omegaconf import OmegaConf
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from einops import rearrange
|
from einops import rearrange, repeat
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
@ -87,6 +86,7 @@ class T2I:
|
|||||||
latent_channels
|
latent_channels
|
||||||
downsampling_factor
|
downsampling_factor
|
||||||
precision
|
precision
|
||||||
|
strength
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
outdir="outputs/txt2img-samples",
|
outdir="outputs/txt2img-samples",
|
||||||
@ -106,7 +106,8 @@ class T2I:
|
|||||||
downsampling_factor=8,
|
downsampling_factor=8,
|
||||||
ddim_eta=0.0, # deterministic
|
ddim_eta=0.0, # deterministic
|
||||||
fixed_code=False,
|
fixed_code=False,
|
||||||
precision='autocast'
|
precision='autocast',
|
||||||
|
strength=0.75 # default in scripts/img2img.py
|
||||||
):
|
):
|
||||||
self.outdir = outdir
|
self.outdir = outdir
|
||||||
self.batch = batch
|
self.batch = batch
|
||||||
@ -124,15 +125,17 @@ class T2I:
|
|||||||
self.downsampling_factor = downsampling_factor
|
self.downsampling_factor = downsampling_factor
|
||||||
self.ddim_eta = ddim_eta
|
self.ddim_eta = ddim_eta
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
|
self.strength = strength
|
||||||
self.model = None # empty for now
|
self.model = None # empty for now
|
||||||
self.sampler = None
|
self.sampler = None
|
||||||
if seed is None:
|
if seed is None:
|
||||||
self.seed = self._new_seed()
|
self.seed = self._new_seed()
|
||||||
else:
|
else:
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
def txt2img(self,prompt,outdir=None,batch=None,iterations=None,
|
def txt2img(self,prompt,outdir=None,batch=None,iterations=None,
|
||||||
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
||||||
cfg_scale=None,ddim_eta=None):
|
cfg_scale=None,ddim_eta=None,strength=None,init_img=None):
|
||||||
"""
|
"""
|
||||||
Generate an image from the prompt, writing iteration images into the outdir
|
Generate an image from the prompt, writing iteration images into the outdir
|
||||||
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
|
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
|
||||||
@ -146,6 +149,7 @@ class T2I:
|
|||||||
ddim_eta = ddim_eta or self.ddim_eta
|
ddim_eta = ddim_eta or self.ddim_eta
|
||||||
batch = batch or self.batch
|
batch = batch or self.batch
|
||||||
iterations = iterations or self.iterations
|
iterations = iterations or self.iterations
|
||||||
|
strength = strength or self.strength # not actually used here, but preserved for code refactoring
|
||||||
|
|
||||||
model = self.load_model() # will instantiate the model or return it from cache
|
model = self.load_model() # will instantiate the model or return it from cache
|
||||||
|
|
||||||
@ -218,24 +222,146 @@ class T2I:
|
|||||||
seed = self._new_seed()
|
seed = self._new_seed()
|
||||||
|
|
||||||
if grid:
|
if grid:
|
||||||
n_rows = batch if batch>1 else int(math.sqrt(batch * iterations))
|
images = self._make_grid(samples=all_samples,
|
||||||
# save as grid
|
seeds=seeds,
|
||||||
grid = torch.stack(all_samples, 0)
|
batch_size=batch,
|
||||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
iterations=iterations,
|
||||||
grid = make_grid(grid, nrow=n_rows)
|
outdir=outdir)
|
||||||
|
|
||||||
# to image
|
|
||||||
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
|
||||||
filename = os.path.join(outdir, f"{base_count:05}.png")
|
|
||||||
Image.fromarray(grid.astype(np.uint8)).save(filename)
|
|
||||||
for s in seeds:
|
|
||||||
images.append([filename,s])
|
|
||||||
|
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
print(f'{batch * iterations} images generated in',"%4.2fs"% (toc-tic))
|
print(f'{batch * iterations} images generated in',"%4.2fs"% (toc-tic))
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
# There is lots of shared code between this and txt2img and should be refactored.
|
||||||
|
def img2img(self,prompt,outdir=None,init_img=None,batch=None,iterations=None,
|
||||||
|
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
||||||
|
cfg_scale=None,ddim_eta=None,strength=None):
|
||||||
|
"""
|
||||||
|
Generate an image from the prompt and the initial image, writing iteration images into the outdir
|
||||||
|
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
|
||||||
|
"""
|
||||||
|
outdir = outdir or self.outdir
|
||||||
|
steps = steps or self.steps
|
||||||
|
seed = seed or self.seed
|
||||||
|
cfg_scale = cfg_scale or self.cfg_scale
|
||||||
|
ddim_eta = ddim_eta or self.ddim_eta
|
||||||
|
batch = batch or self.batch
|
||||||
|
iterations = iterations or self.iterations
|
||||||
|
strength = strength or self.strength
|
||||||
|
|
||||||
|
if init_img is None:
|
||||||
|
print("no init_img provided!")
|
||||||
|
return []
|
||||||
|
|
||||||
|
model = self.load_model() # will instantiate the model or return it from cache
|
||||||
|
|
||||||
|
# grid and individual are mutually exclusive, with individual taking priority.
|
||||||
|
# not necessary, but needed for compatability with dream bot
|
||||||
|
if (grid is None):
|
||||||
|
grid = self.grid
|
||||||
|
if individual:
|
||||||
|
grid = False
|
||||||
|
|
||||||
|
data = [batch * [prompt]]
|
||||||
|
|
||||||
|
# PLMS sampler not supported yet, so ignore previous sampler
|
||||||
|
if self.sampler_name!='ddim':
|
||||||
|
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler")
|
||||||
|
sampler = DDIMSampler(model)
|
||||||
|
else:
|
||||||
|
sampler = self.sampler
|
||||||
|
|
||||||
|
# make directories and establish names for the output files
|
||||||
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
base_count = len(os.listdir(outdir))-1
|
||||||
|
|
||||||
|
assert os.path.isfile(init_img)
|
||||||
|
init_image = self._load_img(init_img).to(self.device)
|
||||||
|
init_image = repeat(init_image, '1 ... -> b ...', b=batch)
|
||||||
|
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
|
||||||
|
|
||||||
|
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||||
|
except AssertionError:
|
||||||
|
print(f"strength must be between 0.0 and 1.0, but received value {strength}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
t_enc = int(strength * steps)
|
||||||
|
print(f"target t_enc is {t_enc} steps")
|
||||||
|
|
||||||
|
precision_scope = autocast if self.precision=="autocast" else nullcontext
|
||||||
|
images = list()
|
||||||
|
seeds = list()
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with precision_scope("cuda"):
|
||||||
|
with model.ema_scope():
|
||||||
|
all_samples = list()
|
||||||
|
for n in trange(iterations, desc="Sampling"):
|
||||||
|
seed_everything(seed)
|
||||||
|
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
||||||
|
uc = None
|
||||||
|
if cfg_scale != 1.0:
|
||||||
|
uc = model.get_learned_conditioning(batch * [""])
|
||||||
|
if isinstance(prompts, tuple):
|
||||||
|
prompts = list(prompts)
|
||||||
|
c = model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
|
# encode (scaled latent)
|
||||||
|
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch).to(self.device))
|
||||||
|
# decode it
|
||||||
|
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
|
||||||
|
unconditional_conditioning=uc,)
|
||||||
|
|
||||||
|
x_samples = model.decode_first_stage(samples)
|
||||||
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
if not grid:
|
||||||
|
for x_sample in x_samples:
|
||||||
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
|
filename = os.path.join(outdir, f"{base_count:05}.png")
|
||||||
|
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
|
||||||
|
images.append([filename,seed])
|
||||||
|
base_count += 1
|
||||||
|
else:
|
||||||
|
all_samples.append(x_samples)
|
||||||
|
seeds.append(seed)
|
||||||
|
|
||||||
|
seed = self._new_seed()
|
||||||
|
|
||||||
|
if grid:
|
||||||
|
images = self._make_grid(samples=all_samples,
|
||||||
|
seeds=seeds,
|
||||||
|
batch_size=batch,
|
||||||
|
iterations=iterations,
|
||||||
|
outdir=outdir)
|
||||||
|
|
||||||
|
toc = time.time()
|
||||||
|
print(f'{batch * iterations} images generated in',"%4.2fs"% (toc-tic))
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
def _make_grid(self,samples,seeds,batch_size,iterations,outdir):
|
||||||
|
images = list()
|
||||||
|
base_count = len(os.listdir(outdir))-1
|
||||||
|
n_rows = batch_size if batch_size>1 else int(math.sqrt(batch_size * iterations))
|
||||||
|
# save as grid
|
||||||
|
grid = torch.stack(samples, 0)
|
||||||
|
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||||
|
grid = make_grid(grid, nrow=n_rows)
|
||||||
|
|
||||||
|
# to image
|
||||||
|
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||||
|
filename = os.path.join(outdir, f"{base_count:05}.png")
|
||||||
|
Image.fromarray(grid.astype(np.uint8)).save(filename)
|
||||||
|
for s in seeds:
|
||||||
|
images.append([filename,s])
|
||||||
|
return images
|
||||||
|
|
||||||
def _new_seed(self):
|
def _new_seed(self):
|
||||||
self.seed = random.randrange(0,np.iinfo(np.uint32).max)
|
self.seed = random.randrange(0,np.iinfo(np.uint32).max)
|
||||||
@ -277,3 +403,13 @@ class T2I:
|
|||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def _load_img(self,path):
|
||||||
|
image = Image.open(path).convert("RGB")
|
||||||
|
w, h = image.size
|
||||||
|
print(f"loaded input image of size ({w}, {h}) from {path}")
|
||||||
|
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||||
|
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
|
||||||
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
|
image = torch.from_numpy(image)
|
||||||
|
return 2.*image - 1.
|
||||||
|
@ -92,7 +92,10 @@ def main_loop(t2i,parser,log):
|
|||||||
print("Try again with a prompt!")
|
print("Try again with a prompt!")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
results = t2i.txt2img(**vars(opt))
|
if opt.init_img is None:
|
||||||
|
results = t2i.txt2img(**vars(opt))
|
||||||
|
else:
|
||||||
|
results = t2i.img2img(**vars(opt))
|
||||||
print("Outputs:")
|
print("Outputs:")
|
||||||
write_log_message(opt,switches,results,log)
|
write_log_message(opt,switches,results,log)
|
||||||
|
|
||||||
@ -161,9 +164,11 @@ def create_cmd_parser():
|
|||||||
parser.add_argument('-b','--batch',type=int,default=1,help="number of images to produce per sampling (currently broken)")
|
parser.add_argument('-b','--batch',type=int,default=1,help="number of images to produce per sampling (currently broken)")
|
||||||
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
|
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
|
||||||
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
|
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
|
||||||
parser.add_argument('-C','--cfg_scale',type=float,help="prompt configuration scale (7.5)")
|
parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale")
|
||||||
parser.add_argument('-g','--grid',action='store_true',help="generate a grid")
|
parser.add_argument('-g','--grid',action='store_true',help="generate a grid")
|
||||||
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)")
|
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)")
|
||||||
|
parser.add_argument('-I','--init_img',type=str,help="path to input image (supersedes width and height)")
|
||||||
|
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def load_history():
|
def load_history():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user