diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 3b1edf24b9..a916e59507 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -147,6 +147,8 @@ class CompelInvocation(BaseInvocation): cross_attention_control_args=options.get( "cross_attention_control", None),) + c = c.detach().to("cpu") + conditioning_data = ConditioningFieldData( conditionings=[ BasicConditioningInfo( @@ -229,6 +231,10 @@ class SDXLPromptInvocationBase: del tokenizer_info del text_encoder_info + c = c.detach().to("cpu") + if c_pooled is not None: + c_pooled = c_pooled.detach().to("cpu") + return c, c_pooled, None def run_clip_compel(self, context, clip_field, prompt, get_pooled): @@ -305,6 +311,10 @@ class SDXLPromptInvocationBase: del tokenizer_info del text_encoder_info + c = c.detach().to("cpu") + if c_pooled is not None: + c_pooled = c_pooled.detach().to("cpu") + return c, c_pooled, ec class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 19207d78d7..1a34e58d28 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -167,13 +167,14 @@ class TextToLatentsInvocation(BaseInvocation): self, context: InvocationContext, scheduler, + unet, ) -> ConditioningData: positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) - c = positive_cond_data.conditionings[0].embeds + c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) - uc = negative_cond_data.conditionings[0].embeds + uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) conditioning_data = ConditioningData( unconditioned_embeddings=uc, @@ -195,7 +196,7 @@ class TextToLatentsInvocation(BaseInvocation): eta=0.0, # ddim_eta # for ancestral and sde schedulers - generator=torch.Generator(device=uc.device).manual_seed(0), + generator=torch.Generator(device=unet.device).manual_seed(0), ) return conditioning_data @@ -340,7 +341,7 @@ class TextToLatentsInvocation(BaseInvocation): ) pipeline = self.create_pipeline(unet, scheduler) - conditioning_data = self.get_conditioning_data(context, scheduler) + conditioning_data = self.get_conditioning_data(context, scheduler, unet) control_data = self.prep_control_data( model=pipeline, context=context, control_input=self.control, @@ -361,6 +362,7 @@ class TextToLatentsInvocation(BaseInvocation): ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + result_latents = result_latents.to("cpu") torch.cuda.empty_cache() name = f'{context.graph_execution_state_id}__{self.id}' @@ -430,7 +432,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): ) pipeline = self.create_pipeline(unet, scheduler) - conditioning_data = self.get_conditioning_data(context, scheduler) + conditioning_data = self.get_conditioning_data(context, scheduler, unet) control_data = self.prep_control_data( model=pipeline, context=context, control_input=self.control, @@ -462,6 +464,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + result_latents = result_latents.to("cpu") torch.cuda.empty_cache() name = f'{context.graph_execution_state_id}__{self.id}' @@ -502,6 +505,7 @@ class LatentsToImageInvocation(BaseInvocation): ) with vae_info as vae: + latents = latents.to(vae.device) if self.fp32: vae.to(dtype=torch.float32) @@ -589,13 +593,17 @@ class ResizeLatentsInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) + # TODO: + device=choose_torch_device() + resized_latents = torch.nn.functional.interpolate( - latents, size=(self.height // 8, self.width // 8), + latents.to(device), size=(self.height // 8, self.width // 8), mode=self.mode, antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + resized_latents = resized_latents.to("cpu") torch.cuda.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" @@ -623,14 +631,18 @@ class ScaleLatentsInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) + # TODO: + device=choose_torch_device() + # resizing resized_latents = torch.nn.functional.interpolate( - latents, scale_factor=self.scale_factor, mode=self.mode, + latents.to(device), scale_factor=self.scale_factor, mode=self.mode, antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + resized_latents = resized_latents.to("cpu") torch.cuda.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" @@ -721,6 +733,6 @@ class ImageToLatentsInvocation(BaseInvocation): latents = latents.to(dtype=orig_dtype) name = f"{context.graph_execution_state_id}__{self.id}" - # context.services.latents.set(name, latents) + latents = latents.to("cpu") context.services.latents.save(name, latents) return build_latents_output(latents_name=name, latents=latents) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 0d62ada34e..abe67131ff 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -48,7 +48,7 @@ def get_noise( dtype=torch_dtype(device), device=noise_device_type, generator=generator, - ).to(device) + ).to("cpu") return noise_tensor diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index d7a4b398b3..f73418e18e 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -415,6 +415,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation): ################# + latents = latents.to("cpu") torch.cuda.empty_cache() name = f'{context.graph_execution_state_id}__{self.id}' @@ -651,6 +652,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): ################# + latents = latents.to("cpu") torch.cuda.empty_cache() name = f'{context.graph_execution_state_id}__{self.id}' diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index e4cba3517e..7c3d43e3e2 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -104,7 +104,8 @@ class ModelCache(object): :param sha_chunksize: Chunksize to use when calculating sha256 model hash ''' self.model_infos: Dict[str, ModelBase] = dict() - self.lazy_offloading = lazy_offloading + # allow lazy offloading only when vram cache enabled + self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0 self.precision: torch.dtype=precision self.max_cache_size: float=max_cache_size self.max_vram_cache_size: float=max_vram_cache_size