From bc11296a5ead60064f3b226e2b8bf3b51b32fab9 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 18 Jul 2023 16:20:25 +0300 Subject: [PATCH 1/2] Disable lazy offloading on disabled vram cache, move resulted tensors to cpu(to not stack vram tensors in cache), fix - text encoder not freed(detach) --- invokeai/app/invocations/compel.py | 10 +++++++ invokeai/app/invocations/latent.py | 28 +++++++++++++------ invokeai/app/invocations/noise.py | 2 +- invokeai/app/invocations/sdxl.py | 2 ++ .../backend/model_management/model_cache.py | 3 +- 5 files changed, 35 insertions(+), 10 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 3b1edf24b9..a916e59507 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -147,6 +147,8 @@ class CompelInvocation(BaseInvocation): cross_attention_control_args=options.get( "cross_attention_control", None),) + c = c.detach().to("cpu") + conditioning_data = ConditioningFieldData( conditionings=[ BasicConditioningInfo( @@ -229,6 +231,10 @@ class SDXLPromptInvocationBase: del tokenizer_info del text_encoder_info + c = c.detach().to("cpu") + if c_pooled is not None: + c_pooled = c_pooled.detach().to("cpu") + return c, c_pooled, None def run_clip_compel(self, context, clip_field, prompt, get_pooled): @@ -305,6 +311,10 @@ class SDXLPromptInvocationBase: del tokenizer_info del text_encoder_info + c = c.detach().to("cpu") + if c_pooled is not None: + c_pooled = c_pooled.detach().to("cpu") + return c, c_pooled, ec class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 19207d78d7..1a34e58d28 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -167,13 +167,14 @@ class TextToLatentsInvocation(BaseInvocation): self, context: InvocationContext, scheduler, + unet, ) -> ConditioningData: positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) - c = positive_cond_data.conditionings[0].embeds + c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) - uc = negative_cond_data.conditionings[0].embeds + uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) conditioning_data = ConditioningData( unconditioned_embeddings=uc, @@ -195,7 +196,7 @@ class TextToLatentsInvocation(BaseInvocation): eta=0.0, # ddim_eta # for ancestral and sde schedulers - generator=torch.Generator(device=uc.device).manual_seed(0), + generator=torch.Generator(device=unet.device).manual_seed(0), ) return conditioning_data @@ -340,7 +341,7 @@ class TextToLatentsInvocation(BaseInvocation): ) pipeline = self.create_pipeline(unet, scheduler) - conditioning_data = self.get_conditioning_data(context, scheduler) + conditioning_data = self.get_conditioning_data(context, scheduler, unet) control_data = self.prep_control_data( model=pipeline, context=context, control_input=self.control, @@ -361,6 +362,7 @@ class TextToLatentsInvocation(BaseInvocation): ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + result_latents = result_latents.to("cpu") torch.cuda.empty_cache() name = f'{context.graph_execution_state_id}__{self.id}' @@ -430,7 +432,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): ) pipeline = self.create_pipeline(unet, scheduler) - conditioning_data = self.get_conditioning_data(context, scheduler) + conditioning_data = self.get_conditioning_data(context, scheduler, unet) control_data = self.prep_control_data( model=pipeline, context=context, control_input=self.control, @@ -462,6 +464,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + result_latents = result_latents.to("cpu") torch.cuda.empty_cache() name = f'{context.graph_execution_state_id}__{self.id}' @@ -502,6 +505,7 @@ class LatentsToImageInvocation(BaseInvocation): ) with vae_info as vae: + latents = latents.to(vae.device) if self.fp32: vae.to(dtype=torch.float32) @@ -589,13 +593,17 @@ class ResizeLatentsInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) + # TODO: + device=choose_torch_device() + resized_latents = torch.nn.functional.interpolate( - latents, size=(self.height // 8, self.width // 8), + latents.to(device), size=(self.height // 8, self.width // 8), mode=self.mode, antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + resized_latents = resized_latents.to("cpu") torch.cuda.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" @@ -623,14 +631,18 @@ class ScaleLatentsInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) + # TODO: + device=choose_torch_device() + # resizing resized_latents = torch.nn.functional.interpolate( - latents, scale_factor=self.scale_factor, mode=self.mode, + latents.to(device), scale_factor=self.scale_factor, mode=self.mode, antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + resized_latents = resized_latents.to("cpu") torch.cuda.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" @@ -721,6 +733,6 @@ class ImageToLatentsInvocation(BaseInvocation): latents = latents.to(dtype=orig_dtype) name = f"{context.graph_execution_state_id}__{self.id}" - # context.services.latents.set(name, latents) + latents = latents.to("cpu") context.services.latents.save(name, latents) return build_latents_output(latents_name=name, latents=latents) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 0d62ada34e..abe67131ff 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -48,7 +48,7 @@ def get_noise( dtype=torch_dtype(device), device=noise_device_type, generator=generator, - ).to(device) + ).to("cpu") return noise_tensor diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index d7a4b398b3..f73418e18e 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -415,6 +415,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation): ################# + latents = latents.to("cpu") torch.cuda.empty_cache() name = f'{context.graph_execution_state_id}__{self.id}' @@ -651,6 +652,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): ################# + latents = latents.to("cpu") torch.cuda.empty_cache() name = f'{context.graph_execution_state_id}__{self.id}' diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index e4cba3517e..7c3d43e3e2 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -104,7 +104,8 @@ class ModelCache(object): :param sha_chunksize: Chunksize to use when calculating sha256 model hash ''' self.model_infos: Dict[str, ModelBase] = dict() - self.lazy_offloading = lazy_offloading + # allow lazy offloading only when vram cache enabled + self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0 self.precision: torch.dtype=precision self.max_cache_size: float=max_cache_size self.max_vram_cache_size: float=max_vram_cache_size From fbbc4b3f69141f09e1617cc94821736cb6759054 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 18 Jul 2023 16:51:16 +0300 Subject: [PATCH 2/2] Fixes --- invokeai/app/invocations/generate.py | 13 +++++++++---- invokeai/app/invocations/latent.py | 5 +++++ invokeai/app/invocations/sdxl.py | 4 ++-- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 6cdb83effc..c5967e4074 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -146,9 +146,13 @@ class InpaintInvocation(BaseInvocation): source_node_id=source_node_id, ) - def get_conditioning(self, context): - c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name) - uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) + def get_conditioning(self, context, unet): + positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) + c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) + extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning + + negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) + uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) return (uc, c, extra_conditioning_info) @@ -209,7 +213,6 @@ class InpaintInvocation(BaseInvocation): ) source_node_id = graph_execution_state.prepared_source_mapping[self.id] - conditioning = self.get_conditioning(context) scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, @@ -217,6 +220,8 @@ class InpaintInvocation(BaseInvocation): ) with self.load_model_old_way(context, scheduler) as model: + conditioning = self.get_conditioning(context, model.context.model.unet) + outputs = Inpaint(model).generate( conditioning=conditioning, scheduler=scheduler, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 1a34e58d28..8af5a943a8 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -334,6 +334,8 @@ class TextToLatentsInvocation(BaseInvocation): ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info as unet: + noise = noise.to(device=unet.device, dtype=unet.dtype) + scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, @@ -425,6 +427,9 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info as unet: + noise = noise.to(device=unet.device, dtype=unet.dtype) + latent = latent.to(device=unet.device, dtype=unet.dtype) + scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index f73418e18e..2cc12f779e 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -305,7 +305,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation): add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype) latents = latents.to(device=unet.device, dtype=unet.dtype) - with tqdm(total=self.steps) as progress_bar: + with tqdm(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -351,7 +351,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation): add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype) latents = latents.to(device=unet.device, dtype=unet.dtype) - with tqdm(total=self.steps) as progress_bar: + with tqdm(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance #latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents