From 5a40aadbee2a0ffb7b14bf9f05a7e1b66d72d62d Mon Sep 17 00:00:00 2001 From: Daya Adianto Date: Wed, 18 Jan 2023 23:23:18 +0700 Subject: [PATCH] Ensure free_gpu_mem option is passed into the generator (#2326) --- ldm/generate.py | 3 ++- ldm/invoke/ckpt_generator/base.py | 2 ++ ldm/invoke/generator/base.py | 2 ++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ldm/generate.py b/ldm/generate.py index 9f19f1aefe..83cdc6b852 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -146,7 +146,7 @@ class Generate: gfpgan=None, codeformer=None, esrgan=None, - free_gpu_mem=False, + free_gpu_mem: bool=False, safety_checker:bool=False, max_loaded_models:int=2, # these are deprecated; if present they override values in the conf file @@ -534,6 +534,7 @@ class Generate: inpaint_height = inpaint_height, inpaint_width = inpaint_width, enable_image_debugging = enable_image_debugging, + free_gpu_mem=self.free_gpu_mem, ) if init_color: diff --git a/ldm/invoke/ckpt_generator/base.py b/ldm/invoke/ckpt_generator/base.py index c84550a6e3..9d137b74d6 100644 --- a/ldm/invoke/ckpt_generator/base.py +++ b/ldm/invoke/ckpt_generator/base.py @@ -56,9 +56,11 @@ class CkptGenerator(): image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, safety_checker:dict=None, attention_maps_callback = None, + free_gpu_mem: bool=False, **kwargs): scope = choose_autocast(self.precision) self.safety_checker = safety_checker + self.free_gpu_mem = free_gpu_mem attention_maps_images = [] attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image()) make_image = self.get_make_image( diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index 3fd34765c6..a17badd022 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -62,9 +62,11 @@ class Generator: def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, safety_checker:dict=None, + free_gpu_mem: bool=False, **kwargs): scope = nullcontext self.safety_checker = safety_checker + self.free_gpu_mem = free_gpu_mem attention_maps_images = [] attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image()) make_image = self.get_make_image(