2022-08-21 23:57:48 +00:00
#!/usr/bin/env python3
2022-08-24 13:22:04 +00:00
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
2022-08-17 01:34:37 +00:00
import argparse
import shlex
import atexit
2022-08-17 02:49:47 +00:00
import os
2022-08-21 15:03:22 +00:00
import sys
2022-08-22 04:12:16 +00:00
from PIL import Image , PngImagePlugin
2022-08-17 01:34:37 +00:00
2022-08-19 03:03:22 +00:00
# readline unavailable on windows systems
2022-08-18 20:00:44 +00:00
try :
import readline
readline_available = True
except :
readline_available = False
2022-08-21 23:57:48 +00:00
debugging = False
2022-08-18 16:43:59 +00:00
2022-08-17 01:34:37 +00:00
def main ( ) :
2022-08-17 16:00:00 +00:00
''' Initialize command-line parsers and the diffusion model '''
2022-08-17 01:34:37 +00:00
arg_parser = create_argv_parser ( )
opt = arg_parser . parse_args ( )
if opt . laion400m :
# defaults suitable to the older latent diffusion weights
width = 256
height = 256
config = " configs/latent-diffusion/txt2img-1p4B-eval.yaml "
weights = " models/ldm/text2img-large/model.ckpt "
else :
# some defaults suitable for stable diffusion weights
width = 512
height = 512
config = " configs/stable-diffusion/v1-inference.yaml "
weights = " models/ldm/stable-diffusion-v1/model.ckpt "
# command line history will be stored in a file called "~/.dream_history"
2022-08-18 20:00:44 +00:00
if readline_available :
setup_readline ( )
2022-08-17 01:34:37 +00:00
print ( " * Initializing, be patient... \n " )
2022-08-21 15:03:22 +00:00
sys . path . append ( ' . ' )
2022-08-17 01:34:37 +00:00
from pytorch_lightning import logging
from ldm . simplet2i import T2I
2022-08-22 19:33:27 +00:00
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers
transformers . logging . set_verbosity_error ( )
2022-08-17 01:34:37 +00:00
# creating a simple text2image object with a handful of
# defaults passed on the command line.
# additional parameters will be added (or overriden) during
# the user input loop
t2i = T2I ( width = width ,
height = height ,
2022-08-18 16:43:59 +00:00
batch_size = opt . batch_size ,
2022-08-17 01:34:37 +00:00
outdir = opt . outdir ,
2022-08-22 04:12:16 +00:00
sampler_name = opt . sampler_name ,
2022-08-17 01:34:37 +00:00
weights = weights ,
2022-08-21 23:57:48 +00:00
full_precision = opt . full_precision ,
2022-08-22 02:48:40 +00:00
config = config ,
2022-08-23 22:26:28 +00:00
latent_diffusion_weights = opt . laion400m , # this is solely for recreating the prompt
embedding_path = opt . embedding_path
2022-08-22 02:48:40 +00:00
)
2022-08-17 01:34:37 +00:00
2022-08-17 02:49:47 +00:00
# make sure the output directory exists
if not os . path . exists ( opt . outdir ) :
os . makedirs ( opt . outdir )
2022-08-17 01:34:37 +00:00
# gets rid of annoying messages about random seed
logging . getLogger ( " pytorch_lightning " ) . setLevel ( logging . ERROR )
2022-08-23 04:30:06 +00:00
infile = None
try :
if opt . infile is not None :
infile = open ( opt . infile , ' r ' )
except FileNotFoundError as e :
print ( e )
exit ( - 1 )
2022-08-17 01:34:37 +00:00
# preload the model
2022-08-18 16:43:59 +00:00
if not debugging :
t2i . load_model ( )
2022-08-23 01:14:31 +00:00
print ( " \n * Initialization done! Awaiting your command (-h for help, ' q ' to quit, ' cd ' to change output dir, ' pwd ' to print output dir)... " )
2022-08-17 01:34:37 +00:00
2022-08-23 04:30:06 +00:00
log_path = os . path . join ( opt . outdir , ' dream_log.txt ' )
2022-08-17 01:34:37 +00:00
with open ( log_path , ' a ' ) as log :
cmd_parser = create_cmd_parser ( )
2022-08-23 04:30:06 +00:00
main_loop ( t2i , cmd_parser , log , infile )
2022-08-17 01:34:37 +00:00
log . close ( )
2022-08-23 04:51:38 +00:00
if infile :
infile . close ( )
2022-08-17 01:34:37 +00:00
2022-08-19 03:23:44 +00:00
2022-08-23 04:30:06 +00:00
def main_loop ( t2i , parser , log , infile ) :
2022-08-17 16:00:00 +00:00
''' prompt/read/execute loop '''
2022-08-19 03:03:22 +00:00
done = False
while not done :
2022-08-17 01:34:37 +00:00
try :
2022-08-23 04:30:06 +00:00
command = infile . readline ( ) if infile else input ( " dream> " )
2022-08-17 01:34:37 +00:00
except EOFError :
2022-08-19 03:03:22 +00:00
done = True
2022-08-17 01:34:37 +00:00
break
2022-08-23 04:30:06 +00:00
if infile and len ( command ) == 0 :
done = True
break
if command . startswith ( ( ' # ' , ' // ' ) ) :
continue
2022-08-23 17:46:50 +00:00
# before splitting, escape single quotes so as not to mess
# up the parser
command = command . replace ( " ' " , " \\ ' " )
2022-08-22 20:55:06 +00:00
try :
elements = shlex . split ( command )
except ValueError as e :
print ( str ( e ) )
continue
2022-08-22 16:40:54 +00:00
if len ( elements ) == 0 :
continue
2022-08-23 04:30:06 +00:00
2022-08-23 01:14:31 +00:00
if elements [ 0 ] == ' q ' :
2022-08-19 03:03:22 +00:00
done = True
break
2022-08-23 01:01:06 +00:00
2022-08-23 03:56:36 +00:00
if elements [ 0 ] == ' cd ' and len ( elements ) > 1 :
if os . path . exists ( elements [ 1 ] ) :
print ( f " setting image output directory to { elements [ 1 ] } " )
t2i . outdir = elements [ 1 ]
else :
print ( f " directory { elements [ 1 ] } does not exist " )
2022-08-23 01:14:31 +00:00
continue
if elements [ 0 ] == ' pwd ' :
print ( f " current output directory is { t2i . outdir } " )
continue
2022-08-19 03:03:22 +00:00
if elements [ 0 ] . startswith ( ' !dream ' ) : # in case a stored prompt still contains the !dream command
elements . pop ( 0 )
# rearrange the arguments to mimic how it works in the Dream bot.
2022-08-17 01:34:37 +00:00
switches = [ ' ' ]
switches_started = False
for el in elements :
if el [ 0 ] == ' - ' and not switches_started :
switches_started = True
if switches_started :
switches . append ( el )
else :
switches [ 0 ] + = el
switches [ 0 ] + = ' '
switches [ 0 ] = switches [ 0 ] [ : len ( switches [ 0 ] ) - 1 ]
2022-08-17 16:35:49 +00:00
2022-08-17 01:34:37 +00:00
try :
opt = parser . parse_args ( switches )
except SystemExit :
parser . print_help ( )
2022-08-17 16:35:49 +00:00
continue
if len ( opt . prompt ) == 0 :
print ( " Try again with a prompt! " )
continue
2022-08-24 13:22:04 +00:00
try :
if opt . init_img is None :
results = t2i . txt2img ( * * vars ( opt ) )
else :
2022-08-24 15:42:44 +00:00
assert os . path . exists ( opt . init_img ) , f " No file found at { opt . init_img } . On Linux systems, pressing <tab> after -I will autocomplete a list of possible image files. "
2022-08-24 13:22:04 +00:00
results = t2i . img2img ( * * vars ( opt ) )
except AssertionError as e :
print ( e )
continue
2022-08-24 15:42:44 +00:00
2022-08-23 04:51:38 +00:00
print ( " Outputs: " )
write_log_message ( t2i , opt , results , log )
2022-08-22 20:55:06 +00:00
2022-08-19 03:23:44 +00:00
2022-08-19 03:03:22 +00:00
print ( " goodbye! " )
2022-08-17 16:00:00 +00:00
2022-08-22 02:48:40 +00:00
def write_log_message ( t2i , opt , results , logfile ) :
2022-08-22 04:12:16 +00:00
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata '''
switches = _reconstruct_switches ( t2i , opt )
2022-08-22 02:48:40 +00:00
prompt_str = ' ' . join ( switches )
2022-08-17 16:00:00 +00:00
2022-08-22 04:12:16 +00:00
# when multiple images are produced in batch, then we keep track of where each starts
2022-08-22 02:48:40 +00:00
last_seed = None
img_num = 1
batch_size = opt . batch_size or t2i . batch_size
2022-08-22 04:12:16 +00:00
seenit = { }
2022-08-23 03:56:36 +00:00
seeds = [ a [ 1 ] for a in results ]
if batch_size > 1 :
seeds = f " (seeds for each batch row: { seeds } ) "
else :
seeds = f " (seeds for individual images: { seeds } ) "
2022-08-17 16:00:00 +00:00
for r in results :
2022-08-22 02:48:40 +00:00
seed = r [ 1 ]
log_message = ( f ' { r [ 0 ] } : { prompt_str } -S { seed } ' )
if batch_size > 1 :
if seed != last_seed :
img_num = 1
2022-08-22 04:12:16 +00:00
log_message + = f ' # (batch image { img_num } of { batch_size } ) '
2022-08-22 02:48:40 +00:00
else :
img_num + = 1
log_message + = f ' # (batch image { img_num } of { batch_size } ) '
last_seed = seed
2022-08-17 16:00:00 +00:00
print ( log_message )
logfile . write ( log_message + " \n " )
logfile . flush ( )
2022-08-22 04:12:16 +00:00
if r [ 0 ] not in seenit :
seenit [ r [ 0 ] ] = True
try :
2022-08-23 03:56:36 +00:00
if opt . grid :
_write_prompt_to_png ( r [ 0 ] , f ' { prompt_str } -g -S { seed } { seeds } ' )
else :
_write_prompt_to_png ( r [ 0 ] , f ' { prompt_str } -S { seed } ' )
2022-08-22 04:12:16 +00:00
except FileNotFoundError :
print ( f " Could not open file ' { r [ 0 ] } ' for reading " )
2022-08-17 16:00:00 +00:00
2022-08-22 02:48:40 +00:00
def _reconstruct_switches ( t2i , opt ) :
''' Normalize the prompt and switches '''
2022-08-22 04:12:16 +00:00
switches = list ( )
switches . append ( f ' " { opt . prompt } " ' )
2022-08-22 02:48:40 +00:00
switches . append ( f ' -s { opt . steps or t2i . steps } ' )
switches . append ( f ' -b { opt . batch_size or t2i . batch_size } ' )
2022-08-22 04:12:16 +00:00
switches . append ( f ' -W { opt . width or t2i . width } ' )
switches . append ( f ' -H { opt . height or t2i . height } ' )
switches . append ( f ' -C { opt . cfg_scale or t2i . cfg_scale } ' )
2022-08-24 15:18:51 +00:00
switches . append ( f ' -m { t2i . sampler_name } ' )
2022-08-22 02:48:40 +00:00
if opt . init_img :
switches . append ( f ' -I { opt . init_img } ' )
2022-08-22 04:12:16 +00:00
if opt . strength and opt . init_img is not None :
switches . append ( f ' -f { opt . strength or t2i . strength } ' )
if t2i . full_precision :
2022-08-22 02:48:40 +00:00
switches . append ( ' -F ' )
return switches
2022-08-22 04:12:16 +00:00
def _write_prompt_to_png ( path , prompt ) :
info = PngImagePlugin . PngInfo ( )
info . add_text ( " Dream " , prompt )
im = Image . open ( path )
im . save ( path , " PNG " , pnginfo = info )
2022-08-17 01:34:37 +00:00
def create_argv_parser ( ) :
parser = argparse . ArgumentParser ( description = " Parse script ' s command line args " )
parser . add_argument ( " --laion400m " ,
" --latent_diffusion " ,
" -l " ,
dest = ' laion400m ' ,
action = ' store_true ' ,
2022-08-22 02:48:40 +00:00
help = " fallback to the latent diffusion (laion400m) weights and config " )
2022-08-23 04:30:06 +00:00
parser . add_argument ( " --from_file " ,
dest = ' infile ' ,
type = str ,
help = " if specified, load prompts from this file " )
2022-08-17 01:34:37 +00:00
parser . add_argument ( ' -n ' , ' --iterations ' ,
type = int ,
default = 1 ,
2022-08-17 02:23:24 +00:00
help = " number of images to generate " )
2022-08-21 23:57:48 +00:00
parser . add_argument ( ' -F ' , ' --full_precision ' ,
dest = ' full_precision ' ,
action = ' store_true ' ,
help = " use slower full precision math for calculations " )
2022-08-18 16:43:59 +00:00
parser . add_argument ( ' -b ' , ' --batch_size ' ,
2022-08-17 01:34:37 +00:00
type = int ,
default = 1 ,
2022-08-22 04:12:16 +00:00
help = " number of images to produce per iteration (faster, but doesn ' t generate individual seeds " )
2022-08-22 02:48:40 +00:00
parser . add_argument ( ' --sampler ' , ' -m ' ,
2022-08-22 04:12:16 +00:00
dest = " sampler_name " ,
2022-08-23 21:25:39 +00:00
choices = [ ' ddim ' , ' k_dpm_2_a ' , ' k_dpm_2 ' , ' k_euler_a ' , ' k_euler ' , ' k_heun ' , ' k_lms ' , ' plms ' ] ,
default = ' k_lms ' ,
help = " which sampler to use (k_lms) - can only be set on command line " )
2022-08-23 14:39:18 +00:00
parser . add_argument ( ' --outdir ' ,
' -o ' ,
2022-08-17 01:34:37 +00:00
type = str ,
2022-08-19 03:23:44 +00:00
default = " outputs/img-samples " ,
2022-08-17 01:34:37 +00:00
help = " directory in which to place generated images and a log of prompts and seeds " )
2022-08-23 22:26:28 +00:00
parser . add_argument ( ' --embedding_path ' ,
type = str ,
help = " Path to a pre-trained embedding manager checkpoint - can only be set on command line " )
2022-08-17 01:34:37 +00:00
return parser
def create_cmd_parser ( ) :
2022-08-17 16:35:49 +00:00
parser = argparse . ArgumentParser ( description = ' Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12 ' )
2022-08-17 01:34:37 +00:00
parser . add_argument ( ' prompt ' )
parser . add_argument ( ' -s ' , ' --steps ' , type = int , help = " number of steps " )
parser . add_argument ( ' -S ' , ' --seed ' , type = int , help = " image seed " )
2022-08-23 14:39:18 +00:00
parser . add_argument ( ' -n ' , ' --iterations ' , type = int , default = 1 , help = " number of samplings to perform (slower, but will provide seeds for individual images) " )
parser . add_argument ( ' -b ' , ' --batch_size ' , type = int , default = 1 , help = " number of images to produce per sampling (will not provide seeds for individual images!) " )
2022-08-17 01:34:37 +00:00
parser . add_argument ( ' -W ' , ' --width ' , type = int , help = " image width, multiple of 64 " )
parser . add_argument ( ' -H ' , ' --height ' , type = int , help = " image height, multiple of 64 " )
2022-08-18 14:47:53 +00:00
parser . add_argument ( ' -C ' , ' --cfg_scale ' , default = 7.5 , type = float , help = " prompt configuration scale " )
2022-08-17 01:34:37 +00:00
parser . add_argument ( ' -g ' , ' --grid ' , action = ' store_true ' , help = " generate a grid " )
2022-08-17 16:00:00 +00:00
parser . add_argument ( ' -i ' , ' --individual ' , action = ' store_true ' , help = " generate individual files (default) " )
2022-08-18 14:47:53 +00:00
parser . add_argument ( ' -I ' , ' --init_img ' , type = str , help = " path to input image (supersedes width and height) " )
parser . add_argument ( ' -f ' , ' --strength ' , default = 0.75 , type = float , help = " strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely " )
2022-08-22 15:03:32 +00:00
parser . add_argument ( ' -x ' , ' --skip_normalize ' , action = ' store_true ' , help = " skip subprompt weight normalization " )
2022-08-17 01:34:37 +00:00
return parser
2022-08-18 20:00:44 +00:00
if readline_available :
def setup_readline ( ) :
2022-08-23 01:14:31 +00:00
readline . set_completer ( Completer ( [ ' cd ' , ' pwd ' ,
' --steps ' , ' -s ' , ' --seed ' , ' -S ' , ' --iterations ' , ' -n ' , ' --batch_size ' , ' -b ' ,
2022-08-18 20:00:44 +00:00
' --width ' , ' -W ' , ' --height ' , ' -H ' , ' --cfg_scale ' , ' -C ' , ' --grid ' , ' -g ' ,
' --individual ' , ' -i ' , ' --init_img ' , ' -I ' , ' --strength ' , ' -f ' ] ) . complete )
readline . set_completer_delims ( " " )
readline . parse_and_bind ( ' tab: complete ' )
load_history ( )
def load_history ( ) :
histfile = os . path . join ( os . path . expanduser ( ' ~ ' ) , " .dream_history " )
2022-08-18 16:43:59 +00:00
try :
2022-08-18 20:00:44 +00:00
readline . read_history_file ( histfile )
readline . set_history_length ( 1000 )
except FileNotFoundError :
pass
atexit . register ( readline . write_history_file , histfile )
2022-08-18 16:43:59 +00:00
2022-08-18 20:00:44 +00:00
class Completer ( ) :
def __init__ ( self , options ) :
self . options = sorted ( options )
return
2022-08-18 16:43:59 +00:00
2022-08-18 20:00:44 +00:00
def complete ( self , text , state ) :
2022-08-23 01:01:06 +00:00
buffer = readline . get_line_buffer ( )
if text . startswith ( ( ' -I ' , ' --init_img ' ) ) :
return self . _path_completions ( text , state , ( ' .png ' ) )
if buffer . strip ( ) . endswith ( ' cd ' ) or text . startswith ( ( ' . ' , ' / ' ) ) :
2022-08-23 01:14:31 +00:00
return self . _path_completions ( text , state , ( ) )
2022-08-18 16:43:59 +00:00
response = None
2022-08-18 20:00:44 +00:00
if state == 0 :
# This is the first time for this text, so build a match list.
if text :
self . matches = [ s
for s in self . options
if s and s . startswith ( text ) ]
else :
self . matches = self . options [ : ]
# Return the state'th item from the match list,
# if we have that many.
try :
response = self . matches [ state ]
except IndexError :
response = None
return response
2022-08-23 01:01:06 +00:00
def _path_completions ( self , text , state , extensions ) :
2022-08-18 20:00:44 +00:00
# get the path so far
if text . startswith ( ' -I ' ) :
path = text . replace ( ' -I ' , ' ' , 1 ) . lstrip ( )
elif text . startswith ( ' --init_img= ' ) :
path = text . replace ( ' --init_img= ' , ' ' , 1 ) . lstrip ( )
2022-08-23 01:14:31 +00:00
else :
path = text
2022-08-18 20:00:44 +00:00
matches = list ( )
path = os . path . expanduser ( path )
if len ( path ) == 0 :
matches . append ( text + ' ./ ' )
else :
dir = os . path . dirname ( path )
dir_list = os . listdir ( dir )
for n in dir_list :
if n . startswith ( ' . ' ) and len ( n ) > 1 :
continue
full_path = os . path . join ( dir , n )
if full_path . startswith ( path ) :
if os . path . isdir ( full_path ) :
matches . append ( os . path . join ( os . path . dirname ( text ) , n ) + ' / ' )
2022-08-23 01:01:06 +00:00
elif n . endswith ( extensions ) :
2022-08-18 20:00:44 +00:00
matches . append ( os . path . join ( os . path . dirname ( text ) , n ) )
try :
response = matches [ state ]
except IndexError :
response = None
return response
2022-08-18 16:43:59 +00:00
2022-08-17 01:34:37 +00:00
if __name__ == " __main__ " :
main ( )
2022-08-24 13:22:04 +00:00