diff --git a/ldm/generate.py b/ldm/generate.py index 63eaf79b50..83cdc6b852 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -146,7 +146,7 @@ class Generate: gfpgan=None, codeformer=None, esrgan=None, - free_gpu_mem=False, + free_gpu_mem: bool=False, safety_checker:bool=False, max_loaded_models:int=2, # these are deprecated; if present they override values in the conf file @@ -460,10 +460,13 @@ class Generate: init_image = None mask_image = None - - if self.free_gpu_mem and self.model.cond_stage_model.device != self.model.device: - self.model.cond_stage_model.device = self.model.device - self.model.cond_stage_model.to(self.model.device) + try: + if self.free_gpu_mem and self.model.cond_stage_model.device != self.model.device: + self.model.cond_stage_model.device = self.model.device + self.model.cond_stage_model.to(self.model.device) + except AttributeError: + print(">> Warning: '--free_gpu_mem' is not yet supported when generating image using model based on HuggingFace Diffuser.") + pass try: uc, c, extra_conditioning_info = get_uc_and_c_and_ec( @@ -531,6 +534,7 @@ class Generate: inpaint_height = inpaint_height, inpaint_width = inpaint_width, enable_image_debugging = enable_image_debugging, + free_gpu_mem=self.free_gpu_mem, ) if init_color: diff --git a/ldm/invoke/ckpt_generator/base.py b/ldm/invoke/ckpt_generator/base.py index c84550a6e3..9d137b74d6 100644 --- a/ldm/invoke/ckpt_generator/base.py +++ b/ldm/invoke/ckpt_generator/base.py @@ -56,9 +56,11 @@ class CkptGenerator(): image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, safety_checker:dict=None, attention_maps_callback = None, + free_gpu_mem: bool=False, **kwargs): scope = choose_autocast(self.precision) self.safety_checker = safety_checker + self.free_gpu_mem = free_gpu_mem attention_maps_images = [] attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image()) make_image = self.get_make_image( diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index 3fd34765c6..a17badd022 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -62,9 +62,11 @@ class Generator: def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, safety_checker:dict=None, + free_gpu_mem: bool=False, **kwargs): scope = nullcontext self.safety_checker = safety_checker + self.free_gpu_mem = free_gpu_mem attention_maps_images = [] attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image()) make_image = self.get_make_image(