mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
added command-line completion
This commit is contained in:
parent
bf76c4f283
commit
750408f793
@ -8,7 +8,7 @@ t2i = T2I(outdir = <path> // outputs/txt2img-samples
|
|||||||
model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
|
model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
|
||||||
config = <path> // default="configs/stable-diffusion/v1-inference.yaml
|
config = <path> // default="configs/stable-diffusion/v1-inference.yaml
|
||||||
iterations = <integer> // how many times to run the sampling (1)
|
iterations = <integer> // how many times to run the sampling (1)
|
||||||
batch = <integer> // how many images to generate per sampling (1)
|
batch_size = <integer> // how many images to generate per sampling (1)
|
||||||
steps = <integer> // 50
|
steps = <integer> // 50
|
||||||
seed = <integer> // current system time
|
seed = <integer> // current system time
|
||||||
sampler = ['ddim','plms'] // ddim
|
sampler = ['ddim','plms'] // ddim
|
||||||
@ -73,7 +73,7 @@ class T2I:
|
|||||||
model
|
model
|
||||||
config
|
config
|
||||||
iterations
|
iterations
|
||||||
batch
|
batch_size
|
||||||
steps
|
steps
|
||||||
seed
|
seed
|
||||||
sampler
|
sampler
|
||||||
@ -90,7 +90,7 @@ class T2I:
|
|||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
outdir="outputs/txt2img-samples",
|
outdir="outputs/txt2img-samples",
|
||||||
batch=1,
|
batch_size=1,
|
||||||
iterations = 1,
|
iterations = 1,
|
||||||
width=512,
|
width=512,
|
||||||
height=512,
|
height=512,
|
||||||
@ -110,7 +110,7 @@ class T2I:
|
|||||||
strength=0.75 # default in scripts/img2img.py
|
strength=0.75 # default in scripts/img2img.py
|
||||||
):
|
):
|
||||||
self.outdir = outdir
|
self.outdir = outdir
|
||||||
self.batch = batch
|
self.batch_size = batch_size
|
||||||
self.iterations = iterations
|
self.iterations = iterations
|
||||||
self.width = width
|
self.width = width
|
||||||
self.height = height
|
self.height = height
|
||||||
@ -133,7 +133,7 @@ class T2I:
|
|||||||
else:
|
else:
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
def txt2img(self,prompt,outdir=None,batch=None,iterations=None,
|
def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None,
|
||||||
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
||||||
cfg_scale=None,ddim_eta=None,strength=None,init_img=None):
|
cfg_scale=None,ddim_eta=None,strength=None,init_img=None):
|
||||||
"""
|
"""
|
||||||
@ -147,7 +147,7 @@ class T2I:
|
|||||||
height = height or self.height
|
height = height or self.height
|
||||||
cfg_scale = cfg_scale or self.cfg_scale
|
cfg_scale = cfg_scale or self.cfg_scale
|
||||||
ddim_eta = ddim_eta or self.ddim_eta
|
ddim_eta = ddim_eta or self.ddim_eta
|
||||||
batch = batch or self.batch
|
batch_size = batch_size or self.batch_size
|
||||||
iterations = iterations or self.iterations
|
iterations = iterations or self.iterations
|
||||||
strength = strength or self.strength # not actually used here, but preserved for code refactoring
|
strength = strength or self.strength # not actually used here, but preserved for code refactoring
|
||||||
|
|
||||||
@ -160,7 +160,7 @@ class T2I:
|
|||||||
if individual:
|
if individual:
|
||||||
grid = False
|
grid = False
|
||||||
|
|
||||||
data = [batch * [prompt]]
|
data = [batch_size * [prompt]]
|
||||||
|
|
||||||
# make directories and establish names for the output files
|
# make directories and establish names for the output files
|
||||||
os.makedirs(outdir, exist_ok=True)
|
os.makedirs(outdir, exist_ok=True)
|
||||||
@ -168,7 +168,7 @@ class T2I:
|
|||||||
|
|
||||||
start_code = None
|
start_code = None
|
||||||
if self.fixed_code:
|
if self.fixed_code:
|
||||||
start_code = torch.randn([batch,
|
start_code = torch.randn([batch_size,
|
||||||
self.latent_channels,
|
self.latent_channels,
|
||||||
height // self.downsampling_factor,
|
height // self.downsampling_factor,
|
||||||
width // self.downsampling_factor],
|
width // self.downsampling_factor],
|
||||||
@ -190,14 +190,14 @@ class T2I:
|
|||||||
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
||||||
uc = None
|
uc = None
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
uc = model.get_learned_conditioning(batch * [""])
|
uc = model.get_learned_conditioning(batch_size * [""])
|
||||||
if isinstance(prompts, tuple):
|
if isinstance(prompts, tuple):
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = model.get_learned_conditioning(prompts)
|
||||||
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
||||||
samples_ddim, _ = sampler.sample(S=steps,
|
samples_ddim, _ = sampler.sample(S=steps,
|
||||||
conditioning=c,
|
conditioning=c,
|
||||||
batch_size=batch,
|
batch_size_size=batch_size,
|
||||||
shape=shape,
|
shape=shape,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
unconditional_guidance_scale=cfg_scale,
|
unconditional_guidance_scale=cfg_scale,
|
||||||
@ -224,17 +224,17 @@ class T2I:
|
|||||||
if grid:
|
if grid:
|
||||||
images = self._make_grid(samples=all_samples,
|
images = self._make_grid(samples=all_samples,
|
||||||
seeds=seeds,
|
seeds=seeds,
|
||||||
batch_size=batch,
|
batch_size=batch_size,
|
||||||
iterations=iterations,
|
iterations=iterations,
|
||||||
outdir=outdir)
|
outdir=outdir)
|
||||||
|
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
print(f'{batch * iterations} images generated in',"%4.2fs"% (toc-tic))
|
print(f'{batch_size * iterations} images generated in',"%4.2fs"% (toc-tic))
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
# There is lots of shared code between this and txt2img and should be refactored.
|
# There is lots of shared code between this and txt2img and should be refactored.
|
||||||
def img2img(self,prompt,outdir=None,init_img=None,batch=None,iterations=None,
|
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None,
|
||||||
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
||||||
cfg_scale=None,ddim_eta=None,strength=None):
|
cfg_scale=None,ddim_eta=None,strength=None):
|
||||||
"""
|
"""
|
||||||
@ -246,7 +246,7 @@ class T2I:
|
|||||||
seed = seed or self.seed
|
seed = seed or self.seed
|
||||||
cfg_scale = cfg_scale or self.cfg_scale
|
cfg_scale = cfg_scale or self.cfg_scale
|
||||||
ddim_eta = ddim_eta or self.ddim_eta
|
ddim_eta = ddim_eta or self.ddim_eta
|
||||||
batch = batch or self.batch
|
batch_size = batch_size or self.batch_size
|
||||||
iterations = iterations or self.iterations
|
iterations = iterations or self.iterations
|
||||||
strength = strength or self.strength
|
strength = strength or self.strength
|
||||||
|
|
||||||
@ -263,7 +263,7 @@ class T2I:
|
|||||||
if individual:
|
if individual:
|
||||||
grid = False
|
grid = False
|
||||||
|
|
||||||
data = [batch * [prompt]]
|
data = [batch_size * [prompt]]
|
||||||
|
|
||||||
# PLMS sampler not supported yet, so ignore previous sampler
|
# PLMS sampler not supported yet, so ignore previous sampler
|
||||||
if self.sampler_name!='ddim':
|
if self.sampler_name!='ddim':
|
||||||
@ -278,7 +278,7 @@ class T2I:
|
|||||||
|
|
||||||
assert os.path.isfile(init_img)
|
assert os.path.isfile(init_img)
|
||||||
init_image = self._load_img(init_img).to(self.device)
|
init_image = self._load_img(init_img).to(self.device)
|
||||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch)
|
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
|
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
|
||||||
|
|
||||||
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
||||||
@ -307,13 +307,13 @@ class T2I:
|
|||||||
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
||||||
uc = None
|
uc = None
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
uc = model.get_learned_conditioning(batch * [""])
|
uc = model.get_learned_conditioning(batch_size * [""])
|
||||||
if isinstance(prompts, tuple):
|
if isinstance(prompts, tuple):
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
# encode (scaled latent)
|
# encode (scaled latent)
|
||||||
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch).to(self.device))
|
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
|
||||||
# decode it
|
# decode it
|
||||||
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
|
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
|
||||||
unconditional_conditioning=uc,)
|
unconditional_conditioning=uc,)
|
||||||
@ -337,12 +337,12 @@ class T2I:
|
|||||||
if grid:
|
if grid:
|
||||||
images = self._make_grid(samples=all_samples,
|
images = self._make_grid(samples=all_samples,
|
||||||
seeds=seeds,
|
seeds=seeds,
|
||||||
batch_size=batch,
|
batch_size=batch_size,
|
||||||
iterations=iterations,
|
iterations=iterations,
|
||||||
outdir=outdir)
|
outdir=outdir)
|
||||||
|
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
print(f'{batch * iterations} images generated in',"%4.2fs"% (toc-tic))
|
print(f'{batch_size * iterations} images generated in',"%4.2fs"% (toc-tic))
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
@ -6,6 +6,8 @@ import shlex
|
|||||||
import atexit
|
import atexit
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
debugging = False
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
''' Initialize command-line parsers and the diffusion model '''
|
''' Initialize command-line parsers and the diffusion model '''
|
||||||
arg_parser = create_argv_parser()
|
arg_parser = create_argv_parser()
|
||||||
@ -24,7 +26,7 @@ def main():
|
|||||||
weights = "models/ldm/stable-diffusion-v1/model.ckpt"
|
weights = "models/ldm/stable-diffusion-v1/model.ckpt"
|
||||||
|
|
||||||
# command line history will be stored in a file called "~/.dream_history"
|
# command line history will be stored in a file called "~/.dream_history"
|
||||||
load_history()
|
setup_readline()
|
||||||
|
|
||||||
print("* Initializing, be patient...\n")
|
print("* Initializing, be patient...\n")
|
||||||
from pytorch_lightning import logging
|
from pytorch_lightning import logging
|
||||||
@ -36,7 +38,7 @@ def main():
|
|||||||
# the user input loop
|
# the user input loop
|
||||||
t2i = T2I(width=width,
|
t2i = T2I(width=width,
|
||||||
height=height,
|
height=height,
|
||||||
batch=opt.batch,
|
batch_size=opt.batch_size,
|
||||||
outdir=opt.outdir,
|
outdir=opt.outdir,
|
||||||
sampler=opt.sampler,
|
sampler=opt.sampler,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -50,6 +52,7 @@ def main():
|
|||||||
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
|
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
|
||||||
|
|
||||||
# preload the model
|
# preload the model
|
||||||
|
if not debugging:
|
||||||
t2i.load_model()
|
t2i.load_model()
|
||||||
print("\n* Initialization done! Awaiting your command (-h for help)...")
|
print("\n* Initialization done! Awaiting your command (-h for help)...")
|
||||||
|
|
||||||
@ -139,7 +142,7 @@ def create_argv_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="number of images to generate")
|
help="number of images to generate")
|
||||||
parser.add_argument('-b','--batch',
|
parser.add_argument('-b','--batch_size',
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="number of images to produce per iteration (currently not working properly - producing too many images)")
|
help="number of images to produce per iteration (currently not working properly - producing too many images)")
|
||||||
@ -161,7 +164,7 @@ def create_cmd_parser():
|
|||||||
parser.add_argument('-s','--steps',type=int,help="number of steps")
|
parser.add_argument('-s','--steps',type=int,help="number of steps")
|
||||||
parser.add_argument('-S','--seed',type=int,help="image seed")
|
parser.add_argument('-S','--seed',type=int,help="image seed")
|
||||||
parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform")
|
parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform")
|
||||||
parser.add_argument('-b','--batch',type=int,default=1,help="number of images to produce per sampling (currently broken)")
|
parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling (currently broken)")
|
||||||
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
|
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
|
||||||
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
|
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
|
||||||
parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale")
|
parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale")
|
||||||
@ -171,6 +174,14 @@ def create_cmd_parser():
|
|||||||
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely")
|
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
def setup_readline():
|
||||||
|
readline.set_completer(Completer(['--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b',
|
||||||
|
'--width','-W','--height','-H','--cfg_scale','-C','--grid','-g',
|
||||||
|
'--individual','-i','--init_img','-I','--strength','-f']).complete)
|
||||||
|
readline.set_completer_delims(" ")
|
||||||
|
readline.parse_and_bind('tab: complete')
|
||||||
|
load_history()
|
||||||
|
|
||||||
def load_history():
|
def load_history():
|
||||||
histfile = os.path.join(os.path.expanduser('~'),".dream_history")
|
histfile = os.path.join(os.path.expanduser('~'),".dream_history")
|
||||||
try:
|
try:
|
||||||
@ -180,5 +191,64 @@ def load_history():
|
|||||||
pass
|
pass
|
||||||
atexit.register(readline.write_history_file,histfile)
|
atexit.register(readline.write_history_file,histfile)
|
||||||
|
|
||||||
|
class Completer():
|
||||||
|
def __init__(self,options):
|
||||||
|
self.options = sorted(options)
|
||||||
|
return
|
||||||
|
|
||||||
|
def complete(self,text,state):
|
||||||
|
if text.startswith('-I') or text.startswith('--init_img'):
|
||||||
|
return self._image_completions(text,state)
|
||||||
|
|
||||||
|
response = None
|
||||||
|
if state == 0:
|
||||||
|
# This is the first time for this text, so build a match list.
|
||||||
|
if text:
|
||||||
|
self.matches = [s
|
||||||
|
for s in self.options
|
||||||
|
if s and s.startswith(text)]
|
||||||
|
else:
|
||||||
|
self.matches = self.options[:]
|
||||||
|
|
||||||
|
# Return the state'th item from the match list,
|
||||||
|
# if we have that many.
|
||||||
|
try:
|
||||||
|
response = self.matches[state]
|
||||||
|
except IndexError:
|
||||||
|
response = None
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _image_completions(self,text,state):
|
||||||
|
# get the path so far
|
||||||
|
if text.startswith('-I'):
|
||||||
|
path = text.replace('-I','',1).lstrip()
|
||||||
|
elif text.startswith('--init_img='):
|
||||||
|
path = text.replace('--init_img=','',1).lstrip()
|
||||||
|
|
||||||
|
matches = list()
|
||||||
|
|
||||||
|
path = os.path.expanduser(path)
|
||||||
|
if len(path)==0:
|
||||||
|
matches.append(text+'./')
|
||||||
|
else:
|
||||||
|
dir = os.path.dirname(path)
|
||||||
|
dir_list = os.listdir(dir)
|
||||||
|
for n in dir_list:
|
||||||
|
if n.startswith('.') and len(n)>1:
|
||||||
|
continue
|
||||||
|
full_path = os.path.join(dir,n)
|
||||||
|
if full_path.startswith(path):
|
||||||
|
if os.path.isdir(full_path):
|
||||||
|
matches.append(os.path.join(os.path.dirname(text),n)+'/')
|
||||||
|
elif n.endswith('.png'):
|
||||||
|
matches.append(os.path.join(os.path.dirname(text),n))
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = matches[state]
|
||||||
|
except IndexError:
|
||||||
|
response = None
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user