new OffloadingDevice loads one model at a time, on demand (#2596)

* new OffloadingDevice loads one model at a time, on demand

* fixup! new OffloadingDevice loads one model at a time, on demand

* fix(prompt_to_embeddings): call the text encoder directly instead of its forward method

allowing any associated hooks to run with it.

* more attempts to get things on the right device from the offloader

* more attempts to get things on the right device from the offloader

* make offloading methods an explicit part of the pipeline interface

* inlining some calls where device is only used once

* ensure model group is ready after pipeline.to is called

* fixup! Strategize slicing based on free [V]RAM (#2572)

* doc(offloading): docstrings for offloading.ModelGroup

* doc(offloading): docstrings for offloading-related pipeline methods

* refactor(offloading): s/SimpleModelGroup/FullyLoadedModelGroup

* refactor(offloading): s/HotSeatModelGroup/LazilyLoadedModelGroup

to frame it is the same terms as "FullyLoadedModelGroup"

---------

Co-authored-by: Damian Stewart <null@damianstewart.com>
This commit is contained in:
Kevin Turner 2023-02-16 15:48:27 -08:00 committed by GitHub
parent 2468ba7445
commit 8a0d45ac5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 371 additions and 47 deletions

View File

@ -213,7 +213,9 @@ class Generate:
print('>> xformers not installed')
# model caching system for fast switching
self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)
self.model_manager = ModelManager(mconfig, self.device, self.precision,
max_loaded_models=max_loaded_models,
sequential_offload=self.free_gpu_mem)
# don't accept invalid models
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
model = model or fallback
@ -480,7 +482,6 @@ class Generate:
self.model.cond_stage_model.device = self.model.device
self.model.cond_stage_model.to(self.model.device)
except AttributeError:
print(">> Warning: '--free_gpu_mem' is not yet supported when generating image using model based on HuggingFace Diffuser.")
pass
try:

View File

@ -3,39 +3,34 @@ from __future__ import annotations
import dataclasses
import inspect
import secrets
import sys
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any
if sys.version_info < (3, 10):
from typing_extensions import ParamSpec
else:
from typing import ParamSpec
import PIL.Image
import einops
import psutil
import torch
import torchvision.transforms as T
from diffusers.utils.import_utils import is_xformers_available
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec
from ldm.invoke.globals import Globals
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
from ldm.modules.textual_inversion_manager import TextualInversionManager
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
@dataclass
@ -264,6 +259,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_model_group: ModelGroup
ID_LENGTH = 8
@ -273,7 +269,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: KarrasDiffusionSchedulers,
safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
@ -303,8 +299,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
textual_inversion_manager=self.textual_inversion_manager
)
self._model_group = FullyLoadedModelGroup(self.unet.device)
self._model_group.install(*self._submodels)
def _adjust_memory_efficient_attention(self, latents: Torch.tensor):
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
@ -322,7 +321,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
elif self.device.type == 'cuda':
mem_free, _ = torch.cuda.mem_get_info(self.device)
else:
raise ValueError(f"unrecognized device {device}")
raise ValueError(f"unrecognized device {self.device}")
# input tensor of [1, 4, h/8, w/8]
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
@ -336,6 +335,66 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.disable_attention_slicing()
def enable_offload_submodels(self, device: torch.device):
"""
Offload each submodel when it's not in use.
Useful for low-vRAM situations where the size of the model in memory is a big chunk of
the total available resource, and you want to free up as much for inference as possible.
This requires more moving parts and may add some delay as the U-Net is swapped out for the
VAE and vice-versa.
"""
models = self._submodels
if self._model_group is not None:
self._model_group.uninstall(*models)
group = LazilyLoadedModelGroup(device)
group.install(*models)
self._model_group = group
def disable_offload_submodels(self):
"""
Leave all submodels loaded.
Appropriate for cases where the size of the model in memory is small compared to the memory
required for inference. Avoids the delay and complexity of shuffling the submodels to and
from the GPU.
"""
models = self._submodels
if self._model_group is not None:
self._model_group.uninstall(*models)
group = FullyLoadedModelGroup(self._model_group.execution_device)
group.install(*models)
self._model_group = group
def offload_all(self):
"""Offload all this pipeline's models to CPU."""
self._model_group.offload_current()
def ready(self):
"""
Ready this pipeline's models.
i.e. pre-load them to the GPU if appropriate.
"""
self._model_group.ready()
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
if torch_device is None:
return self
self._model_group.set_device(torch_device)
self._model_group.ready()
@property
def device(self) -> torch.device:
return self._model_group.execution_device
@property
def _submodels(self) -> Sequence[torch.nn.Module]:
module_names, _, _ = self.extract_init_dict(dict(self.config))
values = [getattr(self, name) for name in module_names.keys()]
return [m for m in values if isinstance(m, torch.nn.Module)]
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData,
*,
@ -377,7 +436,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
callback: Callable[[PipelineIntermediateState], None] = None
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if timesteps is None:
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
self.scheduler.set_timesteps(num_inference_steps, device=self._model_group.device_for(self.unet))
timesteps = self.scheduler.timesteps
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
result: PipelineIntermediateState = infer_latents_from_embeddings(
@ -409,7 +468,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
batch_size = latents.shape[0]
batched_t = torch.full((batch_size,), timesteps[0],
dtype=timesteps.dtype, device=self.unet.device)
dtype=timesteps.dtype, device=self._model_group.device_for(self.unet))
latents = self.scheduler.add_noise(latents, noise, batched_t)
attention_map_saver: Optional[AttentionMapSaver] = None
@ -493,9 +552,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
).add_mask_channels(latents)
return self.unet(sample=latents,
timestep=t,
encoder_hidden_states=text_embeddings,
# First three args should be positional, not keywords, so torch hooks can see them.
return self.unet(latents, t, text_embeddings,
cross_attention_kwargs=cross_attention_kwargs).sample
def img2img_from_embeddings(self,
@ -514,9 +572,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
# 6. Prepare latent variables
device = self.unet.device
latents_dtype = self.unet.dtype
initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
initial_latents = self.non_noised_latents_from_image(
init_image, device=self._model_group.device_for(self.unet),
dtype=self.unet.dtype)
noise = noise_func(initial_latents)
return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps,
@ -529,7 +587,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
strength,
noise: torch.Tensor, run_id=None, callback=None
) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, self.unet.device)
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength,
device=self._model_group.device_for(self.unet))
result_latents, result_attention_maps = self.latents_from_embeddings(
initial_latents, num_inference_steps, conditioning_data,
timesteps=timesteps,
@ -568,7 +627,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id=None,
noise_func=None,
) -> InvokeAIStableDiffusionPipelineOutput:
device = self.unet.device
device = self._model_group.device_for(self.unet)
latents_dtype = self.unet.dtype
if isinstance(init_image, PIL.Image.Image):
@ -632,6 +691,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
self.vae.to('cpu')
init_image = init_image.to('cpu')
else:
self._model_group.load(self.vae)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
if device.type == 'mps':
@ -643,8 +704,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def check_for_safety(self, output, dtype):
with torch.inference_mode():
screened_images, has_nsfw_concept = self.run_safety_checker(
output.images, device=self._execution_device, dtype=dtype)
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
screened_attention_map_saver = None
if has_nsfw_concept is None or not has_nsfw_concept:
screened_attention_map_saver = output.attention_map_saver
@ -653,6 +713,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# block the attention maps if NSFW content is detected
attention_map_saver=screened_attention_map_saver)
def run_safety_checker(self, image, device=None, dtype=None):
# overriding to use the model group for device info instead of requiring the caller to know.
if self.safety_checker is not None:
device = self._model_group.device_for(self.safety_checker)
return super().run_safety_checker(image, device, dtype)
@torch.inference_mode()
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
"""
@ -662,7 +728,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
text=c,
fragment_weights=fragment_weights,
should_return_tokens=return_tokens,
device=self.device)
device=self._model_group.device_for(self.unet))
@property
def cond_stage_model(self):
@ -683,6 +749,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
"""Compatible with DiffusionWrapper"""
return self.unet.in_channels
def decode_latents(self, latents):
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
self._model_group.load(self.vae)
return super().decode_latents(latents)
def debug_latents(self, latents, msg):
with torch.inference_mode():
from ldm.util import debug_image

View File

@ -25,8 +25,6 @@ import torch
import transformers
from diffusers import AutoencoderKL
from diffusers import logging as dlogging
from diffusers.utils.logging import (get_verbosity, set_verbosity,
set_verbosity_error)
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
@ -49,9 +47,10 @@ class ModelManager(object):
def __init__(
self,
config: OmegaConf,
device_type: str = "cpu",
device_type: str | torch.device = "cpu",
precision: str = "float16",
max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload = False
):
"""
Initialize with the path to the models.yaml config file,
@ -69,6 +68,7 @@ class ModelManager(object):
self.models = {}
self.stack = [] # this is an LRU FIFO
self.current_model = None
self.sequential_offload = sequential_offload
def valid_model(self, model_name: str) -> bool:
"""
@ -529,7 +529,10 @@ class ModelManager(object):
dlogging.set_verbosity(verbosity)
assert pipeline is not None, OSError(f'"{name_or_path}" could not be loaded')
pipeline.to(self.device)
if self.sequential_offload:
pipeline.enable_offload_submodels(self.device)
else:
pipeline.to(self.device)
model_hash = self._diffuser_sha256(name_or_path)
@ -748,7 +751,6 @@ class ModelManager(object):
into models.yaml.
"""
new_config = None
import transformers
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
@ -995,12 +997,12 @@ class ModelManager(object):
if self.device == "cpu":
return model
# diffusers really really doesn't like us moving a float16 model onto CPU
verbosity = get_verbosity()
set_verbosity_error()
if isinstance(model, StableDiffusionGeneratorPipeline):
model.offload_all()
return model
model.cond_stage_model.device = "cpu"
model.to("cpu")
set_verbosity(verbosity)
for submodel in ("first_stage_model", "cond_stage_model", "model"):
try:
@ -1013,6 +1015,10 @@ class ModelManager(object):
if self.device == "cpu":
return model
if isinstance(model, StableDiffusionGeneratorPipeline):
model.ready()
return model
model.to(self.device)
model.cond_stage_model.device = self.device
@ -1163,7 +1169,7 @@ class ModelManager(object):
strategy.execute()
@staticmethod
def _abs_path(path: Union(str, Path)) -> Path:
def _abs_path(path: str | Path) -> Path:
if path is None or Path(path).is_absolute():
return path
return Path(Globals.root, path).resolve()

