diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index 0e7ebd6d69..a19dda30b8 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -40,7 +40,10 @@ class FluxTextEncoderInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> FluxConditioningOutput: - t5_embeddings, clip_embeddings = self._encode_prompt(context) + # Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally + # scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary). + t5_embeddings = self._t5_encode(context) + clip_embeddings = self._clip_encode(context) conditioning_data = ConditioningFieldData( conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)] ) @@ -48,12 +51,7 @@ class FluxTextEncoderInvocation(BaseInvocation): conditioning_name = context.conditioning.save(conditioning_data) return FluxConditioningOutput.build(conditioning_name) - def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]: - # Load CLIP. - clip_tokenizer_info = context.models.load(self.clip.tokenizer) - clip_text_encoder_info = context.models.load(self.clip.text_encoder) - - # Load T5. + def _t5_encode(self, context: InvocationContext) -> torch.Tensor: t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer) t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder) @@ -70,6 +68,15 @@ class FluxTextEncoderInvocation(BaseInvocation): prompt_embeds = t5_encoder(prompt) + assert isinstance(prompt_embeds, torch.Tensor) + return prompt_embeds + + def _clip_encode(self, context: InvocationContext) -> torch.Tensor: + clip_tokenizer_info = context.models.load(self.clip.tokenizer) + clip_text_encoder_info = context.models.load(self.clip.text_encoder) + + prompt = [self.prompt] + with ( clip_text_encoder_info as clip_text_encoder, clip_tokenizer_info as clip_tokenizer, @@ -81,6 +88,5 @@ class FluxTextEncoderInvocation(BaseInvocation): pooled_prompt_embeds = clip_encoder(prompt) - assert isinstance(prompt_embeds, torch.Tensor) assert isinstance(pooled_prompt_embeds, torch.Tensor) - return prompt_embeds, pooled_prompt_embeds + return pooled_prompt_embeds diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index b6ff06c67b..248122d8cd 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -58,13 +58,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: - # Load the conditioning data. - cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name) - assert len(cond_data.conditionings) == 1 - flux_conditioning = cond_data.conditionings[0] - assert isinstance(flux_conditioning, FLUXConditioningInfo) - - latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds) + latents = self._run_diffusion(context) image = self._run_vae_decoding(context, latents) image_dto = context.images.save(image=image) return ImageOutput.build(image_dto) @@ -72,12 +66,20 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): def _run_diffusion( self, context: InvocationContext, - clip_embeddings: torch.Tensor, - t5_embeddings: torch.Tensor, ): - transformer_info = context.models.load(self.transformer.transformer) inference_dtype = torch.bfloat16 + # Load the conditioning data. + cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name) + assert len(cond_data.conditionings) == 1 + flux_conditioning = cond_data.conditionings[0] + assert isinstance(flux_conditioning, FLUXConditioningInfo) + flux_conditioning = flux_conditioning.to(dtype=inference_dtype) + t5_embeddings = flux_conditioning.t5_embeds + clip_embeddings = flux_conditioning.clip_embeds + + transformer_info = context.models.load(self.transformer.transformer) + # Prepare input noise. x = get_noise( num_samples=1, @@ -88,24 +90,19 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): seed=self.seed, ) - img, img_ids = prepare_latent_img_patches(x) + x, img_ids = prepare_latent_img_patches(x) is_schnell = "schnell" in transformer_info.config.config_path timesteps = get_schedule( num_steps=self.num_steps, - image_seq_len=img.shape[1], + image_seq_len=x.shape[1], shift=not is_schnell, ) bs, t5_seq_len, _ = t5_embeddings.shape txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()) - # HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from - # disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems - # if the cache is not empty. - context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30) - with transformer_info as transformer: assert isinstance(transformer, Flux) @@ -140,7 +137,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): x = denoise( model=transformer, - img=img, + img=x, img_ids=img_ids, txt=t5_embeddings, txt_ids=txt_ids, diff --git a/invokeai/backend/flux/sampling.py b/invokeai/backend/flux/sampling.py index 19de48ae81..7a35b0aedf 100644 --- a/invokeai/backend/flux/sampling.py +++ b/invokeai/backend/flux/sampling.py @@ -111,16 +111,7 @@ def denoise( step_callback: Callable[[], None], guidance: float = 4.0, ): - dtype = model.txt_in.bias.dtype - - # TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller. - img = img.to(dtype=dtype) - img_ids = img_ids.to(dtype=dtype) - txt = txt.to(dtype=dtype) - txt_ids = txt_ids.to(dtype=dtype) - vec = vec.to(dtype=dtype) - - # this is ignored for schnell + # guidance_vec is ignored for schnell. guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) @@ -168,9 +159,9 @@ def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, img = repeat(img, "1 ... -> bs ...", bs=bs) # Generate patch position ids. - img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :] + img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) return img, img_ids diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index ce9811534e..d4e88857fa 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -72,6 +72,7 @@ class ModelLoader(ModelLoaderBase): pass config.path = str(self._get_model_path(config)) + self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type)) loaded_model = self._load_model(config, submodel_type) self._ram_cache.put( diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index 012fd42d55..97fd401da0 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -193,15 +193,6 @@ class ModelCacheBase(ABC, Generic[T]): """ pass - @abstractmethod - def exists( - self, - key: str, - submodel_type: Optional[SubModelType] = None, - ) -> bool: - """Return true if the model identified by key and submodel_type is in the cache.""" - pass - @abstractmethod def cache_size(self) -> int: """Get the total size of the models currently cached.""" diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 482585e8e7..4b0ebbd40e 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -1,22 +1,6 @@ # Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team # TODO: Add Stalker's proper name to copyright -""" -Manage a RAM cache of diffusion/transformer models for fast switching. -They are moved between GPU VRAM and CPU RAM as necessary. If the cache -grows larger than a preset maximum, then the least recently used -model will be cleared and (re)loaded from disk when next needed. - -The cache returns context manager generators designed to load the -model into the GPU within the context, and unload outside the -context. Use like this: - - cache = ModelCache(max_cache_size=7.5) - with cache.get_model('runwayml/stable-diffusion-1-5') as SD1, - cache.get_model('stabilityai/stable-diffusion-2') as SD2: - do_something_in_GPU(SD1,SD2) - - -""" +""" """ import gc import math @@ -40,45 +24,64 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger -# Maximum size of the cache, in gigs -# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously -DEFAULT_MAX_CACHE_SIZE = 6.0 - -# amount of GPU memory to hold in reserve for use by generations (GB) -DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75 - -# actual size of a gig -GIG = 1073741824 +# Size of a GB in bytes. +GB = 2**30 # Size of a MB in bytes. MB = 2**20 class ModelCache(ModelCacheBase[AnyModel]): - """Implementation of ModelCacheBase.""" + """A cache for managing models in memory. + + The cache is based on two levels of model storage: + - execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu"). + - storage_device: The device where models are offloaded when not in active use (typically "cpu"). + + The model cache is based on the following assumptions: + - storage_device_mem_size > execution_device_mem_size + - disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time + + A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on + the execution_device. + + Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced + on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload + policy. The storage_device cache uses a least-recently-used (LRU) offload policy. + + Note: Neither of these offload policies has really been compared against alternatives. It's likely that different + policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW + configuration. + + The cache returns context manager generators designed to load the model into the execution device (often GPU) within + the context, and unload outside the context. + + Example usage: + ``` + cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0) + with cache.get_model('runwayml/stable-diffusion-1-5') as SD1: + do_something_on_gpu(SD1) + ``` + """ def __init__( self, - max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, - max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE, + max_cache_size: float, + max_vram_cache_size: float, execution_device: torch.device = torch.device("cuda"), storage_device: torch.device = torch.device("cpu"), - precision: torch.dtype = torch.float16, - sequential_offload: bool = False, lazy_offloading: bool = True, - sha_chunksize: int = 16777216, log_memory_usage: bool = False, logger: Optional[Logger] = None, ): """ Initialize the model RAM cache. - :param max_cache_size: Maximum size of the RAM cache [6.0 GB] + :param max_cache_size: Maximum size of the storage_device cache in GBs. + :param max_vram_cache_size: Maximum size of the execution_device cache in GBs. :param execution_device: Torch device to load active model into [torch.device('cuda')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')] - :param precision: Precision for loaded models [torch.float16] - :param lazy_offloading: Keep model in VRAM until another model needs to be loaded - :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially + :param lazy_offloading: Keep model in VRAM until another model needs to be loaded. :param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's @@ -86,7 +89,6 @@ class ModelCache(ModelCacheBase[AnyModel]): """ # allow lazy offloading only when vram cache enabled self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0 - self._precision: torch.dtype = precision self._max_cache_size: float = max_cache_size self._max_vram_cache_size: float = max_vram_cache_size self._execution_device: torch.device = execution_device @@ -145,15 +147,6 @@ class ModelCache(ModelCacheBase[AnyModel]): total += cache_record.size return total - def exists( - self, - key: str, - submodel_type: Optional[SubModelType] = None, - ) -> bool: - """Return true if the model identified by key and submodel_type is in the cache.""" - key = self._make_cache_key(key, submodel_type) - return key in self._cached_models - def put( self, key: str, @@ -203,7 +196,7 @@ class ModelCache(ModelCacheBase[AnyModel]): # more stats if self.stats: stats_name = stats_name or key - self.stats.cache_size = int(self._max_cache_size * GIG) + self.stats.cache_size = int(self._max_cache_size * GB) self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size()) self.stats.in_cache = len(self._cached_models) self.stats.loaded_model_sizes[stats_name] = max( @@ -231,10 +224,13 @@ class ModelCache(ModelCacheBase[AnyModel]): return model_key def offload_unlocked_models(self, size_required: int) -> None: - """Move any unused models from VRAM.""" - reserved = self._max_vram_cache_size * GIG + """Offload models from the execution_device to make room for size_required. + + :param size_required: The amount of space to clear in the execution_device cache, in bytes. + """ + reserved = self._max_vram_cache_size * GB vram_in_use = torch.cuda.memory_allocated() + size_required - self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB") + self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB") for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): if vram_in_use <= reserved: break @@ -245,7 +241,7 @@ class ModelCache(ModelCacheBase[AnyModel]): cache_entry.loaded = False vram_in_use = torch.cuda.memory_allocated() + size_required self.logger.debug( - f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB" + f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB" ) TorchDevice.empty_cache() @@ -303,7 +299,7 @@ class ModelCache(ModelCacheBase[AnyModel]): self.logger.debug( f"Moved model '{cache_entry.key}' from {source_device} to" f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s." - f"Estimated model size: {(cache_entry.size/GIG):.3f} GB." + f"Estimated model size: {(cache_entry.size/GB):.3f} GB." f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" ) @@ -326,14 +322,14 @@ class ModelCache(ModelCacheBase[AnyModel]): f"Moving model '{cache_entry.key}' from {source_device} to" f" {target_device} caused an unexpected change in VRAM usage. The model's" " estimated size may be incorrect. Estimated model size:" - f" {(cache_entry.size/GIG):.3f} GB.\n" + f" {(cache_entry.size/GB):.3f} GB.\n" f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" ) def print_cuda_stats(self) -> None: """Log CUDA diagnostics.""" - vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) - ram = "%4.2fG" % (self.cache_size() / GIG) + vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB) + ram = "%4.2fG" % (self.cache_size() / GB) in_ram_models = 0 in_vram_models = 0 @@ -353,17 +349,20 @@ class ModelCache(ModelCacheBase[AnyModel]): ) def make_room(self, size: int) -> None: - """Make enough room in the cache to accommodate a new model of indicated size.""" - # calculate how much memory this model will require - # multiplier = 2 if self.precision==torch.float32 else 1 + """Make enough room in the cache to accommodate a new model of indicated size. + + Note: This function deletes all of the cache's internal references to a model in order to free it. If there are + external references to the model, there's nothing that the cache can do about it, and those models will not be + garbage-collected. + """ bytes_needed = size - maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes + maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes current_size = self.cache_size() if current_size + bytes_needed > maximum_size: self.logger.debug( - f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional" - f" {(bytes_needed/GIG):.2f} GB" + f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional" + f" {(bytes_needed/GB):.2f} GB" ) self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}") @@ -380,7 +379,7 @@ class ModelCache(ModelCacheBase[AnyModel]): if not cache_entry.locked: self.logger.debug( - f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" + f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)" ) current_size -= cache_entry.size models_cleared += 1 diff --git a/invokeai/backend/quantization/bnb_llm_int8.py b/invokeai/backend/quantization/bnb_llm_int8.py index b92717cbc5..02f94936e9 100644 --- a/invokeai/backend/quantization/bnb_llm_int8.py +++ b/invokeai/backend/quantization/bnb_llm_int8.py @@ -54,8 +54,10 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt): # See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format. scb = state_dict.pop(prefix + "SCB", None) - # weight_format is unused, but we pop it so we can validate that there are no unexpected keys. - _weight_format = state_dict.pop(prefix + "weight_format", None) + + # Currently, we only support weight_format=0. + weight_format = state_dict.pop(prefix + "weight_format", None) + assert weight_format == 0 # TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs` # rather than raising an exception to correctly implement this API. @@ -89,6 +91,14 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt): ) self.bias = bias if bias is None else torch.nn.Parameter(bias) + # Reset the state. The persisted fields are based on the initialization behaviour in + # `bnb.nn.Linear8bitLt.__init__()`. + new_state = bnb.MatmulLtState() + new_state.threshold = self.state.threshold + new_state.has_fp16_weights = False + new_state.use_pool = self.state.use_pool + self.state = new_state + def _convert_linear_layers_to_llm_8bit( module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = "" diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index c5fda909c7..b7e9038cf7 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -43,6 +43,11 @@ class FLUXConditioningInfo: clip_embeds: torch.Tensor t5_embeds: torch.Tensor + def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None): + self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype) + self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype) + return self + @dataclass class ConditioningFieldData: diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py index 101215640a..f24b6db3e1 100644 --- a/invokeai/backend/util/__init__.py +++ b/invokeai/backend/util/__init__.py @@ -3,10 +3,9 @@ Initialization file for invokeai.backend.util """ from invokeai.backend.util.logging import InvokeAILogger -from invokeai.backend.util.util import GIG, Chdir, directory_size +from invokeai.backend.util.util import Chdir, directory_size __all__ = [ - "GIG", "directory_size", "Chdir", "InvokeAILogger", diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index b3466ddba9..cc654e4d39 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -7,9 +7,6 @@ from pathlib import Path from PIL import Image -# actual size of a gig -GIG = 1073741824 - def slugify(value: str, allow_unicode: bool = False) -> str: """