Merge branch 'development' into development

This commit is contained in:
Peter Baylies 2022-09-12 10:37:49 -04:00 committed by GitHub
commit ea60d036d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 7 deletions

View File

@ -32,16 +32,16 @@ dependencies:
- omegaconf==2.1.1
- onnx==1.12.0
- onnxruntime==1.12.1
- opencv-python==4.5.5.64
- pudb==2022.1
- pytorch-lightning==1.6.5
- realesrgan==0.2.5.0
- scipy==1.9.1
- streamlit==1.12.2
- sympy==1.10.1
- tensorboard==2.9.0
- torchmetrics==0.9.3
- pip:
- opencv-python==4.6.0
- realesrgan==0.2.5.0
- test-tube==0.7.5
- transformers==4.21.2
- torch-fidelity==0.3.0

View File

@ -297,9 +297,9 @@ class BasicTransformerBlock(nn.Module):
def _forward(self, x, context=None):
x = x.contiguous() if x.device.type == 'mps' else x
x += self.attn1(self.norm1(x))
x += self.attn2(self.norm2(x), context=context)
x += self.ff(self.norm3(x))
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x

View File

@ -261,6 +261,7 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
grid_images = dict() # seed -> Image, only used if `do_grid`
def image_writer(image, seed, upscaled=False):
path = None
if do_grid:
grid_images[seed] = image
else:
@ -298,13 +299,14 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
t2i.prompt2image(image_callback=image_writer, **vars(opt))
if do_grid and len(grid_images) > 0:
grid_img = make_grid(list(grid_images.values()))
grid_img = make_grid(list(grid_images.values()))
grid_seeds = list(grid_images.keys())
first_seed = last_results[0][1]
filename = f'{prefix}.{first_seed}.png'
# TODO better metadata for grid images
normalized_prompt = PromptFormatter(
t2i, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{first_seed} --grid -N{len(grid_images)}'
metadata_prompt = f'{normalized_prompt} -S{first_seed} --grid -n{len(grid_images)} # {grid_seeds}'
path = file_writer.save_image_and_prompt_to_png(
grid_img, metadata_prompt, filename
)