247
ldm/invoke/offloading.py Normal file
View File

@ -0,0 +1,247 @@
from __future__ import annotations
import warnings
import weakref
from abc import ABCMeta, abstractmethod
from collections.abc import MutableMapping
from typing import Callable
import torch
from accelerate.utils import send_to_device
from torch.utils.hooks import RemovableHandle
OFFLOAD_DEVICE = torch.device("cpu")
class _NoModel:
"""Symbol that indicates no model is loaded.
(We can't weakref.ref(None), so this was my best idea at the time to come up with something
type-checkable.)
"""
def __bool__(self):
return False
def to(self, device: torch.device):
pass
def __repr__(self):
return "<NO MODEL>"
NO_MODEL = _NoModel()
class ModelGroup(metaclass=ABCMeta):
"""
A group of models.
The use case I had in mind when writing this is the sub-models used by a DiffusionPipeline,
e.g. its text encoder, U-net, VAE, etc.
Those models are :py:class:`diffusers.ModelMixin`, but "model" is interchangeable with
:py:class:`torch.nn.Module` here.
"""
def __init__(self, execution_device: torch.device):
self.execution_device = execution_device
@abstractmethod
def install(self, *models: torch.nn.Module):
"""Add models to this group."""
pass
@abstractmethod
def uninstall(self, models: torch.nn.Module):
"""Remove models from this group."""
pass
@abstractmethod
def uninstall_all(self):
"""Remove all models from this group."""
@abstractmethod
def load(self, model: torch.nn.Module):
"""Load this model to the execution device."""
pass
@abstractmethod
def offload_current(self):
"""Offload the current model(s) from the execution device."""
pass
@abstractmethod
def ready(self):
"""Ready this group for use."""
pass
@abstractmethod
def set_device(self, device: torch.device):
"""Change which device models from this group will execute on."""
pass
@abstractmethod
def device_for(self, model) -> torch.device:
"""Get the device the given model will execute on.
The model should already be a member of this group.
"""
pass
@abstractmethod
def __contains__(self, model):
"""Check if the model is a member of this group."""
pass
def __repr__(self) -> str:
return f"<{self.__class__.__name__} object at {id(self):x}: " \
f"device={self.execution_device} >"
class LazilyLoadedModelGroup(ModelGroup):
"""
Only one model from this group is loaded on the GPU at a time.
Running the forward method of a model will displace the previously-loaded model,
offloading it to CPU.
If you call other methods on the model, e.g. ``model.encode(x)`` instead of ``model(x)``,
you will need to explicitly load it with :py:method:`.load(model)`.
This implementation relies on pytorch forward-pre-hooks, and it will copy forward arguments
to the appropriate execution device, as long as they are positional arguments and not keyword
arguments. (I didn't make the rules; that's the way the pytorch 1.13 API works for hooks.)
"""
_hooks: MutableMapping[torch.nn.Module, RemovableHandle]
_current_model_ref: Callable[[], torch.nn.Module | _NoModel]
def __init__(self, execution_device: torch.device):
super().__init__(execution_device)
self._hooks = weakref.WeakKeyDictionary()
self._current_model_ref = weakref.ref(NO_MODEL)
def install(self, *models: torch.nn.Module):
for model in models:
self._hooks[model] = model.register_forward_pre_hook(self._pre_hook)
def uninstall(self, *models: torch.nn.Module):
for model in models:
hook = self._hooks.pop(model)
hook.remove()
if self.is_current_model(model):
# no longer hooked by this object, so don't claim to manage it
self.clear_current_model()
def uninstall_all(self):
self.uninstall(*self._hooks.keys())
def _pre_hook(self, module: torch.nn.Module, forward_input):
self.load(module)
if len(forward_input) == 0:
warnings.warn(f"Hook for {module.__class__.__name__} got no input. "
f"Inputs must be positional, not keywords.", stacklevel=3)
return send_to_device(forward_input, self.execution_device)
def load(self, module):
if not self.is_current_model(module):
self.offload_current()
self._load(module)
def offload_current(self):
module = self._current_model_ref()
if module is not NO_MODEL:
module.to(device=OFFLOAD_DEVICE)
self.clear_current_model()
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}"
module = module.to(self.execution_device)
self.set_current_model(module)
return module
def is_current_model(self, model: torch.nn.Module) -> bool:
"""Is the given model the one currently loaded on the execution device?"""
return self._current_model_ref() is model
def is_empty(self):
"""Are none of this group's models loaded on the execution device?"""
return self._current_model_ref() is NO_MODEL
def set_current_model(self, value):
self._current_model_ref = weakref.ref(value)
def clear_current_model(self):
self._current_model_ref = weakref.ref(NO_MODEL)
def set_device(self, device: torch.device):
if device == self.execution_device:
return
self.execution_device = device
current = self._current_model_ref()
if current is not NO_MODEL:
current.to(device)
def device_for(self, model):
if model not in self:
raise KeyError(f"This does not manage this model {type(model).__name__}", model)
return self.execution_device # this implementation only dispatches to one device
def ready(self):
pass # always ready to load on-demand
def __contains__(self, model):
return model in self._hooks
def __repr__(self) -> str:
return f"<{self.__class__.__name__} object at {id(self):x}: " \
f"current_model={type(self._current_model_ref()).__name__} >"
class FullyLoadedModelGroup(ModelGroup):
"""
A group of models without any implicit loading or unloading.
:py:meth:`.ready` loads _all_ the models to the execution device at once.
"""
_models: weakref.WeakSet
def __init__(self, execution_device: torch.device):
super().__init__(execution_device)
self._models = weakref.WeakSet()
def install(self, *models: torch.nn.Module):
for model in models:
self._models.add(model)
model.to(device=self.execution_device)
def uninstall(self, *models: torch.nn.Module):
for model in models:
self._models.remove(model)
def uninstall_all(self):
self.uninstall(*self._models)
def load(self, model):
model.to(device=self.execution_device)
def offload_current(self):
for model in self._models:
model.to(device=OFFLOAD_DEVICE)
def ready(self):
for model in self._models:
self.load(model)
def set_device(self, device: torch.device):
self.execution_device = device
for model in self._models:
if model.device != OFFLOAD_DEVICE:
model.to(device=device)
def device_for(self, model):
if model not in self:
raise KeyError("This does not manage this model f{type(model).__name__}", model)
return self.execution_device # this implementation only dispatches to one device
def __contains__(self, model):
return model in self._models

