diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 1769b716e7..39c5514148 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -148,6 +148,8 @@ class CompelInvocation(BaseInvocation): cross_attention_control_args=options.get( "cross_attention_control", None),) + c = c.detach().to("cpu") + conditioning_data = ConditioningFieldData( conditionings=[ BasicConditioningInfo( @@ -230,6 +232,10 @@ class SDXLPromptInvocationBase: del tokenizer_info del text_encoder_info + c = c.detach().to("cpu") + if c_pooled is not None: + c_pooled = c_pooled.detach().to("cpu") + return c, c_pooled, None def run_clip_compel(self, context, clip_field, prompt, get_pooled): @@ -306,6 +312,10 @@ class SDXLPromptInvocationBase: del tokenizer_info del text_encoder_info + c = c.detach().to("cpu") + if c_pooled is not None: + c_pooled = c_pooled.detach().to("cpu") + return c, c_pooled, ec class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 35ce668a1b..7d72d61ea1 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -146,13 +146,13 @@ class InpaintInvocation(BaseInvocation): source_node_id=source_node_id, ) - def get_conditioning(self, context): + def get_conditioning(self, context, unet): positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) - c = positive_cond_data.conditionings[0].embeds + c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) - uc = negative_cond_data.conditionings[0].embeds + uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) return (uc, c, extra_conditioning_info) @@ -213,7 +213,6 @@ class InpaintInvocation(BaseInvocation): ) source_node_id = graph_execution_state.prepared_source_mapping[self.id] - conditioning = self.get_conditioning(context) scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, @@ -221,6 +220,8 @@ class InpaintInvocation(BaseInvocation): ) with self.load_model_old_way(context, scheduler) as model: + conditioning = self.get_conditioning(context, model.context.model.unet) + outputs = Inpaint(model).generate( conditioning=conditioning, scheduler=scheduler, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index c6df25052c..575ba28c49 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -167,13 +167,14 @@ class TextToLatentsInvocation(BaseInvocation): self, context: InvocationContext, scheduler, + unet, ) -> ConditioningData: positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) - c = positive_cond_data.conditionings[0].embeds + c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) - uc = negative_cond_data.conditionings[0].embeds + uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) conditioning_data = ConditioningData( unconditioned_embeddings=uc, @@ -195,7 +196,7 @@ class TextToLatentsInvocation(BaseInvocation): eta=0.0, # ddim_eta # for ancestral and sde schedulers - generator=torch.Generator(device=uc.device).manual_seed(0), + generator=torch.Generator(device=unet.device).manual_seed(0), ) return conditioning_data @@ -334,6 +335,8 @@ class TextToLatentsInvocation(BaseInvocation): ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info as unet: + noise = noise.to(device=unet.device, dtype=unet.dtype) + scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, @@ -341,7 +344,7 @@ class TextToLatentsInvocation(BaseInvocation): ) pipeline = self.create_pipeline(unet, scheduler) - conditioning_data = self.get_conditioning_data(context, scheduler) + conditioning_data = self.get_conditioning_data(context, scheduler, unet) control_data = self.prep_control_data( model=pipeline, context=context, control_input=self.control, @@ -362,6 +365,7 @@ class TextToLatentsInvocation(BaseInvocation): ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + result_latents = result_latents.to("cpu") torch.cuda.empty_cache() name = f'{context.graph_execution_state_id}__{self.id}' @@ -424,6 +428,9 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info as unet: + noise = noise.to(device=unet.device, dtype=unet.dtype) + latent = latent.to(device=unet.device, dtype=unet.dtype) + scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, @@ -431,7 +438,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): ) pipeline = self.create_pipeline(unet, scheduler) - conditioning_data = self.get_conditioning_data(context, scheduler) + conditioning_data = self.get_conditioning_data(context, scheduler, unet) control_data = self.prep_control_data( model=pipeline, context=context, control_input=self.control, @@ -463,6 +470,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + result_latents = result_latents.to("cpu") torch.cuda.empty_cache() name = f'{context.graph_execution_state_id}__{self.id}' @@ -503,6 +511,7 @@ class LatentsToImageInvocation(BaseInvocation): ) with vae_info as vae: + latents = latents.to(vae.device) if self.fp32: vae.to(dtype=torch.float32) @@ -590,13 +599,17 @@ class ResizeLatentsInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) + # TODO: + device=choose_torch_device() + resized_latents = torch.nn.functional.interpolate( - latents, size=(self.height // 8, self.width // 8), + latents.to(device), size=(self.height // 8, self.width // 8), mode=self.mode, antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + resized_latents = resized_latents.to("cpu") torch.cuda.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" @@ -624,14 +637,18 @@ class ScaleLatentsInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) + # TODO: + device=choose_torch_device() + # resizing resized_latents = torch.nn.functional.interpolate( - latents, scale_factor=self.scale_factor, mode=self.mode, + latents.to(device), scale_factor=self.scale_factor, mode=self.mode, antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + resized_latents = resized_latents.to("cpu") torch.cuda.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" @@ -722,6 +739,6 @@ class ImageToLatentsInvocation(BaseInvocation): latents = latents.to(dtype=orig_dtype) name = f"{context.graph_execution_state_id}__{self.id}" - # context.services.latents.set(name, latents) + latents = latents.to("cpu") context.services.latents.save(name, latents) return build_latents_output(latents_name=name, latents=latents) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 0d62ada34e..abe67131ff 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -48,7 +48,7 @@ def get_noise( dtype=torch_dtype(device), device=noise_device_type, generator=generator, - ).to(device) + ).to("cpu") return noise_tensor diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index d7a4b398b3..2cc12f779e 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -305,7 +305,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation): add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype) latents = latents.to(device=unet.device, dtype=unet.dtype) - with tqdm(total=self.steps) as progress_bar: + with tqdm(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -351,7 +351,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation): add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype) latents = latents.to(device=unet.device, dtype=unet.dtype) - with tqdm(total=self.steps) as progress_bar: + with tqdm(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance #latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -415,6 +415,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation): ################# + latents = latents.to("cpu") torch.cuda.empty_cache() name = f'{context.graph_execution_state_id}__{self.id}' @@ -651,6 +652,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): ################# + latents = latents.to("cpu") torch.cuda.empty_cache() name = f'{context.graph_execution_state_id}__{self.id}' diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index e4cba3517e..7c3d43e3e2 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -104,7 +104,8 @@ class ModelCache(object): :param sha_chunksize: Chunksize to use when calculating sha256 model hash ''' self.model_infos: Dict[str, ModelBase] = dict() - self.lazy_offloading = lazy_offloading + # allow lazy offloading only when vram cache enabled + self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0 self.precision: torch.dtype=precision self.max_cache_size: float=max_cache_size self.max_vram_cache_size: float=max_vram_cache_size