Make sure --free_gpu_mem still works when using CKPT-based diffuser model (#2367)

This PR attempts to fix `--free_gpu_mem` option that was not working in
CKPT-based diffuser model after #1583.

I noticed that the memory usage after #1583 did not decrease after
generating an image when `--free_gpu_mem` option was enabled.
It turns out that the option was not propagated into `Generator`
instance, hence the generation will always run without the memory saving
procedure.

This PR also related to #2326. Initially, I was trying to make
`--free_gpu_mem` works on 🤗 diffuser model as well.
In the process, I noticed that InvokeAI will raise an exception when
`--free_gpu_mem` is enabled.
I tried to quickly fix it by simply ignoring the exception and produce a
warning message to user's console.
This commit is contained in:
Lincoln Stein 2023-01-23 21:48:23 -05:00 committed by GitHub
commit 884768c39d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 5 deletions

View File

@ -146,7 +146,7 @@ class Generate:
gfpgan=None, gfpgan=None,
codeformer=None, codeformer=None,
esrgan=None, esrgan=None,
free_gpu_mem=False, free_gpu_mem: bool=False,
safety_checker:bool=False, safety_checker:bool=False,
max_loaded_models:int=2, max_loaded_models:int=2,
# these are deprecated; if present they override values in the conf file # these are deprecated; if present they override values in the conf file
@ -460,10 +460,13 @@ class Generate:
init_image = None init_image = None
mask_image = None mask_image = None
try:
if self.free_gpu_mem and self.model.cond_stage_model.device != self.model.device: if self.free_gpu_mem and self.model.cond_stage_model.device != self.model.device:
self.model.cond_stage_model.device = self.model.device self.model.cond_stage_model.device = self.model.device
self.model.cond_stage_model.to(self.model.device) self.model.cond_stage_model.to(self.model.device)
except AttributeError:
print(">> Warning: '--free_gpu_mem' is not yet supported when generating image using model based on HuggingFace Diffuser.")
pass
try: try:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec( uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
@ -531,6 +534,7 @@ class Generate:
inpaint_height = inpaint_height, inpaint_height = inpaint_height,
inpaint_width = inpaint_width, inpaint_width = inpaint_width,
enable_image_debugging = enable_image_debugging, enable_image_debugging = enable_image_debugging,
free_gpu_mem=self.free_gpu_mem,
) )
if init_color: if init_color:

View File

@ -56,9 +56,11 @@ class CkptGenerator():
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None, safety_checker:dict=None,
attention_maps_callback = None, attention_maps_callback = None,
free_gpu_mem: bool=False,
**kwargs): **kwargs):
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
self.safety_checker = safety_checker self.safety_checker = safety_checker
self.free_gpu_mem = free_gpu_mem
attention_maps_images = [] attention_maps_images = []
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image()) attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
make_image = self.get_make_image( make_image = self.get_make_image(

View File

@ -62,9 +62,11 @@ class Generator:
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None, def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None, safety_checker:dict=None,
free_gpu_mem: bool=False,
**kwargs): **kwargs):
scope = nullcontext scope = nullcontext
self.safety_checker = safety_checker self.safety_checker = safety_checker
self.free_gpu_mem = free_gpu_mem
attention_maps_images = [] attention_maps_images = []
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image()) attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
make_image = self.get_make_image( make_image = self.get_make_image(