diff --git a/invokeai/app/invocations/sd3.py b/invokeai/app/invocations/sd3.py index 385cfe3aca..8c5915112a 100644 --- a/invokeai/app/invocations/sd3.py +++ b/invokeai/app/invocations/sd3.py @@ -1,3 +1,4 @@ +from contextlib import ExitStack from typing import cast import torch @@ -23,7 +24,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX from invokeai.backend.model_manager.config import SubModelType from invokeai.backend.model_manager.load.load_base import LoadedModel -from invokeai.backend.util.devices import TorchDevice sd3_pipeline: Optional[StableDiffusion3Pipeline] = None transformer_info: Optional[LoadedModel] = None @@ -148,39 +148,35 @@ class StableDiffusion3Invocation(BaseInvocation): return v % (SEED_MAX + 1) def invoke(self, context: InvocationContext) -> LatentsOutput: - global sd3_pipeline, transformer_info, tokenizer_1_info, tokenizer_2_info, tokenizer_3_info, text_encoder_1_info, text_encoder_2_info, text_encoder_3_info + app_config = context.config.get() + load_te3 = app_config.load_sd3_encoder_3 - if not transformer_info: - transformer_info = context.models.load(self.transformer.transformer) - if not tokenizer_1_info: - tokenizer_1_info = context.models.load(self.clip.tokenizer_1) - if not tokenizer_2_info: - tokenizer_2_info = context.models.load(self.clip.tokenizer_2) - if not tokenizer_3_info: - tokenizer_3_info = context.models.load(self.clip.tokenizer_3) - if not text_encoder_1_info: - text_encoder_1_info = context.models.load(self.clip.text_encoder_1) - if not text_encoder_2_info: - text_encoder_2_info = context.models.load(self.clip.text_encoder_2) - if not text_encoder_3_info: - text_encoder_3_info = context.models.load(self.clip.text_encoder_3) + transformer_info = context.models.load(self.transformer.transformer) + tokenizer_1_info = context.models.load(self.clip.tokenizer_1) + tokenizer_2_info = context.models.load(self.clip.tokenizer_2) + text_encoder_1_info = context.models.load(self.clip.text_encoder_1) + text_encoder_2_info = context.models.load(self.clip.text_encoder_2) - with ( - tokenizer_1_info as tokenizer_1, - tokenizer_2_info as tokenizer_2, - tokenizer_3_info as tokenizer_3, - text_encoder_1_info as text_encoder_1, - text_encoder_2_info as text_encoder_2, - text_encoder_3_info as text_encoder_3, - transformer_info as transformer, - ): + with ExitStack() as stack: + tokenizer_1 = stack.enter_context(tokenizer_1_info) + tokenizer_2 = stack.enter_context(tokenizer_2_info) + text_encoder_1 = stack.enter_context(text_encoder_1_info) + text_encoder_2 = stack.enter_context(text_encoder_2_info) + transformer = stack.enter_context(transformer_info) assert isinstance(transformer, SD3Transformer2DModel) assert isinstance(text_encoder_1, CLIPTextModelWithProjection) assert isinstance(text_encoder_2, CLIPTextModelWithProjection) - assert isinstance(text_encoder_3, T5EncoderModel) assert isinstance(tokenizer_1, CLIPTokenizer) assert isinstance(tokenizer_2, CLIPTokenizer) - assert isinstance(tokenizer_3, T5TokenizerFast) + + if load_te3: + tokenizer_3 = stack.enter_context(context.models.load(self.clip.tokenizer_3)) + text_encoder_3 = stack.enter_context(context.models.load(self.clip.text_encoder_3)) + assert isinstance(text_encoder_3, T5EncoderModel) + assert isinstance(tokenizer_3, T5TokenizerFast) + else: + tokenizer_3 = None + text_encoder_3 = None scheduler = get_scheduler( context=context, @@ -189,21 +185,17 @@ class StableDiffusion3Invocation(BaseInvocation): seed=self.seed, ) - if not isinstance(sd3_pipeline, StableDiffusion3Pipeline): - sd3_pipeline = StableDiffusion3Pipeline( - transformer=transformer, - vae=FakeVae(), - text_encoder=text_encoder_1, - text_encoder_2=text_encoder_2, - text_encoder_3=text_encoder_3, - tokenizer=tokenizer_1, - tokenizer_2=tokenizer_2, - tokenizer_3=tokenizer_3, - scheduler=scheduler, - ) - - sd3_pipeline.components["scheduler"] = scheduler - sd3_pipeline.to(TorchDevice.choose_torch_device().type) + sd3_pipeline = StableDiffusion3Pipeline( + transformer=transformer, + vae=FakeVae(), + text_encoder=text_encoder_1, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer_1, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + scheduler=scheduler, + ) results = sd3_pipeline( self.positive_prompt, diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 496988e853..46c870396b 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -104,6 +104,7 @@ class InvokeAIAppConfig(BaseSettings): vram: Amount of VRAM reserved for model storage (GB). convert_cache: Maximum size of on-disk converted models cache (GB). lazy_offload: Keep models in VRAM until their space is needed. + load_sd3_encoder_3: Load the memory-intensive SD3 text_encoder_3. log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour. device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps` precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32` @@ -173,6 +174,7 @@ class InvokeAIAppConfig(BaseSettings): vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).") convert_cache: float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB).") lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.") + load_sd3_encoder_3: bool = Field(default=False, description="Load the memory-intensive SD3 text_encoder_3.") log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.") # DEVICE diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index a63cc66a86..717f73268d 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -84,6 +84,8 @@ class ModelLoader(ModelLoaderBase): except IndexError: pass + self._logger.info(f"Loading {config.key}:{submodel_type}") + cache_path: Path = self._convert_cache.cache_path(str(model_path)) if self._needs_conversion(config, model_path, cache_path): loaded_model = self._do_convert(config, model_path, cache_path, submodel_type) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index a071570c22..5924f5613a 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -161,11 +161,13 @@ class ModelCache(ModelCacheBase[AnyModel]): self.make_room(size) is_quantized = hasattr(model, "is_quantized") and model.is_quantized - state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None + state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not is_quantized else None cache_record = CacheRecord( key=key, model=model, - device=self._storage_device, + device=self._execution_device + if is_quantized + else self._storage_device, # quantized models are loaded directly into CUDA is_quantized=is_quantized, state_dict=state_dict, size=size, @@ -235,26 +237,28 @@ class ModelCache(ModelCacheBase[AnyModel]): reserved = self._max_vram_cache_size * GIG vram_in_use = torch.cuda.memory_allocated() + size_required self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB") - delete_it = False for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): if vram_in_use <= reserved: break + + # only way to remove a quantized model from VRAM is to + # delete it completely - it can't be moved from device to device + if cache_entry.is_quantized: + self._delete_cache_entry(cache_entry) + vram_in_use = torch.cuda.memory_allocated() + size_required + continue + if not cache_entry.loaded: continue + if not cache_entry.locked: - if cache_entry.is_quantized: - self._delete_cache_entry(cache_entry) - delete_it = True - else: - self.move_model_to_device(cache_entry, self.storage_device) - cache_entry.loaded = False + self.move_model_to_device(cache_entry, self.storage_device) + cache_entry.loaded = False vram_in_use = torch.cuda.memory_allocated() + size_required self.logger.debug( f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB" ) - - if delete_it: - del cache_entry + gc.collect() TorchDevice.empty_cache() def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: @@ -268,7 +272,7 @@ class ModelCache(ModelCacheBase[AnyModel]): self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") source_device = cache_entry.device - # Note: We compare device types only so that 'cuda' == 'cuda:0'. + # Note: We compare device types so that 'cuda' == 'cuda:0'. # This would need to be revised to support multi-GPU. if torch.device(source_device).type == torch.device(target_device).type: return @@ -277,9 +281,6 @@ class ModelCache(ModelCacheBase[AnyModel]): if not hasattr(cache_entry.model, "to"): return - if cache_entry.is_quantized: # can't move quantized models around - return - # This roundabout method for moving the model around is done to avoid # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM. # When moving to VRAM, we copy (not move) each element of the state dict from @@ -422,5 +423,6 @@ class ModelCache(ModelCacheBase[AnyModel]): def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None: self._cache_stack.remove(cache_entry.key) del self._cached_models[cache_entry.key] + del cache_entry gc.collect() TorchDevice.empty_cache()