Merge pull request #177 from lstein/bugfixes

Bugfixes to image generation logic
This commit is contained in:
Lincoln Stein 2022-08-29 06:58:42 -04:00 committed by GitHub
commit 30dbf0e589
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,6 +20,7 @@ from contextlib import contextmanager, nullcontext
import transformers
import time
import re
import sys
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
@ -174,17 +175,15 @@ class T2I:
outdir, prompt, kwargs.get('batch_size', self.batch_size)
)
for r in results:
# gets written into the PNG
metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}'
pngwriter.write_image(r[0], r[1])
return pngwriter.files_written
def txt2img(self, prompt, **kwargs):
outdir = kwargs.get('outdir', 'outputs/img-samples')
outdir = kwargs.pop('outdir', 'outputs/img-samples')
return self.prompt2png(prompt, outdir, **kwargs)
def img2img(self, prompt, **kwargs):
outdir = kwargs.get('outdir', 'outputs/img-samples')
outdir = kwargs.pop('outdir', 'outputs/img-samples')
assert (
'init_img' in kwargs
), 'call to img2img() must include the init_img argument'
@ -349,6 +348,8 @@ class T2I:
image_callback(image, seed)
else:
image_callback(image, seed, upscaled=True)
else: # no callback passed, so we simply replace old image with rescaled one
result[0] = image
except KeyboardInterrupt:
print('*interrupted*')
@ -540,6 +541,9 @@ class T2I:
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
self.model.cond_stage_model.device = self.device
except AttributeError:
import traceback
print('Error loading model. Only the CUDA backend is supported',file=sys.stderr)
print(traceback.format_exc(),file=sys.stderr)
raise SystemExit
self._set_sampler()