View File

@ -214,7 +214,7 @@ class WeightedPromptFragmentsToEmbeddingsConverter():
def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor) -> torch.Tensor:
'''
Build a tensor that embeds the passed-in token IDs and applyies the given per_token weights
Build a tensor that embeds the passed-in token IDs and applies the given per_token weights
:param token_ids: A tensor of shape `[self.max_length]` containing token IDs (ints)
:param per_token_weights: A tensor of shape `[self.max_length]` containing weights (floats)
:return: A tensor of shape `[1, self.max_length, token_dim]` representing the requested weighted embeddings
@ -224,13 +224,12 @@ class WeightedPromptFragmentsToEmbeddingsConverter():
if token_ids.shape != torch.Size([self.max_length]):
raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]")
z = self.text_encoder.forward(input_ids=token_ids.unsqueeze(0),
return_dict=False)[0]
z = self.text_encoder(token_ids.unsqueeze(0), return_dict=False)[0]
empty_token_ids = torch.tensor([self.tokenizer.bos_token_id] +
[self.tokenizer.pad_token_id] * (self.max_length-2) +
[self.tokenizer.eos_token_id], dtype=torch.int, device=token_ids.device).unsqueeze(0)
empty_z = self.text_encoder(input_ids=empty_token_ids).last_hidden_state
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
[self.tokenizer.eos_token_id], dtype=torch.int, device=z.device).unsqueeze(0)
empty_z = self.text_encoder(empty_token_ids).last_hidden_state
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape).to(z)
z_delta_from_empty = z - empty_z
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)