diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 158f11a58e..766b44fdc8 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -65,11 +65,7 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: tokenizer_info = context.models.load(self.clip.tokenizer) - tokenizer_model = tokenizer_info.model - assert isinstance(tokenizer_model, CLIPTokenizer) text_encoder_info = context.models.load(self.clip.text_encoder) - text_encoder_model = text_encoder_info.model - assert isinstance(text_encoder_model, CLIPTextModel) def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: @@ -84,19 +80,21 @@ class CompelInvocation(BaseInvocation): ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context) with ( - ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( - tokenizer, - ti_manager, - ), + # apply all patches while the model is on the target device text_encoder_info as text_encoder, - # Apply the LoRA after text_encoder has been moved to its target device for faster patching. + tokenizer_info as tokenizer, ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers), + ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as ( + patched_tokenizer, + ti_manager, + ), ): assert isinstance(text_encoder, CLIPTextModel) + assert isinstance(tokenizer, CLIPTokenizer) compel = Compel( - tokenizer=tokenizer, + tokenizer=patched_tokenizer, text_encoder=text_encoder, textual_inversion_manager=ti_manager, dtype_for_device_getter=TorchDevice.choose_torch_dtype, @@ -106,7 +104,7 @@ class CompelInvocation(BaseInvocation): conjunction = Compel.parse_prompt_string(self.prompt) if context.config.get().log_tokenization: - log_tokenization_for_conjunction(conjunction, tokenizer) + log_tokenization_for_conjunction(conjunction, patched_tokenizer) c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction) @@ -136,11 +134,7 @@ class SDXLPromptInvocationBase: zero_on_empty: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: tokenizer_info = context.models.load(clip_field.tokenizer) - tokenizer_model = tokenizer_info.model - assert isinstance(tokenizer_model, CLIPTokenizer) text_encoder_info = context.models.load(clip_field.text_encoder) - text_encoder_model = text_encoder_info.model - assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection)) # return zero on empty if prompt == "" and zero_on_empty: @@ -177,20 +171,23 @@ class SDXLPromptInvocationBase: ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context) with ( - ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( - tokenizer, - ti_manager, - ), + # apply all patches while the model is on the target device text_encoder_info as text_encoder, - # Apply the LoRA after text_encoder has been moved to its target device for faster patching. + tokenizer_info as tokenizer, ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers), + ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as ( + patched_tokenizer, + ti_manager, + ), ): assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)) + assert isinstance(tokenizer, CLIPTokenizer) + text_encoder = cast(CLIPTextModel, text_encoder) compel = Compel( - tokenizer=tokenizer, + tokenizer=patched_tokenizer, text_encoder=text_encoder, textual_inversion_manager=ti_manager, dtype_for_device_getter=TorchDevice.choose_torch_dtype, @@ -203,7 +200,7 @@ class SDXLPromptInvocationBase: if context.config.get().log_tokenization: # TODO: better logging for and syntax - log_tokenization_for_conjunction(conjunction, tokenizer) + log_tokenization_for_conjunction(conjunction, patched_tokenizer) # TODO: ask for optimizations? to not run text_encoder twice c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b3ac3973bf..a88eff0fcb 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -930,9 +930,9 @@ class DenoiseLatentsInvocation(BaseInvocation): assert isinstance(unet_info.model, UNet2DConditionModel) with ( ExitStack() as exit_stack, - ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config), - set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME unet_info as unet, + ModelPatcher.apply_freeu(unet, self.unet.freeu_config), + set_seamless(unet, self.unet.seamless_axes), # FIXME # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet(unet, _lora_loader()), ): diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index a8c2dd3e92..2ecb3b5d79 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -42,10 +42,26 @@ T = TypeVar("T") @dataclass class CacheRecord(Generic[T]): - """Elements of the cache.""" + """ + Elements of the cache: + + key: Unique key for each model, same as used in the models database. + model: Model in memory. + state_dict: A read-only copy of the model's state dict in RAM. It will be + used as a template for creating a copy in the VRAM. + size: Size of the model + loaded: True if the model's state dict is currently in VRAM + + Before a model is executed, the state_dict template is copied into VRAM, + and then injected into the model. When the model is finished, the VRAM + copy of the state dict is deleted, and the RAM version is reinjected + into the model. + """ key: str model: T + device: torch.device + state_dict: Optional[Dict[str, torch.Tensor]] size: int loaded: bool = False _locks: int = 0 diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 2ffe954e11..a3016a63ef 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -20,7 +20,6 @@ context. Use like this: import gc import math -import sys import time from contextlib import suppress from logging import Logger @@ -162,7 +161,9 @@ class ModelCache(ModelCacheBase[AnyModel]): if key in self._cached_models: return self.make_room(size) - cache_record = CacheRecord(key, model, size) + + state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None + cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size) self._cached_models[key] = cache_record self._cache_stack.append(key) @@ -257,17 +258,37 @@ class ModelCache(ModelCacheBase[AnyModel]): if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")): return - source_device = cache_entry.model.device + source_device = cache_entry.device # Note: We compare device types only so that 'cuda' == 'cuda:0'. # This would need to be revised to support multi-GPU. if torch.device(source_device).type == torch.device(target_device).type: return + # This roundabout method for moving the model around is done to avoid + # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM. + # When moving to VRAM, we copy (not move) each element of the state dict from + # RAM to a new state dict in VRAM, and then inject it into the model. + # This operation is slightly faster than running `to()` on the whole model. + # + # When the model needs to be removed from VRAM we simply delete the copy + # of the state dict in VRAM, and reinject the state dict that is cached + # in RAM into the model. So this operation is very fast. start_model_to_time = time.time() snapshot_before = self._capture_memory_snapshot() + try: + if cache_entry.state_dict is not None: + assert hasattr(cache_entry.model, "load_state_dict") + if target_device == self.storage_device: + cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True) + else: + new_dict: Dict[str, torch.Tensor] = {} + for k, v in cache_entry.state_dict.items(): + new_dict[k] = v.to(torch.device(target_device), copy=True) + cache_entry.model.load_state_dict(new_dict, assign=True) cache_entry.model.to(target_device) + cache_entry.device = target_device except Exception as e: # blow away cache entry self._delete_cache_entry(cache_entry) raise e @@ -347,43 +368,12 @@ class ModelCache(ModelCacheBase[AnyModel]): while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): model_key = self._cache_stack[pos] cache_entry = self._cached_models[model_key] - - refs = sys.getrefcount(cache_entry.model) - - # HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly - # going against the advice in the Python docs by using `gc.get_referrers(...)` in this way: - # https://docs.python.org/3/library/gc.html#gc.get_referrers - - # manualy clear local variable references of just finished function calls - # for some reason python don't want to collect it even by gc.collect() immidiately - if refs > 2: - while True: - cleared = False - for referrer in gc.get_referrers(cache_entry.model): - if type(referrer).__name__ == "frame": - # RuntimeError: cannot clear an executing frame - with suppress(RuntimeError): - referrer.clear() - cleared = True - # break - - # repeat if referrers changes(due to frame clear), else exit loop - if cleared: - gc.collect() - else: - break - device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None self.logger.debug( - f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}," - f" refs: {refs}" + f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}" ) - # Expected refs: - # 1 from cache_entry - # 1 from getrefcount function - # 1 from onnx runtime object - if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2): + if not cache_entry.locked: self.logger.debug( f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" )