diff --git a/ldm/dream/args.py b/ldm/dream/args.py index 79b1d49f5d..8650fbdfbf 100644 --- a/ldm/dream/args.py +++ b/ldm/dream/args.py @@ -339,6 +339,12 @@ class Args(object): action='store_true', help='Deprecated way to set --precision=float32', ) + model_group.add_argument( + '--free_gpu_mem', + dest='free_gpu_mem', + action='store_true', + help='Force free gpu memory before final decoding', + ) model_group.add_argument( '--precision', dest='precision', diff --git a/ldm/dream/generator/txt2img.py b/ldm/dream/generator/txt2img.py index 0c77705a1c..1ab15ba7cd 100644 --- a/ldm/dream/generator/txt2img.py +++ b/ldm/dream/generator/txt2img.py @@ -27,6 +27,10 @@ class Txt2Img(Generator): height // self.downsampling_factor, width // self.downsampling_factor, ] + + if self.free_gpu_mem and self.model.model.device != self.model.device: + self.model.model.to(self.model.device) + samples, _ = sampler.sample( batch_size = 1, S = steps, @@ -39,6 +43,10 @@ class Txt2Img(Generator): eta = ddim_eta, img_callback = step_callback ) + + if self.free_gpu_mem: + self.model.model.to("cpu") + return self.sample_to_image(samples) return make_image diff --git a/ldm/generate.py b/ldm/generate.py index 6dfbfdc018..854a519520 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -655,6 +655,7 @@ class Generate: if not self.generators.get('txt2img'): from ldm.dream.generator.txt2img import Txt2Img self.generators['txt2img'] = Txt2Img(self.model, self.precision) + self.generators['txt2img'].free_gpu_mem = self.free_gpu_mem return self.generators['txt2img'] def _make_inpaint(self): diff --git a/scripts/dream.py b/scripts/dream.py index f102902263..cac8c2aee4 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -108,6 +108,8 @@ def main(): # preload the model gen.load_model() + #set additional option + gen.free_gpu_mem = opt.free_gpu_mem if not infile: print(