mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
added command-line completion
This commit is contained in:
parent
bf76c4f283
commit
750408f793
@ -8,7 +8,7 @@ t2i = T2I(outdir = <path> // outputs/txt2img-samples
|
||||
model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
|
||||
config = <path> // default="configs/stable-diffusion/v1-inference.yaml
|
||||
iterations = <integer> // how many times to run the sampling (1)
|
||||
batch = <integer> // how many images to generate per sampling (1)
|
||||
batch_size = <integer> // how many images to generate per sampling (1)
|
||||
steps = <integer> // 50
|
||||
seed = <integer> // current system time
|
||||
sampler = ['ddim','plms'] // ddim
|
||||
@ -73,7 +73,7 @@ class T2I:
|
||||
model
|
||||
config
|
||||
iterations
|
||||
batch
|
||||
batch_size
|
||||
steps
|
||||
seed
|
||||
sampler
|
||||
@ -90,7 +90,7 @@ class T2I:
|
||||
"""
|
||||
def __init__(self,
|
||||
outdir="outputs/txt2img-samples",
|
||||
batch=1,
|
||||
batch_size=1,
|
||||
iterations = 1,
|
||||
width=512,
|
||||
height=512,
|
||||
@ -110,7 +110,7 @@ class T2I:
|
||||
strength=0.75 # default in scripts/img2img.py
|
||||
):
|
||||
self.outdir = outdir
|
||||
self.batch = batch
|
||||
self.batch_size = batch_size
|
||||
self.iterations = iterations
|
||||
self.width = width
|
||||
self.height = height
|
||||
@ -133,7 +133,7 @@ class T2I:
|
||||
else:
|
||||
self.seed = seed
|
||||
|
||||
def txt2img(self,prompt,outdir=None,batch=None,iterations=None,
|
||||
def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None,
|
||||
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
||||
cfg_scale=None,ddim_eta=None,strength=None,init_img=None):
|
||||
"""
|
||||
@ -147,7 +147,7 @@ class T2I:
|
||||
height = height or self.height
|
||||
cfg_scale = cfg_scale or self.cfg_scale
|
||||
ddim_eta = ddim_eta or self.ddim_eta
|
||||
batch = batch or self.batch
|
||||
batch_size = batch_size or self.batch_size
|
||||
iterations = iterations or self.iterations
|
||||
strength = strength or self.strength # not actually used here, but preserved for code refactoring
|
||||
|
||||
@ -160,7 +160,7 @@ class T2I:
|
||||
if individual:
|
||||
grid = False
|
||||
|
||||
data = [batch * [prompt]]
|
||||
data = [batch_size * [prompt]]
|
||||
|
||||
# make directories and establish names for the output files
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
@ -168,7 +168,7 @@ class T2I:
|
||||
|
||||
start_code = None
|
||||
if self.fixed_code:
|
||||
start_code = torch.randn([batch,
|
||||
start_code = torch.randn([batch_size,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
@ -190,14 +190,14 @@ class T2I:
|
||||
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
||||
uc = None
|
||||
if cfg_scale != 1.0:
|
||||
uc = model.get_learned_conditioning(batch * [""])
|
||||
uc = model.get_learned_conditioning(batch_size * [""])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
||||
samples_ddim, _ = sampler.sample(S=steps,
|
||||
conditioning=c,
|
||||
batch_size=batch,
|
||||
batch_size_size=batch_size,
|
||||
shape=shape,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
@ -224,17 +224,17 @@ class T2I:
|
||||
if grid:
|
||||
images = self._make_grid(samples=all_samples,
|
||||
seeds=seeds,
|
||||
batch_size=batch,
|
||||
batch_size=batch_size,
|
||||
iterations=iterations,
|
||||
outdir=outdir)
|
||||
|
||||
toc = time.time()
|
||||
print(f'{batch * iterations} images generated in',"%4.2fs"% (toc-tic))
|
||||
print(f'{batch_size * iterations} images generated in',"%4.2fs"% (toc-tic))
|
||||
|
||||
return images
|
||||
|
||||
# There is lots of shared code between this and txt2img and should be refactored.
|
||||
def img2img(self,prompt,outdir=None,init_img=None,batch=None,iterations=None,
|
||||
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None,
|
||||
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
||||
cfg_scale=None,ddim_eta=None,strength=None):
|
||||
"""
|
||||
@ -246,7 +246,7 @@ class T2I:
|
||||
seed = seed or self.seed
|
||||
cfg_scale = cfg_scale or self.cfg_scale
|
||||
ddim_eta = ddim_eta or self.ddim_eta
|
||||
batch = batch or self.batch
|
||||
batch_size = batch_size or self.batch_size
|
||||
iterations = iterations or self.iterations
|
||||
strength = strength or self.strength
|
||||
|
||||
@ -263,7 +263,7 @@ class T2I:
|
||||
if individual:
|
||||
grid = False
|
||||
|
||||
data = [batch * [prompt]]
|
||||
data = [batch_size * [prompt]]
|
||||
|
||||
# PLMS sampler not supported yet, so ignore previous sampler
|
||||
if self.sampler_name!='ddim':
|
||||
@ -278,7 +278,7 @@ class T2I:
|
||||
|
||||
assert os.path.isfile(init_img)
|
||||
init_image = self._load_img(init_img).to(self.device)
|
||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch)
|
||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
|
||||
|
||||
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
||||
@ -307,13 +307,13 @@ class T2I:
|
||||
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
||||
uc = None
|
||||
if cfg_scale != 1.0:
|
||||
uc = model.get_learned_conditioning(batch * [""])
|
||||
uc = model.get_learned_conditioning(batch_size * [""])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch).to(self.device))
|
||||
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
|
||||
# decode it
|
||||
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,)
|
||||
@ -337,12 +337,12 @@ class T2I:
|
||||
if grid:
|
||||
images = self._make_grid(samples=all_samples,
|
||||
seeds=seeds,
|
||||
batch_size=batch,
|
||||
batch_size=batch_size,
|
||||
iterations=iterations,
|
||||
outdir=outdir)
|
||||
|
||||
toc = time.time()
|
||||
print(f'{batch * iterations} images generated in',"%4.2fs"% (toc-tic))
|
||||
print(f'{batch_size * iterations} images generated in',"%4.2fs"% (toc-tic))
|
||||
|
||||
return images
|
||||
|
||||
|
@ -6,6 +6,8 @@ import shlex
|
||||
import atexit
|
||||
import os
|
||||
|
||||
debugging = False
|
||||
|
||||
def main():
|
||||
''' Initialize command-line parsers and the diffusion model '''
|
||||
arg_parser = create_argv_parser()
|
||||
@ -24,7 +26,7 @@ def main():
|
||||
weights = "models/ldm/stable-diffusion-v1/model.ckpt"
|
||||
|
||||
# command line history will be stored in a file called "~/.dream_history"
|
||||
load_history()
|
||||
setup_readline()
|
||||
|
||||
print("* Initializing, be patient...\n")
|
||||
from pytorch_lightning import logging
|
||||
@ -36,7 +38,7 @@ def main():
|
||||
# the user input loop
|
||||
t2i = T2I(width=width,
|
||||
height=height,
|
||||
batch=opt.batch,
|
||||
batch_size=opt.batch_size,
|
||||
outdir=opt.outdir,
|
||||
sampler=opt.sampler,
|
||||
weights=weights,
|
||||
@ -50,7 +52,8 @@ def main():
|
||||
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
|
||||
|
||||
# preload the model
|
||||
t2i.load_model()
|
||||
if not debugging:
|
||||
t2i.load_model()
|
||||
print("\n* Initialization done! Awaiting your command (-h for help)...")
|
||||
|
||||
log_path = os.path.join(opt.outdir,"dream_log.txt")
|
||||
@ -139,7 +142,7 @@ def create_argv_parser():
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of images to generate")
|
||||
parser.add_argument('-b','--batch',
|
||||
parser.add_argument('-b','--batch_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of images to produce per iteration (currently not working properly - producing too many images)")
|
||||
@ -161,7 +164,7 @@ def create_cmd_parser():
|
||||
parser.add_argument('-s','--steps',type=int,help="number of steps")
|
||||
parser.add_argument('-S','--seed',type=int,help="image seed")
|
||||
parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform")
|
||||
parser.add_argument('-b','--batch',type=int,default=1,help="number of images to produce per sampling (currently broken)")
|
||||
parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling (currently broken)")
|
||||
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
|
||||
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
|
||||
parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale")
|
||||
@ -171,6 +174,14 @@ def create_cmd_parser():
|
||||
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely")
|
||||
return parser
|
||||
|
||||
def setup_readline():
|
||||
readline.set_completer(Completer(['--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b',
|
||||
'--width','-W','--height','-H','--cfg_scale','-C','--grid','-g',
|
||||
'--individual','-i','--init_img','-I','--strength','-f']).complete)
|
||||
readline.set_completer_delims(" ")
|
||||
readline.parse_and_bind('tab: complete')
|
||||
load_history()
|
||||
|
||||
def load_history():
|
||||
histfile = os.path.join(os.path.expanduser('~'),".dream_history")
|
||||
try:
|
||||
@ -180,5 +191,64 @@ def load_history():
|
||||
pass
|
||||
atexit.register(readline.write_history_file,histfile)
|
||||
|
||||
class Completer():
|
||||
def __init__(self,options):
|
||||
self.options = sorted(options)
|
||||
return
|
||||
|
||||
def complete(self,text,state):
|
||||
if text.startswith('-I') or text.startswith('--init_img'):
|
||||
return self._image_completions(text,state)
|
||||
|
||||
response = None
|
||||
if state == 0:
|
||||
# This is the first time for this text, so build a match list.
|
||||
if text:
|
||||
self.matches = [s
|
||||
for s in self.options
|
||||
if s and s.startswith(text)]
|
||||
else:
|
||||
self.matches = self.options[:]
|
||||
|
||||
# Return the state'th item from the match list,
|
||||
# if we have that many.
|
||||
try:
|
||||
response = self.matches[state]
|
||||
except IndexError:
|
||||
response = None
|
||||
return response
|
||||
|
||||
def _image_completions(self,text,state):
|
||||
# get the path so far
|
||||
if text.startswith('-I'):
|
||||
path = text.replace('-I','',1).lstrip()
|
||||
elif text.startswith('--init_img='):
|
||||
path = text.replace('--init_img=','',1).lstrip()
|
||||
|
||||
matches = list()
|
||||
|
||||
path = os.path.expanduser(path)
|
||||
if len(path)==0:
|
||||
matches.append(text+'./')
|
||||
else:
|
||||
dir = os.path.dirname(path)
|
||||
dir_list = os.listdir(dir)
|
||||
for n in dir_list:
|
||||
if n.startswith('.') and len(n)>1:
|
||||
continue
|
||||
full_path = os.path.join(dir,n)
|
||||
if full_path.startswith(path):
|
||||
if os.path.isdir(full_path):
|
||||
matches.append(os.path.join(os.path.dirname(text),n)+'/')
|
||||
elif n.endswith('.png'):
|
||||
matches.append(os.path.join(os.path.dirname(text),n))
|
||||
|
||||
try:
|
||||
response = matches[state]
|
||||
except IndexError:
|
||||
response = None
|
||||
return response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user