Refactor generate.py and dream.py (#534)

* revert inadvertent change of conda env name (#528)

* Refactor generate.py and dream.py

* config file path (models.yaml) is parsed inside Generate() to simplify
API

* Better handling of keyboard interrupts in file loading mode vs
interactive

* Removed oodles of unused variables.

* move nonfunctional inpainting out of the scripts directory

* fix ugly ddim tqdm formatting
This commit is contained in:
Lincoln Stein 2022-09-14 07:02:31 -04:00 committed by GitHub
parent d15c75ecae
commit e6179af46a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 141 additions and 162 deletions

View File

@ -17,7 +17,7 @@ import transformers
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image, ImageOps from PIL import Image, ImageOps
from torch import nn from torch import nn
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything, logging
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
@ -35,7 +35,7 @@ Example Usage:
from ldm.generate import Generate from ldm.generate import Generate
# Create an object with default values # Create an object with default values
gr = Generate() gr = Generate('stable-diffusion-1.4')
# do the slow model initialization # do the slow model initialization
gr.load_model() gr.load_model()
@ -79,16 +79,17 @@ still work.
The full list of arguments to Generate() are: The full list of arguments to Generate() are:
gr = Generate( gr = Generate(
# these values are set once and shouldn't be changed
conf = path to configuration file ('configs/models.yaml')
model = symbolic name of the model in the configuration file
full_precision = False
# this value is sticky and maintained between generation calls
sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
# these are deprecated - use conf and model instead
weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt') weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml') config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml')
iterations = <integer> // how many times to run the sampling (1)
steps = <integer> // 50
seed = <integer> // current system time
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
grid = <boolean> // false
width = <integer> // image width, multiple of 64 (512)
height = <integer> // image height, multiple of 64 (512)
cfg_scale = <float> // condition-free guidance scale (7.5)
) )
""" """
@ -101,66 +102,62 @@ class Generate:
def __init__( def __init__(
self, self,
iterations = 1, model = 'stable-diffusion-1.4',
steps = 50, conf = 'configs/models.yaml',
cfg_scale = 7.5, embedding_path = None,
weights = 'models/ldm/stable-diffusion-v1/model.ckpt',
config = 'configs/stable-diffusion/v1-inference.yaml',
grid = False,
width = 512,
height = 512,
sampler_name = 'k_lms', sampler_name = 'k_lms',
ddim_eta = 0.0, # deterministic ddim_eta = 0.0, # deterministic
full_precision = False, full_precision = False,
strength = 0.75, # default in scripts/img2img.py # these are deprecated; if present they override values in the conf file
seamless = False, weights = None,
embedding_path = None, config = None,
device_type = 'cuda',
ignore_ctrl_c = False,
): ):
self.iterations = iterations models = OmegaConf.load(conf)
self.width = width mconfig = models[model]
self.height = height self.weights = mconfig.weights if weights is None else weights
self.steps = steps self.config = mconfig.config if config is None else config
self.cfg_scale = cfg_scale self.height = mconfig.height
self.weights = weights self.width = mconfig.width
self.config = config self.iterations = 1
self.sampler_name = sampler_name self.steps = 50
self.grid = grid self.cfg_scale = 7.5
self.ddim_eta = ddim_eta self.sampler_name = sampler_name
self.full_precision = True if choose_torch_device() == 'mps' else full_precision self.ddim_eta = 0.0 # same seed always produces same image
self.strength = strength self.full_precision = True if choose_torch_device() == 'mps' else full_precision
self.seamless = seamless self.strength = 0.75
self.embedding_path = embedding_path self.seamless = False
self.device_type = device_type self.embedding_path = embedding_path
self.ignore_ctrl_c = ignore_ctrl_c # note, this logic probably doesn't belong here... self.model = None # empty for now
self.model = None # empty for now self.sampler = None
self.sampler = None self.device = None
self.device = None self.session_peakmem = None
self.generators = {} self.generators = {}
self.base_generator = None self.base_generator = None
self.seed = None self.seed = None
if device_type == 'cuda' and not torch.cuda.is_available(): # Note that in previous versions, there was an option to pass the
device_type = choose_torch_device() # device to Generate(). However the device was then ignored, so
print(">> cuda not available, using device", device_type) # it wasn't actually doing anything. This logic could be reinstated.
device_type = choose_torch_device()
self.device = torch.device(device_type) self.device = torch.device(device_type)
# for VRAM usage statistics # for VRAM usage statistics
device_type = choose_torch_device() self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
# gets rid of annoying messages about random seed
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
def prompt2png(self, prompt, outdir, **kwargs): def prompt2png(self, prompt, outdir, **kwargs):
""" """
Takes a prompt and an output directory, writes out the requested number Takes a prompt and an output directory, writes out the requested number
of PNG files, and returns an array of [[filename,seed],[filename,seed]...] of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
Optional named arguments are the same as those passed to Generate and prompt2image() Optional named arguments are the same as those passed to Generate and prompt2image()
""" """
results = self.prompt2image(prompt, **kwargs) results = self.prompt2image(prompt, **kwargs)
pngwriter = PngWriter(outdir) pngwriter = PngWriter(outdir)
prefix = pngwriter.unique_prefix() prefix = pngwriter.unique_prefix()
outputs = [] outputs = []
for image, seed in results: for image, seed in results:
name = f'{prefix}.{seed}.png' name = f'{prefix}.{seed}.png'
path = pngwriter.save_image_and_prompt_to_png( path = pngwriter.save_image_and_prompt_to_png(
@ -183,33 +180,35 @@ class Generate:
self, self,
# these are common # these are common
prompt, prompt,
iterations = None, iterations = None,
steps = None, steps = None,
seed = None, seed = None,
cfg_scale = None, cfg_scale = None,
ddim_eta = None, ddim_eta = None,
skip_normalize = False, skip_normalize = False,
image_callback = None, image_callback = None,
step_callback = None, step_callback = None,
width = None, width = None,
height = None, height = None,
sampler_name = None, sampler_name = None,
seamless = False, seamless = False,
log_tokenization= False, log_tokenization = False,
with_variations = None, with_variations = None,
variation_amount = 0.0, variation_amount = 0.0,
# these are specific to img2img and inpaint # these are specific to img2img and inpaint
init_img = None, init_img = None,
init_mask = None, init_mask = None,
fit = False, fit = False,
strength = None, strength = None,
# these are specific to embiggen (which also relies on img2img args) # these are specific to embiggen (which also relies on img2img args)
embiggen = None, embiggen = None,
embiggen_tiles = None, embiggen_tiles = None,
# these are specific to GFPGAN/ESRGAN # these are specific to GFPGAN/ESRGAN
gfpgan_strength= 0, gfpgan_strength = 0,
save_original = False, save_original = False,
upscale = None, upscale = None,
# Set this True to handle KeyboardInterrupt internally
catch_interrupts = False,
**args, **args,
): # eat up additional cruft ): # eat up additional cruft
""" """
@ -262,9 +261,8 @@ class Generate:
self.log_tokenization = log_tokenization self.log_tokenization = log_tokenization
with_variations = [] if with_variations is None else with_variations with_variations = [] if with_variations is None else with_variations
model = ( # will instantiate the model or return it from cache
self.load_model() model = self.load_model()
) # will instantiate the model or return it from cache
for m in model.modules(): for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
@ -281,7 +279,6 @@ class Generate:
(embiggen == None and embiggen_tiles == None) or ((embiggen != None or embiggen_tiles != None) and init_img != None) (embiggen == None and embiggen_tiles == None) or ((embiggen != None or embiggen_tiles != None) and init_img != None)
), 'Embiggen requires an init/input image to be specified' ), 'Embiggen requires an init/input image to be specified'
# check this logic - doesn't look right
if len(with_variations) > 0 or variation_amount > 1.0: if len(with_variations) > 0 or variation_amount > 1.0:
assert seed is not None,\ assert seed is not None,\
'seed must be specified when using with_variations' 'seed must be specified when using with_variations'
@ -298,7 +295,7 @@ class Generate:
self._set_sampler() self._set_sampler()
tic = time.time() tic = time.time()
if torch.cuda.is_available(): if self._has_cuda():
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
results = list() results = list()
@ -307,9 +304,9 @@ class Generate:
try: try:
uc, c = get_uc_and_c( uc, c = get_uc_and_c(
prompt, model=self.model, prompt, model =self.model,
skip_normalize=skip_normalize, skip_normalize=skip_normalize,
log_tokens=self.log_tokenization log_tokens =self.log_tokenization
) )
(init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit) (init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit)
@ -352,27 +349,25 @@ class Generate:
save_original = save_original, save_original = save_original,
image_callback = image_callback) image_callback = image_callback)
except KeyboardInterrupt:
print('*interrupted*')
if not self.ignore_ctrl_c:
raise KeyboardInterrupt
print(
'>> Partial results will be returned; if --grid was requested, nothing will be returned.'
)
except RuntimeError as e: except RuntimeError as e:
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
print('>> Could not generate image.') print('>> Could not generate image.')
except KeyboardInterrupt:
if catch_interrupts:
print('**Interrupted** Partial results will be returned.')
else:
raise KeyboardInterrupt
toc = time.time() toc = time.time()
print('>> Usage stats:') print('>> Usage stats:')
print( print(
f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic) f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
) )
if torch.cuda.is_available() and self.device.type == 'cuda': if self._has_cuda():
print( print(
f'>> Max VRAM used for this generation:', f'>> Max VRAM used for this generation:',
'%4.2fG.' % (torch.cuda.max_memory_allocated() / 1e9), '%4.2fG.' % (torch.cuda.max_memory_allocated() / 1e9),
'Current VRAM utilization:' 'Current VRAM utilization:',
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9), '%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
) )
@ -439,8 +434,7 @@ class Generate:
if self.model is None: if self.model is None:
seed_everything(random.randrange(0, np.iinfo(np.uint32).max)) seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
try: try:
config = OmegaConf.load(self.config) model = self._load_model_from_config(self.config, self.weights)
model = self._load_model_from_config(config, self.weights)
if self.embedding_path is not None: if self.embedding_path is not None:
model.embedding_manager.load( model.embedding_manager.load(
self.embedding_path, self.full_precision self.embedding_path, self.full_precision
@ -541,8 +535,11 @@ class Generate:
print(msg) print(msg)
def _load_model_from_config(self, config, ckpt): # Be warned: config is the path to the model config file, not the dream conf file!
print(f'>> Loading model from {ckpt}') # Also note that we can get config and weights from self, so why do we need to
# pass them as args?
def _load_model_from_config(self, config, weights):
print(f'>> Loading model from {weights}')
# for usage statistics # for usage statistics
device_type = choose_torch_device() device_type = choose_torch_device()
@ -551,10 +548,11 @@ class Generate:
tic = time.time() tic = time.time()
# this does the work # this does the work
pl_sd = torch.load(ckpt, map_location='cpu') c = OmegaConf.load(config)
sd = pl_sd['state_dict'] pl_sd = torch.load(weights, map_location='cpu')
model = instantiate_from_config(config.model) sd = pl_sd['state_dict']
m, u = model.load_state_dict(sd, strict=False) model = instantiate_from_config(c.model)
m, u = model.load_state_dict(sd, strict=False)
if self.full_precision: if self.full_precision:
print( print(
@ -573,7 +571,7 @@ class Generate:
print( print(
f'>> Model loaded in', '%4.2fs' % (toc - tic) f'>> Model loaded in', '%4.2fs' % (toc - tic)
) )
if device_type == 'cuda': if self._has_cuda():
print( print(
'>> Max VRAM used to load the model:', '>> Max VRAM used to load the model:',
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9), '%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
@ -710,3 +708,5 @@ class Generate:
return width, height, resize_needed return width, height, resize_needed
def _has_cuda(self):
return self.device.type == 'cuda'

View File

@ -225,7 +225,7 @@ class DDIMSampler(object):
total_steps = ( total_steps = (
timesteps if ddim_use_original_steps else timesteps.shape[0] timesteps if ddim_use_original_steps else timesteps.shape[0]
) )
print(f'Running DDIM Sampling with {total_steps} timesteps') print(f'\nRunning DDIM Sampling with {total_steps} timesteps')
iterator = tqdm( iterator = tqdm(
time_range, time_range,

View File

@ -33,53 +33,35 @@ def main():
print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.') print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.')
sys.exit(-1) sys.exit(-1)
try:
models = OmegaConf.load(opt.config)
width = models[opt.model].width
height = models[opt.model].height
config = models[opt.model].config
weights = models[opt.model].weights
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
print('* Initializing, be patient...\n') print('* Initializing, be patient...\n')
sys.path.append('.') sys.path.append('.')
from pytorch_lightning import logging
from ldm.generate import Generate from ldm.generate import Generate
# these two lines prevent a horrible warning message from appearing # these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported # when the frozen CLIP tokenizer is imported
import transformers import transformers
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
# creating a simple text2image object with a handful of # creating a simple Generate object with a handful of
# defaults passed on the command line. # defaults passed on the command line.
# additional parameters will be added (or overriden) during # additional parameters will be added (or overriden) during
# the user input loop # the user input loop
t2i = Generate( try:
width=width, gen = Generate(
height=height, conf = opt.config,
sampler_name=opt.sampler_name, model = opt.model,
weights=weights, sampler_name = opt.sampler_name,
full_precision=opt.full_precision, embedding_path = opt.embedding_path,
config=config, full_precision = opt.full_precision,
grid=opt.grid, )
# this is solely for recreating the prompt except (FileNotFoundError, IOError, KeyError) as e:
seamless=opt.seamless, print(f'{e}. Aborting.')
embedding_path=opt.embedding_path, sys.exit(-1)
device_type=opt.device,
ignore_ctrl_c=opt.infile is None,
)
# make sure the output directory exists # make sure the output directory exists
if not os.path.exists(opt.outdir): if not os.path.exists(opt.outdir):
os.makedirs(opt.outdir) os.makedirs(opt.outdir)
# gets rid of annoying messages about random seed
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
# load the infile as a list of lines # load the infile as a list of lines
infile = None infile = None
if opt.infile: if opt.infile:
@ -98,21 +80,23 @@ def main():
print(">> changed to seamless tiling mode") print(">> changed to seamless tiling mode")
# preload the model # preload the model
t2i.load_model() gen.load_model()
if not infile: if not infile:
print( print(
"\n* Initialization done! Awaiting your command (-h for help, 'q' to quit)" "\n* Initialization done! Awaiting your command (-h for help, 'q' to quit)"
) )
cmd_parser = create_cmd_parser() # web server loops forever
if opt.web: if opt.web:
dream_server_loop(t2i, opt.host, opt.port, opt.outdir) dream_server_loop(gen, opt.host, opt.port, opt.outdir)
else: sys.exit(0)
main_loop(t2i, opt.outdir, opt.prompt_as_dir, cmd_parser, infile)
cmd_parser = create_cmd_parser()
main_loop(gen, opt.outdir, opt.prompt_as_dir, cmd_parser, infile)
def main_loop(t2i, outdir, prompt_as_dir, parser, infile): # TODO: main_loop() has gotten busy. Needs to be refactored.
def main_loop(gen, outdir, prompt_as_dir, parser, infile):
"""prompt/read/execute loop""" """prompt/read/execute loop"""
done = False done = False
path_filter = re.compile(r'[<>:"/\\|?*]') path_filter = re.compile(r'[<>:"/\\|?*]')
@ -132,9 +116,6 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
except EOFError: except EOFError:
done = True done = True
continue continue
except KeyboardInterrupt:
done = True
continue
# skip empty lines # skip empty lines
if not command.strip(): if not command.strip():
@ -184,6 +165,7 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
if len(opt.prompt) == 0: if len(opt.prompt) == 0:
print('Try again with a prompt!') print('Try again with a prompt!')
continue continue
# retrieve previous value! # retrieve previous value!
if opt.init_img is not None and re.match('^-\\d+$', opt.init_img): if opt.init_img is not None and re.match('^-\\d+$', opt.init_img):
try: try:
@ -204,8 +186,6 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
opt.seed = None opt.seed = None
continue continue
do_grid = opt.grid or t2i.grid
if opt.with_variations is not None: if opt.with_variations is not None:
# shotgun parsing, woo # shotgun parsing, woo
parts = [] parts = []
@ -258,11 +238,11 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
file_writer = PngWriter(current_outdir) file_writer = PngWriter(current_outdir)
prefix = file_writer.unique_prefix() prefix = file_writer.unique_prefix()
results = [] # list of filename, prompt pairs results = [] # list of filename, prompt pairs
grid_images = dict() # seed -> Image, only used if `do_grid` grid_images = dict() # seed -> Image, only used if `opt.grid`
def image_writer(image, seed, upscaled=False): def image_writer(image, seed, upscaled=False):
path = None path = None
if do_grid: if opt.grid:
grid_images[seed] = image grid_images[seed] = image
else: else:
if upscaled and opt.save_original: if upscaled and opt.save_original:
@ -278,16 +258,16 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
iter_opt.with_variations = opt.with_variations + this_variation iter_opt.with_variations = opt.with_variations + this_variation
iter_opt.variation_amount = 0 iter_opt.variation_amount = 0
normalized_prompt = PromptFormatter( normalized_prompt = PromptFormatter(
t2i, iter_opt).normalize_prompt() gen, iter_opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{iter_opt.seed}' metadata_prompt = f'{normalized_prompt} -S{iter_opt.seed}'
elif opt.with_variations is not None: elif opt.with_variations is not None:
normalized_prompt = PromptFormatter( normalized_prompt = PromptFormatter(
t2i, opt).normalize_prompt() gen, opt).normalize_prompt()
# use the original seed - the per-iteration value is the last variation-seed # use the original seed - the per-iteration value is the last variation-seed
metadata_prompt = f'{normalized_prompt} -S{opt.seed}' metadata_prompt = f'{normalized_prompt} -S{opt.seed}'
else: else:
normalized_prompt = PromptFormatter( normalized_prompt = PromptFormatter(
t2i, opt).normalize_prompt() gen, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{seed}' metadata_prompt = f'{normalized_prompt} -S{seed}'
path = file_writer.save_image_and_prompt_to_png( path = file_writer.save_image_and_prompt_to_png(
image, metadata_prompt, filename) image, metadata_prompt, filename)
@ -296,16 +276,21 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
results.append([path, metadata_prompt]) results.append([path, metadata_prompt])
last_results.append([path, seed]) last_results.append([path, seed])
t2i.prompt2image(image_callback=image_writer, **vars(opt)) catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts
gen.prompt2image(
image_callback=image_writer,
catch_interrupts=catch_ctrl_c,
**vars(opt)
)
if do_grid and len(grid_images) > 0: if opt.grid and len(grid_images) > 0:
grid_img = make_grid(list(grid_images.values())) grid_img = make_grid(list(grid_images.values()))
grid_seeds = list(grid_images.keys()) grid_seeds = list(grid_images.keys())
first_seed = last_results[0][1] first_seed = last_results[0][1]
filename = f'{prefix}.{first_seed}.png' filename = f'{prefix}.{first_seed}.png'
# TODO better metadata for grid images # TODO better metadata for grid images
normalized_prompt = PromptFormatter( normalized_prompt = PromptFormatter(
t2i, opt).normalize_prompt() gen, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{first_seed} --grid -n{len(grid_images)} # {grid_seeds}' metadata_prompt = f'{normalized_prompt} -S{first_seed} --grid -n{len(grid_images)} # {grid_seeds}'
path = file_writer.save_image_and_prompt_to_png( path = file_writer.save_image_and_prompt_to_png(
grid_img, metadata_prompt, filename grid_img, metadata_prompt, filename
@ -337,11 +322,12 @@ def get_next_command(infile=None) -> str: # command string
raise EOFError raise EOFError
else: else:
command = command.strip() command = command.strip()
print(f'#{command}') if len(command)>0:
print(f'#{command}')
return command return command
def dream_server_loop(t2i, host, port, outdir): def dream_server_loop(gen, host, port, outdir):
print('\n* --web was specified, starting web server...') print('\n* --web was specified, starting web server...')
# Change working directory to the stable-diffusion directory # Change working directory to the stable-diffusion directory
os.chdir( os.chdir(
@ -349,7 +335,7 @@ def dream_server_loop(t2i, host, port, outdir):
) )
# Start server # Start server
DreamServer.model = t2i DreamServer.model = gen # misnomer in DreamServer - this is not the model you are looking for
DreamServer.outdir = outdir DreamServer.outdir = outdir
dream_server = ThreadingDreamServer((host, port)) dream_server = ThreadingDreamServer((host, port))
print(">> Started Stable Diffusion dream server!") print(">> Started Stable Diffusion dream server!")
@ -519,13 +505,6 @@ def create_argv_parser():
default='model', default='model',
help='Indicates the Stable Diffusion model to use.', help='Indicates the Stable Diffusion model to use.',
) )
parser.add_argument(
'--device',
'-d',
type=str,
default='cuda',
help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if available"
)
parser.add_argument( parser.add_argument(
'--model', '--model',
default='stable-diffusion-1.4', default='stable-diffusion-1.4',