mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
new OffloadingDevice loads one model at a time, on demand (#2596)
* new OffloadingDevice loads one model at a time, on demand * fixup! new OffloadingDevice loads one model at a time, on demand * fix(prompt_to_embeddings): call the text encoder directly instead of its forward method allowing any associated hooks to run with it. * more attempts to get things on the right device from the offloader * more attempts to get things on the right device from the offloader * make offloading methods an explicit part of the pipeline interface * inlining some calls where device is only used once * ensure model group is ready after pipeline.to is called * fixup! Strategize slicing based on free [V]RAM (#2572) * doc(offloading): docstrings for offloading.ModelGroup * doc(offloading): docstrings for offloading-related pipeline methods * refactor(offloading): s/SimpleModelGroup/FullyLoadedModelGroup * refactor(offloading): s/HotSeatModelGroup/LazilyLoadedModelGroup to frame it is the same terms as "FullyLoadedModelGroup" --------- Co-authored-by: Damian Stewart <null@damianstewart.com>
This commit is contained in:
parent
2468ba7445
commit
8a0d45ac5a
@ -213,7 +213,9 @@ class Generate:
|
|||||||
print('>> xformers not installed')
|
print('>> xformers not installed')
|
||||||
|
|
||||||
# model caching system for fast switching
|
# model caching system for fast switching
|
||||||
self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)
|
self.model_manager = ModelManager(mconfig, self.device, self.precision,
|
||||||
|
max_loaded_models=max_loaded_models,
|
||||||
|
sequential_offload=self.free_gpu_mem)
|
||||||
# don't accept invalid models
|
# don't accept invalid models
|
||||||
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
||||||
model = model or fallback
|
model = model or fallback
|
||||||
@ -480,7 +482,6 @@ class Generate:
|
|||||||
self.model.cond_stage_model.device = self.model.device
|
self.model.cond_stage_model.device = self.model.device
|
||||||
self.model.cond_stage_model.to(self.model.device)
|
self.model.cond_stage_model.to(self.model.device)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
print(">> Warning: '--free_gpu_mem' is not yet supported when generating image using model based on HuggingFace Diffuser.")
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -3,39 +3,34 @@ from __future__ import annotations
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
import secrets
|
import secrets
|
||||||
import sys
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any
|
from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any
|
||||||
|
|
||||||
if sys.version_info < (3, 10):
|
|
||||||
from typing_extensions import ParamSpec
|
|
||||||
else:
|
|
||||||
from typing import ParamSpec
|
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import einops
|
import einops
|
||||||
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
|
||||||
|
|
||||||
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
|
||||||
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
|
||||||
|
|
||||||
|
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
|
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from diffusers.utils.outputs import BaseOutput
|
from diffusers.utils.outputs import BaseOutput
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
|
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
|
||||||
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||||
|
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
|
||||||
|
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
|
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -264,6 +259,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||||
"""
|
"""
|
||||||
|
_model_group: ModelGroup
|
||||||
|
|
||||||
ID_LENGTH = 8
|
ID_LENGTH = 8
|
||||||
|
|
||||||
@ -273,7 +269,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
scheduler: KarrasDiffusionSchedulers,
|
||||||
safety_checker: Optional[StableDiffusionSafetyChecker],
|
safety_checker: Optional[StableDiffusionSafetyChecker],
|
||||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||||
requires_safety_checker: bool = False,
|
requires_safety_checker: bool = False,
|
||||||
@ -303,8 +299,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
textual_inversion_manager=self.textual_inversion_manager
|
textual_inversion_manager=self.textual_inversion_manager
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
||||||
|
self._model_group.install(*self._submodels)
|
||||||
|
|
||||||
def _adjust_memory_efficient_attention(self, latents: Torch.tensor):
|
|
||||||
|
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
if xformers is available, use it, otherwise use sliced attention.
|
if xformers is available, use it, otherwise use sliced attention.
|
||||||
"""
|
"""
|
||||||
@ -322,7 +321,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
elif self.device.type == 'cuda':
|
elif self.device.type == 'cuda':
|
||||||
mem_free, _ = torch.cuda.mem_get_info(self.device)
|
mem_free, _ = torch.cuda.mem_get_info(self.device)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unrecognized device {device}")
|
raise ValueError(f"unrecognized device {self.device}")
|
||||||
# input tensor of [1, 4, h/8, w/8]
|
# input tensor of [1, 4, h/8, w/8]
|
||||||
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
||||||
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
|
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
|
||||||
@ -336,6 +335,66 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self.disable_attention_slicing()
|
self.disable_attention_slicing()
|
||||||
|
|
||||||
|
|
||||||
|
def enable_offload_submodels(self, device: torch.device):
|
||||||
|
"""
|
||||||
|
Offload each submodel when it's not in use.
|
||||||
|
|
||||||
|
Useful for low-vRAM situations where the size of the model in memory is a big chunk of
|
||||||
|
the total available resource, and you want to free up as much for inference as possible.
|
||||||
|
|
||||||
|
This requires more moving parts and may add some delay as the U-Net is swapped out for the
|
||||||
|
VAE and vice-versa.
|
||||||
|
"""
|
||||||
|
models = self._submodels
|
||||||
|
if self._model_group is not None:
|
||||||
|
self._model_group.uninstall(*models)
|
||||||
|
group = LazilyLoadedModelGroup(device)
|
||||||
|
group.install(*models)
|
||||||
|
self._model_group = group
|
||||||
|
|
||||||
|
def disable_offload_submodels(self):
|
||||||
|
"""
|
||||||
|
Leave all submodels loaded.
|
||||||
|
|
||||||
|
Appropriate for cases where the size of the model in memory is small compared to the memory
|
||||||
|
required for inference. Avoids the delay and complexity of shuffling the submodels to and
|
||||||
|
from the GPU.
|
||||||
|
"""
|
||||||
|
models = self._submodels
|
||||||
|
if self._model_group is not None:
|
||||||
|
self._model_group.uninstall(*models)
|
||||||
|
group = FullyLoadedModelGroup(self._model_group.execution_device)
|
||||||
|
group.install(*models)
|
||||||
|
self._model_group = group
|
||||||
|
|
||||||
|
def offload_all(self):
|
||||||
|
"""Offload all this pipeline's models to CPU."""
|
||||||
|
self._model_group.offload_current()
|
||||||
|
|
||||||
|
def ready(self):
|
||||||
|
"""
|
||||||
|
Ready this pipeline's models.
|
||||||
|
|
||||||
|
i.e. pre-load them to the GPU if appropriate.
|
||||||
|
"""
|
||||||
|
self._model_group.ready()
|
||||||
|
|
||||||
|
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
|
||||||
|
if torch_device is None:
|
||||||
|
return self
|
||||||
|
self._model_group.set_device(torch_device)
|
||||||
|
self._model_group.ready()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self._model_group.execution_device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _submodels(self) -> Sequence[torch.nn.Module]:
|
||||||
|
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||||
|
values = [getattr(self, name) for name in module_names.keys()]
|
||||||
|
return [m for m in values if isinstance(m, torch.nn.Module)]
|
||||||
|
|
||||||
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: ConditioningData,
|
||||||
*,
|
*,
|
||||||
@ -377,7 +436,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
callback: Callable[[PipelineIntermediateState], None] = None
|
callback: Callable[[PipelineIntermediateState], None] = None
|
||||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||||
if timesteps is None:
|
if timesteps is None:
|
||||||
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
|
self.scheduler.set_timesteps(num_inference_steps, device=self._model_group.device_for(self.unet))
|
||||||
timesteps = self.scheduler.timesteps
|
timesteps = self.scheduler.timesteps
|
||||||
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
|
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
|
||||||
result: PipelineIntermediateState = infer_latents_from_embeddings(
|
result: PipelineIntermediateState = infer_latents_from_embeddings(
|
||||||
@ -409,7 +468,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
batch_size = latents.shape[0]
|
batch_size = latents.shape[0]
|
||||||
batched_t = torch.full((batch_size,), timesteps[0],
|
batched_t = torch.full((batch_size,), timesteps[0],
|
||||||
dtype=timesteps.dtype, device=self.unet.device)
|
dtype=timesteps.dtype, device=self._model_group.device_for(self.unet))
|
||||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||||
|
|
||||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||||
@ -493,9 +552,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
|
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
|
||||||
).add_mask_channels(latents)
|
).add_mask_channels(latents)
|
||||||
|
|
||||||
return self.unet(sample=latents,
|
# First three args should be positional, not keywords, so torch hooks can see them.
|
||||||
timestep=t,
|
return self.unet(latents, t, text_embeddings,
|
||||||
encoder_hidden_states=text_embeddings,
|
|
||||||
cross_attention_kwargs=cross_attention_kwargs).sample
|
cross_attention_kwargs=cross_attention_kwargs).sample
|
||||||
|
|
||||||
def img2img_from_embeddings(self,
|
def img2img_from_embeddings(self,
|
||||||
@ -514,9 +572,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
|
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
|
||||||
|
|
||||||
# 6. Prepare latent variables
|
# 6. Prepare latent variables
|
||||||
device = self.unet.device
|
initial_latents = self.non_noised_latents_from_image(
|
||||||
latents_dtype = self.unet.dtype
|
init_image, device=self._model_group.device_for(self.unet),
|
||||||
initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
|
dtype=self.unet.dtype)
|
||||||
noise = noise_func(initial_latents)
|
noise = noise_func(initial_latents)
|
||||||
|
|
||||||
return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps,
|
return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps,
|
||||||
@ -529,7 +587,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
strength,
|
strength,
|
||||||
noise: torch.Tensor, run_id=None, callback=None
|
noise: torch.Tensor, run_id=None, callback=None
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, self.unet.device)
|
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength,
|
||||||
|
device=self._model_group.device_for(self.unet))
|
||||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||||
initial_latents, num_inference_steps, conditioning_data,
|
initial_latents, num_inference_steps, conditioning_data,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
@ -568,7 +627,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
run_id=None,
|
run_id=None,
|
||||||
noise_func=None,
|
noise_func=None,
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
device = self.unet.device
|
device = self._model_group.device_for(self.unet)
|
||||||
latents_dtype = self.unet.dtype
|
latents_dtype = self.unet.dtype
|
||||||
|
|
||||||
if isinstance(init_image, PIL.Image.Image):
|
if isinstance(init_image, PIL.Image.Image):
|
||||||
@ -632,6 +691,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
|
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
|
||||||
self.vae.to('cpu')
|
self.vae.to('cpu')
|
||||||
init_image = init_image.to('cpu')
|
init_image = init_image.to('cpu')
|
||||||
|
else:
|
||||||
|
self._model_group.load(self.vae)
|
||||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||||
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
||||||
if device.type == 'mps':
|
if device.type == 'mps':
|
||||||
@ -643,8 +704,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
def check_for_safety(self, output, dtype):
|
def check_for_safety(self, output, dtype):
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
screened_images, has_nsfw_concept = self.run_safety_checker(
|
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
|
||||||
output.images, device=self._execution_device, dtype=dtype)
|
|
||||||
screened_attention_map_saver = None
|
screened_attention_map_saver = None
|
||||||
if has_nsfw_concept is None or not has_nsfw_concept:
|
if has_nsfw_concept is None or not has_nsfw_concept:
|
||||||
screened_attention_map_saver = output.attention_map_saver
|
screened_attention_map_saver = output.attention_map_saver
|
||||||
@ -653,6 +713,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# block the attention maps if NSFW content is detected
|
# block the attention maps if NSFW content is detected
|
||||||
attention_map_saver=screened_attention_map_saver)
|
attention_map_saver=screened_attention_map_saver)
|
||||||
|
|
||||||
|
def run_safety_checker(self, image, device=None, dtype=None):
|
||||||
|
# overriding to use the model group for device info instead of requiring the caller to know.
|
||||||
|
if self.safety_checker is not None:
|
||||||
|
device = self._model_group.device_for(self.safety_checker)
|
||||||
|
return super().run_safety_checker(image, device, dtype)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
|
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
|
||||||
"""
|
"""
|
||||||
@ -662,7 +728,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
text=c,
|
text=c,
|
||||||
fragment_weights=fragment_weights,
|
fragment_weights=fragment_weights,
|
||||||
should_return_tokens=return_tokens,
|
should_return_tokens=return_tokens,
|
||||||
device=self.device)
|
device=self._model_group.device_for(self.unet))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cond_stage_model(self):
|
def cond_stage_model(self):
|
||||||
@ -683,6 +749,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
"""Compatible with DiffusionWrapper"""
|
"""Compatible with DiffusionWrapper"""
|
||||||
return self.unet.in_channels
|
return self.unet.in_channels
|
||||||
|
|
||||||
|
def decode_latents(self, latents):
|
||||||
|
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
||||||
|
self._model_group.load(self.vae)
|
||||||
|
return super().decode_latents(latents)
|
||||||
|
|
||||||
def debug_latents(self, latents, msg):
|
def debug_latents(self, latents, msg):
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
from ldm.util import debug_image
|
from ldm.util import debug_image
|
||||||
|
@ -25,8 +25,6 @@ import torch
|
|||||||
import transformers
|
import transformers
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import AutoencoderKL
|
||||||
from diffusers import logging as dlogging
|
from diffusers import logging as dlogging
|
||||||
from diffusers.utils.logging import (get_verbosity, set_verbosity,
|
|
||||||
set_verbosity_error)
|
|
||||||
from huggingface_hub import scan_cache_dir
|
from huggingface_hub import scan_cache_dir
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
@ -49,9 +47,10 @@ class ModelManager(object):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: OmegaConf,
|
config: OmegaConf,
|
||||||
device_type: str = "cpu",
|
device_type: str | torch.device = "cpu",
|
||||||
precision: str = "float16",
|
precision: str = "float16",
|
||||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||||
|
sequential_offload = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file,
|
Initialize with the path to the models.yaml config file,
|
||||||
@ -69,6 +68,7 @@ class ModelManager(object):
|
|||||||
self.models = {}
|
self.models = {}
|
||||||
self.stack = [] # this is an LRU FIFO
|
self.stack = [] # this is an LRU FIFO
|
||||||
self.current_model = None
|
self.current_model = None
|
||||||
|
self.sequential_offload = sequential_offload
|
||||||
|
|
||||||
def valid_model(self, model_name: str) -> bool:
|
def valid_model(self, model_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -529,6 +529,9 @@ class ModelManager(object):
|
|||||||
dlogging.set_verbosity(verbosity)
|
dlogging.set_verbosity(verbosity)
|
||||||
assert pipeline is not None, OSError(f'"{name_or_path}" could not be loaded')
|
assert pipeline is not None, OSError(f'"{name_or_path}" could not be loaded')
|
||||||
|
|
||||||
|
if self.sequential_offload:
|
||||||
|
pipeline.enable_offload_submodels(self.device)
|
||||||
|
else:
|
||||||
pipeline.to(self.device)
|
pipeline.to(self.device)
|
||||||
|
|
||||||
model_hash = self._diffuser_sha256(name_or_path)
|
model_hash = self._diffuser_sha256(name_or_path)
|
||||||
@ -748,7 +751,6 @@ class ModelManager(object):
|
|||||||
into models.yaml.
|
into models.yaml.
|
||||||
"""
|
"""
|
||||||
new_config = None
|
new_config = None
|
||||||
import transformers
|
|
||||||
|
|
||||||
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
|
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
|
||||||
|
|
||||||
@ -995,12 +997,12 @@ class ModelManager(object):
|
|||||||
if self.device == "cpu":
|
if self.device == "cpu":
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# diffusers really really doesn't like us moving a float16 model onto CPU
|
if isinstance(model, StableDiffusionGeneratorPipeline):
|
||||||
verbosity = get_verbosity()
|
model.offload_all()
|
||||||
set_verbosity_error()
|
return model
|
||||||
|
|
||||||
model.cond_stage_model.device = "cpu"
|
model.cond_stage_model.device = "cpu"
|
||||||
model.to("cpu")
|
model.to("cpu")
|
||||||
set_verbosity(verbosity)
|
|
||||||
|
|
||||||
for submodel in ("first_stage_model", "cond_stage_model", "model"):
|
for submodel in ("first_stage_model", "cond_stage_model", "model"):
|
||||||
try:
|
try:
|
||||||
@ -1013,6 +1015,10 @@ class ModelManager(object):
|
|||||||
if self.device == "cpu":
|
if self.device == "cpu":
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
if isinstance(model, StableDiffusionGeneratorPipeline):
|
||||||
|
model.ready()
|
||||||
|
return model
|
||||||
|
|
||||||
model.to(self.device)
|
model.to(self.device)
|
||||||
model.cond_stage_model.device = self.device
|
model.cond_stage_model.device = self.device
|
||||||
|
|
||||||
@ -1163,7 +1169,7 @@ class ModelManager(object):
|
|||||||
strategy.execute()
|
strategy.execute()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _abs_path(path: Union(str, Path)) -> Path:
|
def _abs_path(path: str | Path) -> Path:
|
||||||
if path is None or Path(path).is_absolute():
|
if path is None or Path(path).is_absolute():
|
||||||
return path
|
return path
|
||||||
return Path(Globals.root, path).resolve()
|
return Path(Globals.root, path).resolve()
|
||||||
|
247
ldm/invoke/offloading.py
Normal file
247
ldm/invoke/offloading.py
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
import weakref
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
from collections.abc import MutableMapping
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from accelerate.utils import send_to_device
|
||||||
|
from torch.utils.hooks import RemovableHandle
|
||||||
|
|
||||||
|
OFFLOAD_DEVICE = torch.device("cpu")
|
||||||
|
|
||||||
|
class _NoModel:
|
||||||
|
"""Symbol that indicates no model is loaded.
|
||||||
|
|
||||||
|
(We can't weakref.ref(None), so this was my best idea at the time to come up with something
|
||||||
|
type-checkable.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def to(self, device: torch.device):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<NO MODEL>"
|
||||||
|
|
||||||
|
NO_MODEL = _NoModel()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelGroup(metaclass=ABCMeta):
|
||||||
|
"""
|
||||||
|
A group of models.
|
||||||
|
|
||||||
|
The use case I had in mind when writing this is the sub-models used by a DiffusionPipeline,
|
||||||
|
e.g. its text encoder, U-net, VAE, etc.
|
||||||
|
|
||||||
|
Those models are :py:class:`diffusers.ModelMixin`, but "model" is interchangeable with
|
||||||
|
:py:class:`torch.nn.Module` here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, execution_device: torch.device):
|
||||||
|
self.execution_device = execution_device
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def install(self, *models: torch.nn.Module):
|
||||||
|
"""Add models to this group."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def uninstall(self, models: torch.nn.Module):
|
||||||
|
"""Remove models from this group."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def uninstall_all(self):
|
||||||
|
"""Remove all models from this group."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load(self, model: torch.nn.Module):
|
||||||
|
"""Load this model to the execution device."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def offload_current(self):
|
||||||
|
"""Offload the current model(s) from the execution device."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def ready(self):
|
||||||
|
"""Ready this group for use."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_device(self, device: torch.device):
|
||||||
|
"""Change which device models from this group will execute on."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def device_for(self, model) -> torch.device:
|
||||||
|
"""Get the device the given model will execute on.
|
||||||
|
|
||||||
|
The model should already be a member of this group.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __contains__(self, model):
|
||||||
|
"""Check if the model is a member of this group."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<{self.__class__.__name__} object at {id(self):x}: " \
|
||||||
|
f"device={self.execution_device} >"
|
||||||
|
|
||||||
|
|
||||||
|
class LazilyLoadedModelGroup(ModelGroup):
|
||||||
|
"""
|
||||||
|
Only one model from this group is loaded on the GPU at a time.
|
||||||
|
|
||||||
|
Running the forward method of a model will displace the previously-loaded model,
|
||||||
|
offloading it to CPU.
|
||||||
|
|
||||||
|
If you call other methods on the model, e.g. ``model.encode(x)`` instead of ``model(x)``,
|
||||||
|
you will need to explicitly load it with :py:method:`.load(model)`.
|
||||||
|
|
||||||
|
This implementation relies on pytorch forward-pre-hooks, and it will copy forward arguments
|
||||||
|
to the appropriate execution device, as long as they are positional arguments and not keyword
|
||||||
|
arguments. (I didn't make the rules; that's the way the pytorch 1.13 API works for hooks.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
_hooks: MutableMapping[torch.nn.Module, RemovableHandle]
|
||||||
|
_current_model_ref: Callable[[], torch.nn.Module | _NoModel]
|
||||||
|
|
||||||
|
def __init__(self, execution_device: torch.device):
|
||||||
|
super().__init__(execution_device)
|
||||||
|
self._hooks = weakref.WeakKeyDictionary()
|
||||||
|
self._current_model_ref = weakref.ref(NO_MODEL)
|
||||||
|
|
||||||
|
def install(self, *models: torch.nn.Module):
|
||||||
|
for model in models:
|
||||||
|
self._hooks[model] = model.register_forward_pre_hook(self._pre_hook)
|
||||||
|
|
||||||
|
def uninstall(self, *models: torch.nn.Module):
|
||||||
|
for model in models:
|
||||||
|
hook = self._hooks.pop(model)
|
||||||
|
hook.remove()
|
||||||
|
if self.is_current_model(model):
|
||||||
|
# no longer hooked by this object, so don't claim to manage it
|
||||||
|
self.clear_current_model()
|
||||||
|
|
||||||
|
def uninstall_all(self):
|
||||||
|
self.uninstall(*self._hooks.keys())
|
||||||
|
|
||||||
|
def _pre_hook(self, module: torch.nn.Module, forward_input):
|
||||||
|
self.load(module)
|
||||||
|
if len(forward_input) == 0:
|
||||||
|
warnings.warn(f"Hook for {module.__class__.__name__} got no input. "
|
||||||
|
f"Inputs must be positional, not keywords.", stacklevel=3)
|
||||||
|
return send_to_device(forward_input, self.execution_device)
|
||||||
|
|
||||||
|
def load(self, module):
|
||||||
|
if not self.is_current_model(module):
|
||||||
|
self.offload_current()
|
||||||
|
self._load(module)
|
||||||
|
|
||||||
|
def offload_current(self):
|
||||||
|
module = self._current_model_ref()
|
||||||
|
if module is not NO_MODEL:
|
||||||
|
module.to(device=OFFLOAD_DEVICE)
|
||||||
|
self.clear_current_model()
|
||||||
|
|
||||||
|
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||||
|
assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}"
|
||||||
|
module = module.to(self.execution_device)
|
||||||
|
self.set_current_model(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def is_current_model(self, model: torch.nn.Module) -> bool:
|
||||||
|
"""Is the given model the one currently loaded on the execution device?"""
|
||||||
|
return self._current_model_ref() is model
|
||||||
|
|
||||||
|
def is_empty(self):
|
||||||
|
"""Are none of this group's models loaded on the execution device?"""
|
||||||
|
return self._current_model_ref() is NO_MODEL
|
||||||
|
|
||||||
|
def set_current_model(self, value):
|
||||||
|
self._current_model_ref = weakref.ref(value)
|
||||||
|
|
||||||
|
def clear_current_model(self):
|
||||||
|
self._current_model_ref = weakref.ref(NO_MODEL)
|
||||||
|
|
||||||
|
def set_device(self, device: torch.device):
|
||||||
|
if device == self.execution_device:
|
||||||
|
return
|
||||||
|
self.execution_device = device
|
||||||
|
current = self._current_model_ref()
|
||||||
|
if current is not NO_MODEL:
|
||||||
|
current.to(device)
|
||||||
|
|
||||||
|
def device_for(self, model):
|
||||||
|
if model not in self:
|
||||||
|
raise KeyError(f"This does not manage this model {type(model).__name__}", model)
|
||||||
|
return self.execution_device # this implementation only dispatches to one device
|
||||||
|
|
||||||
|
def ready(self):
|
||||||
|
pass # always ready to load on-demand
|
||||||
|
|
||||||
|
def __contains__(self, model):
|
||||||
|
return model in self._hooks
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<{self.__class__.__name__} object at {id(self):x}: " \
|
||||||
|
f"current_model={type(self._current_model_ref()).__name__} >"
|
||||||
|
|
||||||
|
|
||||||
|
class FullyLoadedModelGroup(ModelGroup):
|
||||||
|
"""
|
||||||
|
A group of models without any implicit loading or unloading.
|
||||||
|
|
||||||
|
:py:meth:`.ready` loads _all_ the models to the execution device at once.
|
||||||
|
"""
|
||||||
|
_models: weakref.WeakSet
|
||||||
|
|
||||||
|
def __init__(self, execution_device: torch.device):
|
||||||
|
super().__init__(execution_device)
|
||||||
|
self._models = weakref.WeakSet()
|
||||||
|
|
||||||
|
def install(self, *models: torch.nn.Module):
|
||||||
|
for model in models:
|
||||||
|
self._models.add(model)
|
||||||
|
model.to(device=self.execution_device)
|
||||||
|
|
||||||
|
def uninstall(self, *models: torch.nn.Module):
|
||||||
|
for model in models:
|
||||||
|
self._models.remove(model)
|
||||||
|
|
||||||
|
def uninstall_all(self):
|
||||||
|
self.uninstall(*self._models)
|
||||||
|
|
||||||
|
def load(self, model):
|
||||||
|
model.to(device=self.execution_device)
|
||||||
|
|
||||||
|
def offload_current(self):
|
||||||
|
for model in self._models:
|
||||||
|
model.to(device=OFFLOAD_DEVICE)
|
||||||
|
|
||||||
|
def ready(self):
|
||||||
|
for model in self._models:
|
||||||
|
self.load(model)
|
||||||
|
|
||||||
|
def set_device(self, device: torch.device):
|
||||||
|
self.execution_device = device
|
||||||
|
for model in self._models:
|
||||||
|
if model.device != OFFLOAD_DEVICE:
|
||||||
|
model.to(device=device)
|
||||||
|
|
||||||
|
def device_for(self, model):
|
||||||
|
if model not in self:
|
||||||
|
raise KeyError("This does not manage this model f{type(model).__name__}", model)
|
||||||
|
return self.execution_device # this implementation only dispatches to one device
|
||||||
|
|
||||||
|
def __contains__(self, model):
|
||||||
|
return model in self._models
|
@ -214,7 +214,7 @@ class WeightedPromptFragmentsToEmbeddingsConverter():
|
|||||||
|
|
||||||
def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor) -> torch.Tensor:
|
def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor) -> torch.Tensor:
|
||||||
'''
|
'''
|
||||||
Build a tensor that embeds the passed-in token IDs and applyies the given per_token weights
|
Build a tensor that embeds the passed-in token IDs and applies the given per_token weights
|
||||||
:param token_ids: A tensor of shape `[self.max_length]` containing token IDs (ints)
|
:param token_ids: A tensor of shape `[self.max_length]` containing token IDs (ints)
|
||||||
:param per_token_weights: A tensor of shape `[self.max_length]` containing weights (floats)
|
:param per_token_weights: A tensor of shape `[self.max_length]` containing weights (floats)
|
||||||
:return: A tensor of shape `[1, self.max_length, token_dim]` representing the requested weighted embeddings
|
:return: A tensor of shape `[1, self.max_length, token_dim]` representing the requested weighted embeddings
|
||||||
@ -224,13 +224,12 @@ class WeightedPromptFragmentsToEmbeddingsConverter():
|
|||||||
if token_ids.shape != torch.Size([self.max_length]):
|
if token_ids.shape != torch.Size([self.max_length]):
|
||||||
raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]")
|
raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]")
|
||||||
|
|
||||||
z = self.text_encoder.forward(input_ids=token_ids.unsqueeze(0),
|
z = self.text_encoder(token_ids.unsqueeze(0), return_dict=False)[0]
|
||||||
return_dict=False)[0]
|
|
||||||
empty_token_ids = torch.tensor([self.tokenizer.bos_token_id] +
|
empty_token_ids = torch.tensor([self.tokenizer.bos_token_id] +
|
||||||
[self.tokenizer.pad_token_id] * (self.max_length-2) +
|
[self.tokenizer.pad_token_id] * (self.max_length-2) +
|
||||||
[self.tokenizer.eos_token_id], dtype=torch.int, device=token_ids.device).unsqueeze(0)
|
[self.tokenizer.eos_token_id], dtype=torch.int, device=z.device).unsqueeze(0)
|
||||||
empty_z = self.text_encoder(input_ids=empty_token_ids).last_hidden_state
|
empty_z = self.text_encoder(empty_token_ids).last_hidden_state
|
||||||
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
|
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape).to(z)
|
||||||
z_delta_from_empty = z - empty_z
|
z_delta_from_empty = z - empty_z
|
||||||
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
|
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user