grid is broken, needs the grid-fix PR#166 to fix

This commit is contained in:
Lincoln Stein 2022-08-29 13:39:20 -04:00
parent 90cbc6362c
commit c7db038c96
3 changed files with 19 additions and 3 deletions

View File

@ -86,7 +86,7 @@ class PngWriter:
if None in (rows, cols): if None in (rows, cols):
rows = floor(sqrt(image_cnt)) # try to make it square rows = floor(sqrt(image_cnt)) # try to make it square
cols = ceil(image_cnt / rows) cols = ceil(image_cnt / rows)
width = image_list[0].width width = image_list[0].width
height = image_list[0].height height = image_list[0].height
grid_img = Image.new('RGB', (width * cols, height * rows)) grid_img = Image.new('RGB', (width * cols, height * rows))

View File

@ -122,6 +122,7 @@ class T2I:
cfg_scale=7.5, cfg_scale=7.5,
weights='models/ldm/stable-diffusion-v1/model.ckpt', weights='models/ldm/stable-diffusion-v1/model.ckpt',
config='configs/stable-diffusion/v1-inference.yaml', config='configs/stable-diffusion/v1-inference.yaml',
grid=False,
width=512, width=512,
height=512, height=512,
sampler_name='klms', sampler_name='klms',
@ -147,6 +148,7 @@ class T2I:
self.sampler_name = sampler_name self.sampler_name = sampler_name
self.latent_channels = latent_channels self.latent_channels = latent_channels
self.downsampling_factor = downsampling_factor self.downsampling_factor = downsampling_factor
self.grid = grid
self.ddim_eta = ddim_eta self.ddim_eta = ddim_eta
self.precision = precision self.precision = precision
self.full_precision = full_precision self.full_precision = full_precision

View File

@ -51,6 +51,7 @@ def main():
weights=weights, weights=weights,
full_precision=opt.full_precision, full_precision=opt.full_precision,
config=config, config=config,
grid = opt.grid,
# this is solely for recreating the prompt # this is solely for recreating the prompt
latent_diffusion_weights=opt.laion400m, latent_diffusion_weights=opt.laion400m,
embedding_path=opt.embedding_path, embedding_path=opt.embedding_path,
@ -179,7 +180,8 @@ def main_loop(t2i, outdir, parser, infile):
file_writer.files_written if individual_images else image_list file_writer.files_written if individual_images else image_list
) )
if opt.grid and len(results) > 0: grid = opt.grid or t2i.grid
if grid and len(results) > 0:
grid_img = file_writer.make_grid([r[0] for r in results]) grid_img = file_writer.make_grid([r[0] for r in results])
filename = file_writer.unique_filename(results[0][1]) filename = file_writer.unique_filename(results[0][1])
seeds = [a[1] for a in results] seeds = [a[1] for a in results]
@ -261,7 +263,13 @@ SAMPLER_CHOICES=[
def create_argv_parser(): def create_argv_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Parse script's command line args" description="""Generate images using Stable Diffusion.
Use --web to launch the web interface.
Use --from_file to load prompts from a file path or standard input ("-").
Otherwise you will be dropped into an interactive command prompt (type -h for help.)
Other command-line arguments are defaults that can usually be overridden
prompt the command prompt.
"""
) )
parser.add_argument( parser.add_argument(
'--laion400m', '--laion400m',
@ -291,6 +299,12 @@ def create_argv_parser():
action='store_true', action='store_true',
help='Use slower full precision math for calculations', help='Use slower full precision math for calculations',
) )
parser.add_argument(
'-g',
'--grid',
action='store_true',
help='Generate a grid instead of individual images',
)
parser.add_argument( parser.add_argument(
'-A', '-A',
'-m', '-m',