FLUX memory management improvements (#6791)

## Summary

This PR contains several improvements to memory management for FLUX
workflows.

It is now possible to achieve better FLUX model caching performance, but
this still requires users to manually configure their `ram`/`vram`
settings. E.g. a `vram` setting of 16.0 should allow for all quantized
FLUX models to be kept in memory on the GPU.

Changes:
- Check the size of a model on disk and free the requisite space in the
model cache before loading it. (This behaviour existed previously, but
was removed in https://github.com/invoke-ai/InvokeAI/pull/6072/files.
The removal did not seem to be intentional).
- Removed the hack to free 24GB of space in the cache before loading the
FLUX model.
- Split the T5 embedding and CLIP embedding steps into separate
functions so that the two models don't both have to be held in RAM at
the same time.
- Fix a bug in `InvokeLinear8bitLt` that was causing some tensors to be
left on the GPU when the model was offloaded to the CPU. (This class is
getting very messy due to the non-standard state_dict handling in
`bnb.nn.Linear8bitLt`. )
- Tidy up some dtype handling in FluxTextToImageInvocation to avoid
situations where we hold references to two copies of the same tensor
unnecessarily.
- (minor) Misc cleanup of ModelCache: improve docs and remove unused
vars.

Future:
We should revisit our default ram/vram configs. The current defaults are
very conservative, and users could see major performance improvements
from tuning these values.

## QA Instructions

I tested the FLUX workflow with the following configurations and
verified that the cache hit rates and memory usage matched the expected
behaviour:
- `ram = 16` and `vram = 16`
- `ram = 16` and `vram = 1`
- `ram = 1` and `vram = 1`

Note that the changes in this PR are not isolated to FLUX. Since we now
check the size of models on disk, we may see slight changes in model
cache offload patterns for other models as well.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
This commit is contained in:
Ryan Dick 2024-08-29 15:17:45 -04:00 committed by GitHub
commit 87261bdbc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 114 additions and 118 deletions

View File

@ -40,7 +40,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
t5_embeddings, clip_embeddings = self._encode_prompt(context)
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
t5_embeddings = self._t5_encode(context)
clip_embeddings = self._clip_encode(context)
conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
)
@ -48,12 +51,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput.build(conditioning_name)
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
# Load CLIP.
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
# Load T5.
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
@ -70,6 +68,15 @@ class FluxTextEncoderInvocation(BaseInvocation):
prompt_embeds = t5_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
prompt = [self.prompt]
with (
clip_text_encoder_info as clip_text_encoder,
clip_tokenizer_info as clip_tokenizer,
@ -81,6 +88,5 @@ class FluxTextEncoderInvocation(BaseInvocation):
pooled_prompt_embeds = clip_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return prompt_embeds, pooled_prompt_embeds
return pooled_prompt_embeds

View File

@ -58,13 +58,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
latents = self._run_diffusion(context)
image = self._run_vae_decoding(context, latents)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)
@ -72,12 +66,20 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def _run_diffusion(
self,
context: InvocationContext,
clip_embeddings: torch.Tensor,
t5_embeddings: torch.Tensor,
):
transformer_info = context.models.load(self.transformer.transformer)
inference_dtype = torch.bfloat16
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
transformer_info = context.models.load(self.transformer.transformer)
# Prepare input noise.
x = get_noise(
num_samples=1,
@ -88,24 +90,19 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
seed=self.seed,
)
img, img_ids = prepare_latent_img_patches(x)
x, img_ids = prepare_latent_img_patches(x)
is_schnell = "schnell" in transformer_info.config.config_path
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=img.shape[1],
image_seq_len=x.shape[1],
shift=not is_schnell,
)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty.
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
with transformer_info as transformer:
assert isinstance(transformer, Flux)
@ -140,7 +137,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
x = denoise(
model=transformer,
img=img,
img=x,
img_ids=img_ids,
txt=t5_embeddings,
txt_ids=txt_ids,

View File

@ -111,16 +111,7 @@ def denoise(
step_callback: Callable[[], None],
guidance: float = 4.0,
):
dtype = model.txt_in.bias.dtype
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
img = img.to(dtype=dtype)
img_ids = img_ids.to(dtype=dtype)
txt = txt.to(dtype=dtype)
txt_ids = txt_ids.to(dtype=dtype)
vec = vec.to(dtype=dtype)
# this is ignored for schnell
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
@ -168,9 +159,9 @@ def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor,
img = repeat(img, "1 ... -> bs ...", bs=bs)
# Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
return img, img_ids

View File

@ -72,6 +72,7 @@ class ModelLoader(ModelLoaderBase):
pass
config.path = str(self._get_model_path(config))
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
loaded_model = self._load_model(config, submodel_type)
self._ram_cache.put(

View File

@ -193,15 +193,6 @@ class ModelCacheBase(ABC, Generic[T]):
"""
pass
@abstractmethod
def exists(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
) -> bool:
"""Return true if the model identified by key and submodel_type is in the cache."""
pass
@abstractmethod
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""

View File

@ -1,22 +1,6 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
# TODO: Add Stalker's proper name to copyright
"""
Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
The cache returns context manager generators designed to load the
model into the GPU within the context, and unload outside the
context. Use like this:
cache = ModelCache(max_cache_size=7.5)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
do_something_in_GPU(SD1,SD2)
"""
""" """
import gc
import math
@ -40,45 +24,64 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
# amount of GPU memory to hold in reserve for use by generations (GB)
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
# actual size of a gig
GIG = 1073741824
# Size of a GB in bytes.
GB = 2**30
# Size of a MB in bytes.
MB = 2**20
class ModelCache(ModelCacheBase[AnyModel]):
"""Implementation of ModelCacheBase."""
"""A cache for managing models in memory.
The cache is based on two levels of model storage:
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
The model cache is based on the following assumptions:
- storage_device_mem_size > execution_device_mem_size
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
the execution_device.
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
configuration.
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
the context, and unload outside the context.
Example usage:
```
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
do_something_on_gpu(SD1)
```
"""
def __init__(
self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
max_cache_size: float,
max_vram_cache_size: float,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
):
"""
Initialize the model RAM cache.
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param max_cache_size: Maximum size of the storage_device cache in GBs.
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded.
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
@ -86,7 +89,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""
# allow lazy offloading only when vram cache enabled
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device
@ -145,15 +147,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
total += cache_record.size
return total
def exists(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
) -> bool:
"""Return true if the model identified by key and submodel_type is in the cache."""
key = self._make_cache_key(key, submodel_type)
return key in self._cached_models
def put(
self,
key: str,
@ -203,7 +196,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
# more stats
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.cache_size = int(self._max_cache_size * GB)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
@ -231,10 +224,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
return model_key
def offload_unlocked_models(self, size_required: int) -> None:
"""Move any unused models from VRAM."""
reserved = self._max_vram_cache_size * GIG
"""Offload models from the execution_device to make room for size_required.
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
"""
reserved = self._max_vram_cache_size * GB
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
@ -245,7 +241,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
)
TorchDevice.empty_cache()
@ -303,7 +299,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self.logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
@ -326,14 +322,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
f"Moving model '{cache_entry.key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n"
f" {(cache_entry.size/GB):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % (self.cache_size() / GIG)
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
ram = "%4.2fG" % (self.cache_size() / GB)
in_ram_models = 0
in_vram_models = 0
@ -353,17 +349,20 @@ class ModelCache(ModelCacheBase[AnyModel]):
)
def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size."""
# calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1
"""Make enough room in the cache to accommodate a new model of indicated size.
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
external references to the model, there's nothing that the cache can do about it, and those models will not be
garbage-collected.
"""
bytes_needed = size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
current_size = self.cache_size()
if current_size + bytes_needed > maximum_size:
self.logger.debug(
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GIG):.2f} GB"
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GB):.2f} GB"
)
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
@ -380,7 +379,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
if not cache_entry.locked:
self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1

View File

@ -54,8 +54,10 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
scb = state_dict.pop(prefix + "SCB", None)
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
_weight_format = state_dict.pop(prefix + "weight_format", None)
# Currently, we only support weight_format=0.
weight_format = state_dict.pop(prefix + "weight_format", None)
assert weight_format == 0
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
# rather than raising an exception to correctly implement this API.
@ -89,6 +91,14 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
# Reset the state. The persisted fields are based on the initialization behaviour in
# `bnb.nn.Linear8bitLt.__init__()`.
new_state = bnb.MatmulLtState()
new_state.threshold = self.state.threshold
new_state.has_fp16_weights = False
new_state.use_pool = self.state.use_pool
self.state = new_state
def _convert_linear_layers_to_llm_8bit(
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""

View File

@ -43,6 +43,11 @@ class FLUXConditioningInfo:
clip_embeds: torch.Tensor
t5_embeds: torch.Tensor
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype)
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
return self
@dataclass
class ConditioningFieldData:

View File

@ -3,10 +3,9 @@ Initialization file for invokeai.backend.util
"""
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.util import GIG, Chdir, directory_size
from invokeai.backend.util.util import Chdir, directory_size
__all__ = [
"GIG",
"directory_size",
"Chdir",
"InvokeAILogger",

View File

@ -7,9 +7,6 @@ from pathlib import Path
from PIL import Image
# actual size of a gig
GIG = 1073741824
def slugify(value: str, allow_unicode: bool = False) -> str:
"""