diff --git a/ldm/generate.py b/ldm/generate.py index 8f67403633..943fe69101 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -118,6 +118,7 @@ class Generate: embedding_path = None, device_type = 'cuda', ignore_ctrl_c = False, + outdir = 'outputs/img-samples', ): self.iterations = iterations self.width = width @@ -142,6 +143,7 @@ class Generate: self.generators = {} self.base_generator = None self.seed = None + self.outdir = outdir if device_type == 'cuda' and not torch.cuda.is_available(): device_type = choose_torch_device() @@ -152,7 +154,6 @@ class Generate: device_type = choose_torch_device() self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None transformers.logging.set_verbosity_error() - def prompt2png(self, prompt, outdir, **kwargs): """ Takes a prompt and an output directory, writes out the requested number @@ -171,11 +172,11 @@ class Generate: return outputs def txt2img(self, prompt, **kwargs): - outdir = kwargs.pop('outdir', 'outputs/img-samples') + outdir = kwargs.pop('outdir', self.outdir) return self.prompt2png(prompt, outdir, **kwargs) def img2img(self, prompt, **kwargs): - outdir = kwargs.pop('outdir', 'outputs/img-samples') + outdir = kwargs.pop('outdir', self.outdir) assert ( 'init_img' in kwargs ), 'call to img2img() must include the init_img argument' @@ -209,6 +210,7 @@ class Generate: gfpgan_strength= 0, save_original = False, upscale = None, + outdir = None, **args, ): # eat up additional cruft """ @@ -255,7 +257,7 @@ class Generate: ddim_eta = ddim_eta or self.ddim_eta iterations = iterations or self.iterations strength = strength or self.strength - self.seed = seed + outdir = outdir or self.outdir self.log_tokenization = log_tokenization with_variations = [] if with_variations is None else with_variations