mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
324 lines
12 KiB
Python
Executable File
324 lines
12 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
import argparse
|
|
import shlex
|
|
import atexit
|
|
import os
|
|
import sys
|
|
from PIL import Image,PngImagePlugin
|
|
|
|
# readline unavailable on windows systems
|
|
try:
|
|
import readline
|
|
readline_available = True
|
|
except:
|
|
readline_available = False
|
|
|
|
debugging = False
|
|
|
|
def main():
|
|
''' Initialize command-line parsers and the diffusion model '''
|
|
arg_parser = create_argv_parser()
|
|
opt = arg_parser.parse_args()
|
|
if opt.laion400m:
|
|
# defaults suitable to the older latent diffusion weights
|
|
width = 256
|
|
height = 256
|
|
config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
|
|
weights = "models/ldm/text2img-large/model.ckpt"
|
|
else:
|
|
# some defaults suitable for stable diffusion weights
|
|
width = 512
|
|
height = 512
|
|
config = "configs/stable-diffusion/v1-inference.yaml"
|
|
weights = "models/ldm/stable-diffusion-v1/model.ckpt"
|
|
|
|
# command line history will be stored in a file called "~/.dream_history"
|
|
if readline_available:
|
|
setup_readline()
|
|
|
|
print("* Initializing, be patient...\n")
|
|
sys.path.append('.')
|
|
from pytorch_lightning import logging
|
|
from ldm.simplet2i import T2I
|
|
# these two lines prevent a horrible warning message from appearing
|
|
# when the frozen CLIP tokenizer is imported
|
|
import transformers
|
|
transformers.logging.set_verbosity_error()
|
|
|
|
# creating a simple text2image object with a handful of
|
|
# defaults passed on the command line.
|
|
# additional parameters will be added (or overriden) during
|
|
# the user input loop
|
|
t2i = T2I(width=width,
|
|
height=height,
|
|
batch_size=opt.batch_size,
|
|
outdir=opt.outdir,
|
|
sampler_name=opt.sampler_name,
|
|
weights=weights,
|
|
full_precision=opt.full_precision,
|
|
config=config,
|
|
latent_diffusion_weights=opt.laion400m # this is solely for recreating the prompt
|
|
)
|
|
|
|
# make sure the output directory exists
|
|
if not os.path.exists(opt.outdir):
|
|
os.makedirs(opt.outdir)
|
|
|
|
# gets rid of annoying messages about random seed
|
|
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
|
|
|
|
# preload the model
|
|
if not debugging:
|
|
t2i.load_model()
|
|
print("\n* Initialization done! Awaiting your command (-h for help, q to quit)...")
|
|
|
|
log_path = os.path.join(opt.outdir,"dream_log.txt")
|
|
with open(log_path,'a') as log:
|
|
cmd_parser = create_cmd_parser()
|
|
main_loop(t2i,cmd_parser,log)
|
|
log.close()
|
|
|
|
|
|
def main_loop(t2i,parser,log):
|
|
''' prompt/read/execute loop '''
|
|
done = False
|
|
|
|
while not done:
|
|
try:
|
|
command = input("dream> ")
|
|
except EOFError:
|
|
done = True
|
|
break
|
|
|
|
elements = shlex.split(command)
|
|
if len(elements)==0:
|
|
continue
|
|
|
|
if elements[0]=='q': #
|
|
done = True
|
|
break
|
|
if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command
|
|
elements.pop(0)
|
|
|
|
# rearrange the arguments to mimic how it works in the Dream bot.
|
|
switches = ['']
|
|
switches_started = False
|
|
|
|
for el in elements:
|
|
if el[0]=='-' and not switches_started:
|
|
switches_started = True
|
|
if switches_started:
|
|
switches.append(el)
|
|
else:
|
|
switches[0] += el
|
|
switches[0] += ' '
|
|
switches[0] = switches[0][:len(switches[0])-1]
|
|
|
|
try:
|
|
opt = parser.parse_args(switches)
|
|
except SystemExit:
|
|
parser.print_help()
|
|
continue
|
|
if len(opt.prompt)==0:
|
|
print("Try again with a prompt!")
|
|
continue
|
|
|
|
try:
|
|
if opt.init_img is None:
|
|
results = t2i.txt2img(**vars(opt))
|
|
else:
|
|
results = t2i.img2img(**vars(opt))
|
|
print("Outputs:")
|
|
write_log_message(t2i,opt,results,log)
|
|
except KeyboardInterrupt:
|
|
print('*interrupted*')
|
|
continue
|
|
|
|
print("goodbye!")
|
|
|
|
|
|
def write_log_message(t2i,opt,results,logfile):
|
|
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata '''
|
|
switches = _reconstruct_switches(t2i,opt)
|
|
prompt_str = ' '.join(switches)
|
|
|
|
# when multiple images are produced in batch, then we keep track of where each starts
|
|
last_seed = None
|
|
img_num = 1
|
|
batch_size = opt.batch_size or t2i.batch_size
|
|
seenit = {}
|
|
|
|
for r in results:
|
|
seed = r[1]
|
|
log_message = (f'{r[0]}: {prompt_str} -S{seed}')
|
|
|
|
if batch_size > 1:
|
|
if seed != last_seed:
|
|
img_num = 1
|
|
log_message += f' # (batch image {img_num} of {batch_size})'
|
|
else:
|
|
img_num += 1
|
|
log_message += f' # (batch image {img_num} of {batch_size})'
|
|
last_seed = seed
|
|
print(log_message)
|
|
logfile.write(log_message+"\n")
|
|
logfile.flush()
|
|
if r[0] not in seenit:
|
|
seenit[r[0]] = True
|
|
try:
|
|
_write_prompt_to_png(r[0],f'{prompt_str} -S{seed}')
|
|
except FileNotFoundError:
|
|
print(f"Could not open file '{r[0]}' for reading")
|
|
|
|
def _reconstruct_switches(t2i,opt):
|
|
'''Normalize the prompt and switches'''
|
|
switches = list()
|
|
switches.append(f'"{opt.prompt}"')
|
|
switches.append(f'-s{opt.steps or t2i.steps}')
|
|
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
|
|
switches.append(f'-W{opt.width or t2i.width}')
|
|
switches.append(f'-H{opt.height or t2i.height}')
|
|
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
|
if opt.init_img:
|
|
switches.append(f'-I{opt.init_img}')
|
|
if opt.strength and opt.init_img is not None:
|
|
switches.append(f'-f{opt.strength or t2i.strength}')
|
|
if t2i.full_precision:
|
|
switches.append('-F')
|
|
return switches
|
|
|
|
def _write_prompt_to_png(path,prompt):
|
|
info = PngImagePlugin.PngInfo()
|
|
info.add_text("Dream",prompt)
|
|
im = Image.open(path)
|
|
im.save(path,"PNG",pnginfo=info)
|
|
|
|
def create_argv_parser():
|
|
parser = argparse.ArgumentParser(description="Parse script's command line args")
|
|
parser.add_argument("--laion400m",
|
|
"--latent_diffusion",
|
|
"-l",
|
|
dest='laion400m',
|
|
action='store_true',
|
|
help="fallback to the latent diffusion (laion400m) weights and config")
|
|
parser.add_argument('-n','--iterations',
|
|
type=int,
|
|
default=1,
|
|
help="number of images to generate")
|
|
parser.add_argument('-F','--full_precision',
|
|
dest='full_precision',
|
|
action='store_true',
|
|
help="use slower full precision math for calculations")
|
|
parser.add_argument('-b','--batch_size',
|
|
type=int,
|
|
default=1,
|
|
help="number of images to produce per iteration (faster, but doesn't generate individual seeds")
|
|
parser.add_argument('--sampler','-m',
|
|
dest="sampler_name",
|
|
choices=['plms','ddim', 'klms'],
|
|
default='klms',
|
|
help="which sampler to use (klms) - can only be set on command line")
|
|
parser.add_argument('-o',
|
|
'--outdir',
|
|
type=str,
|
|
default="outputs/img-samples",
|
|
help="directory in which to place generated images and a log of prompts and seeds")
|
|
return parser
|
|
|
|
|
|
def create_cmd_parser():
|
|
parser = argparse.ArgumentParser(description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12')
|
|
parser.add_argument('prompt')
|
|
parser.add_argument('-s','--steps',type=int,help="number of steps")
|
|
parser.add_argument('-S','--seed',type=int,help="image seed")
|
|
parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform")
|
|
parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling")
|
|
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
|
|
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
|
|
parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale")
|
|
parser.add_argument('-g','--grid',action='store_true',help="generate a grid")
|
|
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)")
|
|
parser.add_argument('-I','--init_img',type=str,help="path to input image (supersedes width and height)")
|
|
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely")
|
|
return parser
|
|
|
|
if readline_available:
|
|
def setup_readline():
|
|
readline.set_completer(Completer(['--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b',
|
|
'--width','-W','--height','-H','--cfg_scale','-C','--grid','-g',
|
|
'--individual','-i','--init_img','-I','--strength','-f']).complete)
|
|
readline.set_completer_delims(" ")
|
|
readline.parse_and_bind('tab: complete')
|
|
load_history()
|
|
|
|
def load_history():
|
|
histfile = os.path.join(os.path.expanduser('~'),".dream_history")
|
|
try:
|
|
readline.read_history_file(histfile)
|
|
readline.set_history_length(1000)
|
|
except FileNotFoundError:
|
|
pass
|
|
atexit.register(readline.write_history_file,histfile)
|
|
|
|
class Completer():
|
|
def __init__(self,options):
|
|
self.options = sorted(options)
|
|
return
|
|
|
|
def complete(self,text,state):
|
|
if text.startswith('-I') or text.startswith('--init_img'):
|
|
return self._image_completions(text,state)
|
|
|
|
response = None
|
|
if state == 0:
|
|
# This is the first time for this text, so build a match list.
|
|
if text:
|
|
self.matches = [s
|
|
for s in self.options
|
|
if s and s.startswith(text)]
|
|
else:
|
|
self.matches = self.options[:]
|
|
|
|
# Return the state'th item from the match list,
|
|
# if we have that many.
|
|
try:
|
|
response = self.matches[state]
|
|
except IndexError:
|
|
response = None
|
|
return response
|
|
|
|
def _image_completions(self,text,state):
|
|
# get the path so far
|
|
if text.startswith('-I'):
|
|
path = text.replace('-I','',1).lstrip()
|
|
elif text.startswith('--init_img='):
|
|
path = text.replace('--init_img=','',1).lstrip()
|
|
|
|
matches = list()
|
|
|
|
path = os.path.expanduser(path)
|
|
if len(path)==0:
|
|
matches.append(text+'./')
|
|
else:
|
|
dir = os.path.dirname(path)
|
|
dir_list = os.listdir(dir)
|
|
for n in dir_list:
|
|
if n.startswith('.') and len(n)>1:
|
|
continue
|
|
full_path = os.path.join(dir,n)
|
|
if full_path.startswith(path):
|
|
if os.path.isdir(full_path):
|
|
matches.append(os.path.join(os.path.dirname(text),n)+'/')
|
|
elif n.endswith('.png'):
|
|
matches.append(os.path.join(os.path.dirname(text),n))
|
|
|
|
try:
|
|
response = matches[state]
|
|
except IndexError:
|
|
response = None
|
|
return response
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|