mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
code is reorganized and mostly functional. Grid needs to be brought back online, as well as naming of img2img variants (currently the variants get written but not logged)
This commit is contained in:
@ -1,6 +1,7 @@
|
|||||||
'''Utilities for dealing with PNG images and their path names'''
|
'''Utilities for dealing with PNG images and their path names'''
|
||||||
import os
|
import os
|
||||||
import atexit
|
import atexit
|
||||||
|
import re
|
||||||
from PIL import Image,PngImagePlugin
|
from PIL import Image,PngImagePlugin
|
||||||
|
|
||||||
# ---------------readline utilities---------------------
|
# ---------------readline utilities---------------------
|
||||||
@ -94,40 +95,43 @@ if readline_available:
|
|||||||
# -------------------image generation utils-----
|
# -------------------image generation utils-----
|
||||||
class PngWriter:
|
class PngWriter:
|
||||||
|
|
||||||
def __init__(self,opt):
|
def __init__(self,outdir,opt,prompt):
|
||||||
|
self.outdir = outdir
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
|
self.prompt = prompt
|
||||||
self.filepath = None
|
self.filepath = None
|
||||||
self.files_written = []
|
self.files_written = []
|
||||||
|
|
||||||
def write_image(self,image,seed):
|
def write_image(self,image,seed):
|
||||||
self.filepath = self.unique_filename(self,opt,seed,self.filepath) # will increment name in some sensible way
|
self.filepath = self.unique_filename(seed,self.filepath) # will increment name in some sensible way
|
||||||
try:
|
try:
|
||||||
image.save(self.filename)
|
prompt = f'{self.prompt} -S{seed}'
|
||||||
|
self.save_image_and_prompt_to_png(image,prompt,self.filepath)
|
||||||
except IOError as e:
|
except IOError as e:
|
||||||
print(e)
|
print(e)
|
||||||
self.files_written.append([self.filepath,seed])
|
self.files_written.append([self.filepath,seed])
|
||||||
|
|
||||||
def unique_filename(self,opt,seed,previouspath):
|
def unique_filename(self,seed,previouspath):
|
||||||
revision = 1
|
revision = 1
|
||||||
|
|
||||||
if previouspath is None:
|
if previouspath is None:
|
||||||
# sort reverse alphabetically until we find max+1
|
# sort reverse alphabetically until we find max+1
|
||||||
dirlist = sorted(os.listdir(outdir),reverse=True)
|
dirlist = sorted(os.listdir(self.outdir),reverse=True)
|
||||||
# find the first filename that matches our pattern or return 000000.0.png
|
# find the first filename that matches our pattern or return 000000.0.png
|
||||||
filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png')
|
filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png')
|
||||||
basecount = int(filename.split('.',1)[0])
|
basecount = int(filename.split('.',1)[0])
|
||||||
basecount += 1
|
basecount += 1
|
||||||
if opt.batch_size > 1:
|
if self.opt.batch_size > 1:
|
||||||
filename = f'{basecount:06}.{seed}.01.png'
|
filename = f'{basecount:06}.{seed}.01.png'
|
||||||
else:
|
else:
|
||||||
filename = f'{basecount:06}.{seed}.png'
|
filename = f'{basecount:06}.{seed}.png'
|
||||||
return os.path.join(outdir,filename)
|
return os.path.join(self.outdir,filename)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
basename = os.path.basename(previouspath)
|
basename = os.path.basename(previouspath)
|
||||||
x = re.match('^(\d+)\..*\.png',basename)
|
x = re.match('^(\d+)\..*\.png',basename)
|
||||||
if not x:
|
if not x:
|
||||||
return self.unique_filename(opt,seed,previouspath)
|
return self.unique_filename(seed,previouspath)
|
||||||
|
|
||||||
basecount = int(x.groups()[0])
|
basecount = int(x.groups()[0])
|
||||||
series = 0
|
series = 0
|
||||||
@ -135,9 +139,41 @@ class PngWriter:
|
|||||||
while not finished:
|
while not finished:
|
||||||
series += 1
|
series += 1
|
||||||
filename = f'{basecount:06}.{seed}.png'
|
filename = f'{basecount:06}.{seed}.png'
|
||||||
if isbatch or os.path.exists(os.path.join(outdir,filename)):
|
if self.opt.batch_size>1 or os.path.exists(os.path.join(self.outdir,filename)):
|
||||||
filename = f'{basecount:06}.{seed}.{series:02}.png'
|
filename = f'{basecount:06}.{seed}.{series:02}.png'
|
||||||
finished = not os.path.exists(os.path.join(outdir,filename))
|
finished = not os.path.exists(os.path.join(self.outdir,filename))
|
||||||
return os.path.join(outdir,filename)
|
return os.path.join(self.outdir,filename)
|
||||||
|
|
||||||
|
def save_image_and_prompt_to_png(self,image,prompt,path):
|
||||||
|
info = PngImagePlugin.PngInfo()
|
||||||
|
info.add_text("Dream",prompt)
|
||||||
|
image.save(path,"PNG",pnginfo=info)
|
||||||
|
|
||||||
|
class PromptFormatter():
|
||||||
|
def __init__(self,t2i,opt):
|
||||||
|
self.t2i = t2i
|
||||||
|
self.opt = opt
|
||||||
|
|
||||||
|
def normalize_prompt(self):
|
||||||
|
'''Normalize the prompt and switches'''
|
||||||
|
t2i = self.t2i
|
||||||
|
opt = self.opt
|
||||||
|
|
||||||
|
switches = list()
|
||||||
|
switches.append(f'"{opt.prompt}"')
|
||||||
|
switches.append(f'-s{opt.steps or t2i.steps}')
|
||||||
|
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
|
||||||
|
switches.append(f'-W{opt.width or t2i.width}')
|
||||||
|
switches.append(f'-H{opt.height or t2i.height}')
|
||||||
|
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
||||||
|
switches.append(f'-m{t2i.sampler_name}')
|
||||||
|
if opt.variants:
|
||||||
|
switches.append(f'-v{opt.variants}')
|
||||||
|
if opt.init_img:
|
||||||
|
switches.append(f'-I{opt.init_img}')
|
||||||
|
if opt.strength and opt.init_img is not None:
|
||||||
|
switches.append(f'-f{opt.strength or t2i.strength}')
|
||||||
|
if t2i.full_precision:
|
||||||
|
switches.append('-F')
|
||||||
|
return ' '.join(switches)
|
||||||
|
|
||||||
|
@ -99,13 +99,13 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
iterations = 1,
|
iterations = 1,
|
||||||
grid=False,
|
|
||||||
individual=None, # redundant
|
|
||||||
steps=50,
|
steps=50,
|
||||||
seed=None,
|
seed=None,
|
||||||
cfg_scale=7.5,
|
cfg_scale=7.5,
|
||||||
weights="models/ldm/stable-diffusion-v1/model.ckpt",
|
weights="models/ldm/stable-diffusion-v1/model.ckpt",
|
||||||
config = "configs/stable-diffusion/v1-inference.yaml",
|
config = "configs/stable-diffusion/v1-inference.yaml",
|
||||||
|
width=512,
|
||||||
|
height=512,
|
||||||
sampler_name="klms",
|
sampler_name="klms",
|
||||||
latent_channels=4,
|
latent_channels=4,
|
||||||
downsampling_factor=8,
|
downsampling_factor=8,
|
||||||
@ -121,7 +121,6 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
self.iterations = iterations
|
self.iterations = iterations
|
||||||
self.width = width
|
self.width = width
|
||||||
self.height = height
|
self.height = height
|
||||||
self.grid = grid
|
|
||||||
self.steps = steps
|
self.steps = steps
|
||||||
self.cfg_scale = cfg_scale
|
self.cfg_scale = cfg_scale
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
@ -143,7 +142,7 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
else:
|
else:
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
def generate(self,
|
def prompt2image(self,
|
||||||
# these are common
|
# these are common
|
||||||
prompt,
|
prompt,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
@ -160,8 +159,9 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
# these are specific to img2img
|
# these are specific to img2img
|
||||||
init_img=None,
|
init_img=None,
|
||||||
strength=None,
|
strength=None,
|
||||||
variants=None):
|
variants=None,
|
||||||
'''ldm.generate() is the common entry point for txt2img() and img2img()'''
|
**args): # eat up additional cruft
|
||||||
|
'''ldm.prompt2image() is the common entry point for txt2img() and img2img()'''
|
||||||
steps = steps or self.steps
|
steps = steps or self.steps
|
||||||
seed = seed or self.seed
|
seed = seed or self.seed
|
||||||
width = width or self.width
|
width = width or self.width
|
||||||
@ -178,10 +178,6 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
|
|
||||||
data = [batch_size * [prompt]]
|
data = [batch_size * [prompt]]
|
||||||
scope = autocast if self.precision=="autocast" else nullcontext
|
scope = autocast if self.precision=="autocast" else nullcontext
|
||||||
if grid:
|
|
||||||
callback = self.image2png
|
|
||||||
else:
|
|
||||||
callback = None
|
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
if init_img:
|
if init_img:
|
||||||
@ -212,7 +208,7 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
steps,seed,cfg_scale,ddim_eta,
|
steps,seed,cfg_scale,ddim_eta,
|
||||||
skip_normalize,
|
skip_normalize,
|
||||||
width,height,
|
width,height,
|
||||||
callback=callback): # the callback is called each time a new Image is generated
|
callback): # the callback is called each time a new Image is generated
|
||||||
"""
|
"""
|
||||||
Generate an image from the prompt, writing iteration images into the outdir
|
Generate an image from the prompt, writing iteration images into the outdir
|
||||||
The output is a list of lists in the format: [[image1,seed1], [image2,seed2],...]
|
The output is a list of lists in the format: [[image1,seed1], [image2,seed2],...]
|
||||||
@ -224,14 +220,14 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
|
|
||||||
# Gawd. Too many levels of indent here. Need to refactor into smaller routines!
|
# Gawd. Too many levels of indent here. Need to refactor into smaller routines!
|
||||||
try:
|
try:
|
||||||
with precision_scope(self.device.type), model.ema_scope():
|
with precision_scope(self.device.type), self.model.ema_scope():
|
||||||
all_samples = list()
|
all_samples = list()
|
||||||
for n in trange(iterations, desc="Sampling"):
|
for n in trange(iterations, desc="Sampling"):
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
||||||
uc = None
|
uc = None
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
uc = model.get_learned_conditioning(batch_size * [""])
|
uc = self.model.get_learned_conditioning(batch_size * [""])
|
||||||
if isinstance(prompts, tuple):
|
if isinstance(prompts, tuple):
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
|
|
||||||
@ -247,9 +243,9 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
weight = weights[i]
|
weight = weights[i]
|
||||||
if not skip_normalize:
|
if not skip_normalize:
|
||||||
weight = weight / totalWeight
|
weight = weight / totalWeight
|
||||||
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||||
else: # just standard 1 prompt
|
else: # just standard 1 prompt
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = self.model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
||||||
samples_ddim, _ = sampler.sample(S=steps,
|
samples_ddim, _ = sampler.sample(S=steps,
|
||||||
@ -261,7 +257,7 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
unconditional_conditioning=uc,
|
unconditional_conditioning=uc,
|
||||||
eta=ddim_eta)
|
eta=ddim_eta)
|
||||||
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
for x_sample in x_samples_ddim:
|
for x_sample in x_samples_ddim:
|
||||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
@ -277,8 +273,6 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print(str(e))
|
print(str(e))
|
||||||
|
|
||||||
toc = time.time()
|
|
||||||
print(f'{image_count} images generated in',"%4.2fs"% (toc-tic))
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -297,14 +291,14 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
# PLMS sampler not supported yet, so ignore previous sampler
|
# PLMS sampler not supported yet, so ignore previous sampler
|
||||||
if self.sampler_name!='ddim':
|
if self.sampler_name!='ddim':
|
||||||
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler")
|
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler")
|
||||||
sampler = DDIMSampler(model, device=self.device)
|
sampler = DDIMSampler(self.model, device=self.device)
|
||||||
else:
|
else:
|
||||||
sampler = self.sampler
|
sampler = self.sampler
|
||||||
|
|
||||||
init_image = self._load_img(init_img).to(self.device)
|
init_image = self._load_img(init_img).to(self.device)
|
||||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||||
with precision_scope(self.device.type):
|
with precision_scope(self.device.type):
|
||||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
|
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space
|
||||||
|
|
||||||
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
||||||
|
|
||||||
@ -314,14 +308,14 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
images = list()
|
images = list()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with precision_scope(self.device.type), model.ema_scope():
|
with precision_scope(self.device.type), self.model.ema_scope():
|
||||||
all_samples = list()
|
all_samples = list()
|
||||||
for n in trange(iterations, desc="Sampling"):
|
for n in trange(iterations, desc="Sampling"):
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
||||||
uc = None
|
uc = None
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
uc = model.get_learned_conditioning(batch_size * [""])
|
uc = self.model.get_learned_conditioning(batch_size * [""])
|
||||||
if isinstance(prompts, tuple):
|
if isinstance(prompts, tuple):
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
|
|
||||||
@ -337,9 +331,9 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
weight = weights[i]
|
weight = weights[i]
|
||||||
if not skip_normalize:
|
if not skip_normalize:
|
||||||
weight = weight / totalWeight
|
weight = weight / totalWeight
|
||||||
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||||
else: # just standard 1 prompt
|
else: # just standard 1 prompt
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = self.model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
# encode (scaled latent)
|
# encode (scaled latent)
|
||||||
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
|
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
|
||||||
@ -347,7 +341,7 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
|
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
|
||||||
unconditional_conditioning=uc,)
|
unconditional_conditioning=uc,)
|
||||||
|
|
||||||
x_samples = model.decode_first_stage(samples)
|
x_samples = self.model.decode_first_stage(samples)
|
||||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
for x_sample in x_samples:
|
for x_sample in x_samples:
|
||||||
|
@ -6,7 +6,7 @@ import shlex
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import copy
|
import copy
|
||||||
from ldm.dream_util import Completer,PngWriter
|
from ldm.dream_util import Completer,PngWriter,PromptFormatter
|
||||||
|
|
||||||
debugging = False
|
debugging = False
|
||||||
|
|
||||||
@ -27,10 +27,6 @@ def main():
|
|||||||
config = "configs/stable-diffusion/v1-inference.yaml"
|
config = "configs/stable-diffusion/v1-inference.yaml"
|
||||||
weights = "models/ldm/stable-diffusion-v1/model.ckpt"
|
weights = "models/ldm/stable-diffusion-v1/model.ckpt"
|
||||||
|
|
||||||
# command line history will be stored in a file called "~/.dream_history"
|
|
||||||
if readline_available:
|
|
||||||
setup_readline()
|
|
||||||
|
|
||||||
print("* Initializing, be patient...\n")
|
print("* Initializing, be patient...\n")
|
||||||
sys.path.append('.')
|
sys.path.append('.')
|
||||||
from pytorch_lightning import logging
|
from pytorch_lightning import logging
|
||||||
@ -46,8 +42,6 @@ def main():
|
|||||||
# the user input loop
|
# the user input loop
|
||||||
t2i = T2I(width=width,
|
t2i = T2I(width=width,
|
||||||
height=height,
|
height=height,
|
||||||
batch_size=opt.batch_size,
|
|
||||||
outdir=opt.outdir,
|
|
||||||
sampler_name=opt.sampler_name,
|
sampler_name=opt.sampler_name,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
full_precision=opt.full_precision,
|
full_precision=opt.full_precision,
|
||||||
@ -79,13 +73,13 @@ def main():
|
|||||||
log_path = os.path.join(opt.outdir,'dream_log.txt')
|
log_path = os.path.join(opt.outdir,'dream_log.txt')
|
||||||
with open(log_path,'a') as log:
|
with open(log_path,'a') as log:
|
||||||
cmd_parser = create_cmd_parser()
|
cmd_parser = create_cmd_parser()
|
||||||
main_loop(t2i,cmd_parser,log,infile)
|
main_loop(t2i,opt.outdir,cmd_parser,log,infile)
|
||||||
log.close()
|
log.close()
|
||||||
if infile:
|
if infile:
|
||||||
infile.close()
|
infile.close()
|
||||||
|
|
||||||
|
|
||||||
def main_loop(t2i,parser,log,infile):
|
def main_loop(t2i,outdir,parser,log,infile):
|
||||||
''' prompt/read/execute loop '''
|
''' prompt/read/execute loop '''
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
@ -123,13 +117,13 @@ def main_loop(t2i,parser,log,infile):
|
|||||||
if elements[0]=='cd' and len(elements)>1:
|
if elements[0]=='cd' and len(elements)>1:
|
||||||
if os.path.exists(elements[1]):
|
if os.path.exists(elements[1]):
|
||||||
print(f"setting image output directory to {elements[1]}")
|
print(f"setting image output directory to {elements[1]}")
|
||||||
opt.outdir=elements[1]
|
outdir=elements[1]
|
||||||
else:
|
else:
|
||||||
print(f"directory {elements[1]} does not exist")
|
print(f"directory {elements[1]} does not exist")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if elements[0]=='pwd':
|
if elements[0]=='pwd':
|
||||||
print(f"current output directory is {opt.outdir}")
|
print(f"current output directory is {outdir}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command
|
if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command
|
||||||
@ -158,87 +152,40 @@ def main_loop(t2i,parser,log,infile):
|
|||||||
print("Try again with a prompt!")
|
print("Try again with a prompt!")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
normalized_prompt = PromptFormatter(t2i,opt).normalize_prompt()
|
||||||
try:
|
try:
|
||||||
file_writer = PngWriter(opt)
|
file_writer = PngWriter(outdir,opt,normalized_prompt)
|
||||||
opt.callback = file_writer(write_image)
|
callback = file_writer.write_image
|
||||||
run_generator(**vars(opt))
|
|
||||||
|
t2i.prompt2image(image_callback=callback,
|
||||||
|
**vars(opt))
|
||||||
results = file_writer.files_written
|
results = file_writer.files_written
|
||||||
|
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
print(e)
|
print(e)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print("Outputs:")
|
print("Outputs:")
|
||||||
write_log_message(t2i,opt,results,log)
|
write_log_message(t2i,normalized_prompt,results,log)
|
||||||
|
|
||||||
print("goodbye!")
|
print("goodbye!")
|
||||||
|
|
||||||
def write_log_message(t2i,opt,results,logfile):
|
def write_log_message(t2i,prompt,results,logfile):
|
||||||
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata '''
|
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata '''
|
||||||
switches = _reconstruct_switches(t2i,opt)
|
|
||||||
prompt_str = ' '.join(switches)
|
|
||||||
|
|
||||||
# when multiple images are produced in batch, then we keep track of where each starts
|
|
||||||
last_seed = None
|
last_seed = None
|
||||||
img_num = 1
|
img_num = 1
|
||||||
batch_size = opt.batch_size or t2i.batch_size
|
|
||||||
seenit = {}
|
seenit = {}
|
||||||
|
|
||||||
seeds = [a[1] for a in results]
|
seeds = [a[1] for a in results]
|
||||||
if batch_size > 1:
|
|
||||||
seeds = f"(seeds for each batch row: {seeds})"
|
|
||||||
else:
|
|
||||||
seeds = f"(seeds for individual images: {seeds})"
|
seeds = f"(seeds for individual images: {seeds})"
|
||||||
|
|
||||||
for r in results:
|
for r in results:
|
||||||
seed = r[1]
|
seed = r[1]
|
||||||
log_message = (f'{r[0]}: {prompt_str} -S{seed}')
|
log_message = (f'{r[0]}: {prompt} -S{seed}')
|
||||||
|
|
||||||
if batch_size > 1:
|
|
||||||
if seed != last_seed:
|
|
||||||
img_num = 1
|
|
||||||
log_message += f' # (batch image {img_num} of {batch_size})'
|
|
||||||
else:
|
|
||||||
img_num += 1
|
|
||||||
log_message += f' # (batch image {img_num} of {batch_size})'
|
|
||||||
last_seed = seed
|
|
||||||
print(log_message)
|
print(log_message)
|
||||||
logfile.write(log_message+"\n")
|
logfile.write(log_message+"\n")
|
||||||
logfile.flush()
|
logfile.flush()
|
||||||
if r[0] not in seenit:
|
|
||||||
seenit[r[0]] = True
|
|
||||||
try:
|
|
||||||
if opt.grid:
|
|
||||||
_write_prompt_to_png(r[0],f'{prompt_str} -g -S{seed} {seeds}')
|
|
||||||
else:
|
|
||||||
_write_prompt_to_png(r[0],f'{prompt_str} -S{seed}')
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"Could not open file '{r[0]}' for reading")
|
|
||||||
|
|
||||||
def _reconstruct_switches(t2i,opt):
|
|
||||||
'''Normalize the prompt and switches'''
|
|
||||||
switches = list()
|
|
||||||
switches.append(f'"{opt.prompt}"')
|
|
||||||
switches.append(f'-s{opt.steps or t2i.steps}')
|
|
||||||
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
|
|
||||||
switches.append(f'-W{opt.width or t2i.width}')
|
|
||||||
switches.append(f'-H{opt.height or t2i.height}')
|
|
||||||
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
|
||||||
switches.append(f'-m{t2i.sampler_name}')
|
|
||||||
if opt.variants:
|
|
||||||
switches.append(f'-v{opt.variants}')
|
|
||||||
if opt.init_img:
|
|
||||||
switches.append(f'-I{opt.init_img}')
|
|
||||||
if opt.strength and opt.init_img is not None:
|
|
||||||
switches.append(f'-f{opt.strength or t2i.strength}')
|
|
||||||
if t2i.full_precision:
|
|
||||||
switches.append('-F')
|
|
||||||
return switches
|
|
||||||
|
|
||||||
def _write_prompt_to_png(path,prompt):
|
|
||||||
info = PngImagePlugin.PngInfo()
|
|
||||||
info.add_text("Dream",prompt)
|
|
||||||
im = Image.open(path)
|
|
||||||
im.save(path,"PNG",pnginfo=info)
|
|
||||||
|
|
||||||
def create_argv_parser():
|
def create_argv_parser():
|
||||||
parser = argparse.ArgumentParser(description="Parse script's command line args")
|
parser = argparse.ArgumentParser(description="Parse script's command line args")
|
||||||
@ -260,10 +207,6 @@ def create_argv_parser():
|
|||||||
dest='full_precision',
|
dest='full_precision',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="use slower full precision math for calculations")
|
help="use slower full precision math for calculations")
|
||||||
parser.add_argument('-b','--batch_size',
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="number of images to produce per iteration (faster, but doesn't generate individual seeds")
|
|
||||||
parser.add_argument('--sampler','-m',
|
parser.add_argument('--sampler','-m',
|
||||||
dest="sampler_name",
|
dest="sampler_name",
|
||||||
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],
|
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],
|
||||||
|
Reference in New Issue
Block a user