add argument --outdir

This commit is contained in:
Hideyuki Katsushiro 2022-10-05 10:08:53 +09:00
parent 9318719b9e
commit e7fb9f342c

View File

@ -118,6 +118,7 @@ class Generate:
embedding_path = None, embedding_path = None,
device_type = 'cuda', device_type = 'cuda',
ignore_ctrl_c = False, ignore_ctrl_c = False,
outdir = 'outputs/img-samples',
): ):
self.iterations = iterations self.iterations = iterations
self.width = width self.width = width
@ -142,6 +143,7 @@ class Generate:
self.generators = {} self.generators = {}
self.base_generator = None self.base_generator = None
self.seed = None self.seed = None
self.outdir = outdir
if device_type == 'cuda' and not torch.cuda.is_available(): if device_type == 'cuda' and not torch.cuda.is_available():
device_type = choose_torch_device() device_type = choose_torch_device()
@ -152,7 +154,6 @@ class Generate:
device_type = choose_torch_device() device_type = choose_torch_device()
self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
def prompt2png(self, prompt, outdir, **kwargs): def prompt2png(self, prompt, outdir, **kwargs):
""" """
Takes a prompt and an output directory, writes out the requested number Takes a prompt and an output directory, writes out the requested number
@ -171,11 +172,11 @@ class Generate:
return outputs return outputs
def txt2img(self, prompt, **kwargs): def txt2img(self, prompt, **kwargs):
outdir = kwargs.pop('outdir', 'outputs/img-samples') outdir = kwargs.pop('outdir', self.outdir)
return self.prompt2png(prompt, outdir, **kwargs) return self.prompt2png(prompt, outdir, **kwargs)
def img2img(self, prompt, **kwargs): def img2img(self, prompt, **kwargs):
outdir = kwargs.pop('outdir', 'outputs/img-samples') outdir = kwargs.pop('outdir', self.outdir)
assert ( assert (
'init_img' in kwargs 'init_img' in kwargs
), 'call to img2img() must include the init_img argument' ), 'call to img2img() must include the init_img argument'
@ -209,6 +210,7 @@ class Generate:
gfpgan_strength= 0, gfpgan_strength= 0,
save_original = False, save_original = False,
upscale = None, upscale = None,
outdir = None,
**args, **args,
): # eat up additional cruft ): # eat up additional cruft
""" """
@ -255,7 +257,7 @@ class Generate:
ddim_eta = ddim_eta or self.ddim_eta ddim_eta = ddim_eta or self.ddim_eta
iterations = iterations or self.iterations iterations = iterations or self.iterations
strength = strength or self.strength strength = strength or self.strength
self.seed = seed outdir = outdir or self.outdir
self.log_tokenization = log_tokenization self.log_tokenization = log_tokenization
with_variations = [] if with_variations is None else with_variations with_variations = [] if with_variations is None else with_variations