From 8a0d45ac5ae80e1949c2a4a5643f70a2f6df8ecc Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Thu, 16 Feb 2023 15:48:27 -0800 Subject: [PATCH] new OffloadingDevice loads one model at a time, on demand (#2596) * new OffloadingDevice loads one model at a time, on demand * fixup! new OffloadingDevice loads one model at a time, on demand * fix(prompt_to_embeddings): call the text encoder directly instead of its forward method allowing any associated hooks to run with it. * more attempts to get things on the right device from the offloader * more attempts to get things on the right device from the offloader * make offloading methods an explicit part of the pipeline interface * inlining some calls where device is only used once * ensure model group is ready after pipeline.to is called * fixup! Strategize slicing based on free [V]RAM (#2572) * doc(offloading): docstrings for offloading.ModelGroup * doc(offloading): docstrings for offloading-related pipeline methods * refactor(offloading): s/SimpleModelGroup/FullyLoadedModelGroup * refactor(offloading): s/HotSeatModelGroup/LazilyLoadedModelGroup to frame it is the same terms as "FullyLoadedModelGroup" --------- Co-authored-by: Damian Stewart --- ldm/generate.py | 5 +- ldm/invoke/generator/diffusers_pipeline.py | 129 +++++++-- ldm/invoke/model_manager.py | 26 +- ldm/invoke/offloading.py | 247 ++++++++++++++++++ ldm/modules/prompt_to_embeddings_converter.py | 11 +- 5 files changed, 371 insertions(+), 47 deletions(-) create mode 100644 ldm/invoke/offloading.py diff --git a/ldm/generate.py b/ldm/generate.py index 8cb3058694..1b07122628 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -213,7 +213,9 @@ class Generate: print('>> xformers not installed') # model caching system for fast switching - self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models) + self.model_manager = ModelManager(mconfig, self.device, self.precision, + max_loaded_models=max_loaded_models, + sequential_offload=self.free_gpu_mem) # don't accept invalid models fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME model = model or fallback @@ -480,7 +482,6 @@ class Generate: self.model.cond_stage_model.device = self.model.device self.model.cond_stage_model.to(self.model.device) except AttributeError: - print(">> Warning: '--free_gpu_mem' is not yet supported when generating image using model based on HuggingFace Diffuser.") pass try: diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 686fb40d3a..5990eb42a1 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -3,39 +3,34 @@ from __future__ import annotations import dataclasses import inspect import secrets -import sys +from collections.abc import Sequence from dataclasses import dataclass, field from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any -if sys.version_info < (3, 10): - from typing_extensions import ParamSpec -else: - from typing import ParamSpec - import PIL.Image import einops +import psutil import torch import torchvision.transforms as T -from diffusers.utils.import_utils import is_xformers_available - -from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver -from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter - - from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput -from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.outputs import BaseOutput from torchvision.transforms.functional import resize as tv_resize from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from typing_extensions import ParamSpec from ldm.invoke.globals import Globals from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings from ldm.modules.textual_inversion_manager import TextualInversionManager +from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup +from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver +from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter @dataclass @@ -264,6 +259,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + _model_group: ModelGroup ID_LENGTH = 8 @@ -273,7 +269,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: KarrasDiffusionSchedulers, safety_checker: Optional[StableDiffusionSafetyChecker], feature_extractor: Optional[CLIPFeatureExtractor], requires_safety_checker: bool = False, @@ -303,8 +299,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): textual_inversion_manager=self.textual_inversion_manager ) + self._model_group = FullyLoadedModelGroup(self.unet.device) + self._model_group.install(*self._submodels) - def _adjust_memory_efficient_attention(self, latents: Torch.tensor): + + def _adjust_memory_efficient_attention(self, latents: torch.Tensor): """ if xformers is available, use it, otherwise use sliced attention. """ @@ -322,7 +321,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): elif self.device.type == 'cuda': mem_free, _ = torch.cuda.mem_get_info(self.device) else: - raise ValueError(f"unrecognized device {device}") + raise ValueError(f"unrecognized device {self.device}") # input tensor of [1, 4, h/8, w/8] # output tensor of [16, (h/8 * w/8), (h/8 * w/8)] bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4 @@ -336,6 +335,66 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): self.disable_attention_slicing() + def enable_offload_submodels(self, device: torch.device): + """ + Offload each submodel when it's not in use. + + Useful for low-vRAM situations where the size of the model in memory is a big chunk of + the total available resource, and you want to free up as much for inference as possible. + + This requires more moving parts and may add some delay as the U-Net is swapped out for the + VAE and vice-versa. + """ + models = self._submodels + if self._model_group is not None: + self._model_group.uninstall(*models) + group = LazilyLoadedModelGroup(device) + group.install(*models) + self._model_group = group + + def disable_offload_submodels(self): + """ + Leave all submodels loaded. + + Appropriate for cases where the size of the model in memory is small compared to the memory + required for inference. Avoids the delay and complexity of shuffling the submodels to and + from the GPU. + """ + models = self._submodels + if self._model_group is not None: + self._model_group.uninstall(*models) + group = FullyLoadedModelGroup(self._model_group.execution_device) + group.install(*models) + self._model_group = group + + def offload_all(self): + """Offload all this pipeline's models to CPU.""" + self._model_group.offload_current() + + def ready(self): + """ + Ready this pipeline's models. + + i.e. pre-load them to the GPU if appropriate. + """ + self._model_group.ready() + + def to(self, torch_device: Optional[Union[str, torch.device]] = None): + if torch_device is None: + return self + self._model_group.set_device(torch_device) + self._model_group.ready() + + @property + def device(self) -> torch.device: + return self._model_group.execution_device + + @property + def _submodels(self) -> Sequence[torch.nn.Module]: + module_names, _, _ = self.extract_init_dict(dict(self.config)) + values = [getattr(self, name) for name in module_names.keys()] + return [m for m in values if isinstance(m, torch.nn.Module)] + def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, conditioning_data: ConditioningData, *, @@ -377,7 +436,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): callback: Callable[[PipelineIntermediateState], None] = None ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: if timesteps is None: - self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device) + self.scheduler.set_timesteps(num_inference_steps, device=self._model_group.device_for(self.unet)) timesteps = self.scheduler.timesteps infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState) result: PipelineIntermediateState = infer_latents_from_embeddings( @@ -409,7 +468,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): batch_size = latents.shape[0] batched_t = torch.full((batch_size,), timesteps[0], - dtype=timesteps.dtype, device=self.unet.device) + dtype=timesteps.dtype, device=self._model_group.device_for(self.unet)) latents = self.scheduler.add_noise(latents, noise, batched_t) attention_map_saver: Optional[AttentionMapSaver] = None @@ -493,9 +552,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype) ).add_mask_channels(latents) - return self.unet(sample=latents, - timestep=t, - encoder_hidden_states=text_embeddings, + # First three args should be positional, not keywords, so torch hooks can see them. + return self.unet(latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs).sample def img2img_from_embeddings(self, @@ -514,9 +572,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): init_image = einops.rearrange(init_image, 'c h w -> 1 c h w') # 6. Prepare latent variables - device = self.unet.device - latents_dtype = self.unet.dtype - initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype) + initial_latents = self.non_noised_latents_from_image( + init_image, device=self._model_group.device_for(self.unet), + dtype=self.unet.dtype) noise = noise_func(initial_latents) return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps, @@ -529,7 +587,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): strength, noise: torch.Tensor, run_id=None, callback=None ) -> InvokeAIStableDiffusionPipelineOutput: - timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, self.unet.device) + timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, + device=self._model_group.device_for(self.unet)) result_latents, result_attention_maps = self.latents_from_embeddings( initial_latents, num_inference_steps, conditioning_data, timesteps=timesteps, @@ -568,7 +627,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): run_id=None, noise_func=None, ) -> InvokeAIStableDiffusionPipelineOutput: - device = self.unet.device + device = self._model_group.device_for(self.unet) latents_dtype = self.unet.dtype if isinstance(init_image, PIL.Image.Image): @@ -632,6 +691,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # TODO remove this workaround once kulinseth#222 is merged to pytorch mainline self.vae.to('cpu') init_image = init_image.to('cpu') + else: + self._model_group.load(self.vae) init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible! if device.type == 'mps': @@ -643,8 +704,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): def check_for_safety(self, output, dtype): with torch.inference_mode(): - screened_images, has_nsfw_concept = self.run_safety_checker( - output.images, device=self._execution_device, dtype=dtype) + screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype) screened_attention_map_saver = None if has_nsfw_concept is None or not has_nsfw_concept: screened_attention_map_saver = output.attention_map_saver @@ -653,6 +713,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # block the attention maps if NSFW content is detected attention_map_saver=screened_attention_map_saver) + def run_safety_checker(self, image, device=None, dtype=None): + # overriding to use the model group for device info instead of requiring the caller to know. + if self.safety_checker is not None: + device = self._model_group.device_for(self.safety_checker) + return super().run_safety_checker(image, device, dtype) + @torch.inference_mode() def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None): """ @@ -662,7 +728,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): text=c, fragment_weights=fragment_weights, should_return_tokens=return_tokens, - device=self.device) + device=self._model_group.device_for(self.unet)) @property def cond_stage_model(self): @@ -683,6 +749,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): """Compatible with DiffusionWrapper""" return self.unet.in_channels + def decode_latents(self, latents): + # Explicit call to get the vae loaded, since `decode` isn't the forward method. + self._model_group.load(self.vae) + return super().decode_latents(latents) + def debug_latents(self, latents, msg): with torch.inference_mode(): from ldm.util import debug_image diff --git a/ldm/invoke/model_manager.py b/ldm/invoke/model_manager.py index 99e2bdfd86..27b5d064ef 100644 --- a/ldm/invoke/model_manager.py +++ b/ldm/invoke/model_manager.py @@ -25,8 +25,6 @@ import torch import transformers from diffusers import AutoencoderKL from diffusers import logging as dlogging -from diffusers.utils.logging import (get_verbosity, set_verbosity, - set_verbosity_error) from huggingface_hub import scan_cache_dir from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig @@ -49,9 +47,10 @@ class ModelManager(object): def __init__( self, config: OmegaConf, - device_type: str = "cpu", + device_type: str | torch.device = "cpu", precision: str = "float16", max_loaded_models=DEFAULT_MAX_MODELS, + sequential_offload = False ): """ Initialize with the path to the models.yaml config file, @@ -69,6 +68,7 @@ class ModelManager(object): self.models = {} self.stack = [] # this is an LRU FIFO self.current_model = None + self.sequential_offload = sequential_offload def valid_model(self, model_name: str) -> bool: """ @@ -529,7 +529,10 @@ class ModelManager(object): dlogging.set_verbosity(verbosity) assert pipeline is not None, OSError(f'"{name_or_path}" could not be loaded') - pipeline.to(self.device) + if self.sequential_offload: + pipeline.enable_offload_submodels(self.device) + else: + pipeline.to(self.device) model_hash = self._diffuser_sha256(name_or_path) @@ -748,7 +751,6 @@ class ModelManager(object): into models.yaml. """ new_config = None - import transformers from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser @@ -995,12 +997,12 @@ class ModelManager(object): if self.device == "cpu": return model - # diffusers really really doesn't like us moving a float16 model onto CPU - verbosity = get_verbosity() - set_verbosity_error() + if isinstance(model, StableDiffusionGeneratorPipeline): + model.offload_all() + return model + model.cond_stage_model.device = "cpu" model.to("cpu") - set_verbosity(verbosity) for submodel in ("first_stage_model", "cond_stage_model", "model"): try: @@ -1013,6 +1015,10 @@ class ModelManager(object): if self.device == "cpu": return model + if isinstance(model, StableDiffusionGeneratorPipeline): + model.ready() + return model + model.to(self.device) model.cond_stage_model.device = self.device @@ -1163,7 +1169,7 @@ class ModelManager(object): strategy.execute() @staticmethod - def _abs_path(path: Union(str, Path)) -> Path: + def _abs_path(path: str | Path) -> Path: if path is None or Path(path).is_absolute(): return path return Path(Globals.root, path).resolve() diff --git a/ldm/invoke/offloading.py b/ldm/invoke/offloading.py new file mode 100644 index 0000000000..e049f5fe09 --- /dev/null +++ b/ldm/invoke/offloading.py @@ -0,0 +1,247 @@ +from __future__ import annotations + +import warnings +import weakref +from abc import ABCMeta, abstractmethod +from collections.abc import MutableMapping +from typing import Callable + +import torch +from accelerate.utils import send_to_device +from torch.utils.hooks import RemovableHandle + +OFFLOAD_DEVICE = torch.device("cpu") + +class _NoModel: + """Symbol that indicates no model is loaded. + + (We can't weakref.ref(None), so this was my best idea at the time to come up with something + type-checkable.) + """ + + def __bool__(self): + return False + + def to(self, device: torch.device): + pass + + def __repr__(self): + return "" + +NO_MODEL = _NoModel() + + +class ModelGroup(metaclass=ABCMeta): + """ + A group of models. + + The use case I had in mind when writing this is the sub-models used by a DiffusionPipeline, + e.g. its text encoder, U-net, VAE, etc. + + Those models are :py:class:`diffusers.ModelMixin`, but "model" is interchangeable with + :py:class:`torch.nn.Module` here. + """ + + def __init__(self, execution_device: torch.device): + self.execution_device = execution_device + + @abstractmethod + def install(self, *models: torch.nn.Module): + """Add models to this group.""" + pass + + @abstractmethod + def uninstall(self, models: torch.nn.Module): + """Remove models from this group.""" + pass + + @abstractmethod + def uninstall_all(self): + """Remove all models from this group.""" + + @abstractmethod + def load(self, model: torch.nn.Module): + """Load this model to the execution device.""" + pass + + @abstractmethod + def offload_current(self): + """Offload the current model(s) from the execution device.""" + pass + + @abstractmethod + def ready(self): + """Ready this group for use.""" + pass + + @abstractmethod + def set_device(self, device: torch.device): + """Change which device models from this group will execute on.""" + pass + + @abstractmethod + def device_for(self, model) -> torch.device: + """Get the device the given model will execute on. + + The model should already be a member of this group. + """ + pass + + @abstractmethod + def __contains__(self, model): + """Check if the model is a member of this group.""" + pass + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} object at {id(self):x}: " \ + f"device={self.execution_device} >" + + +class LazilyLoadedModelGroup(ModelGroup): + """ + Only one model from this group is loaded on the GPU at a time. + + Running the forward method of a model will displace the previously-loaded model, + offloading it to CPU. + + If you call other methods on the model, e.g. ``model.encode(x)`` instead of ``model(x)``, + you will need to explicitly load it with :py:method:`.load(model)`. + + This implementation relies on pytorch forward-pre-hooks, and it will copy forward arguments + to the appropriate execution device, as long as they are positional arguments and not keyword + arguments. (I didn't make the rules; that's the way the pytorch 1.13 API works for hooks.) + """ + + _hooks: MutableMapping[torch.nn.Module, RemovableHandle] + _current_model_ref: Callable[[], torch.nn.Module | _NoModel] + + def __init__(self, execution_device: torch.device): + super().__init__(execution_device) + self._hooks = weakref.WeakKeyDictionary() + self._current_model_ref = weakref.ref(NO_MODEL) + + def install(self, *models: torch.nn.Module): + for model in models: + self._hooks[model] = model.register_forward_pre_hook(self._pre_hook) + + def uninstall(self, *models: torch.nn.Module): + for model in models: + hook = self._hooks.pop(model) + hook.remove() + if self.is_current_model(model): + # no longer hooked by this object, so don't claim to manage it + self.clear_current_model() + + def uninstall_all(self): + self.uninstall(*self._hooks.keys()) + + def _pre_hook(self, module: torch.nn.Module, forward_input): + self.load(module) + if len(forward_input) == 0: + warnings.warn(f"Hook for {module.__class__.__name__} got no input. " + f"Inputs must be positional, not keywords.", stacklevel=3) + return send_to_device(forward_input, self.execution_device) + + def load(self, module): + if not self.is_current_model(module): + self.offload_current() + self._load(module) + + def offload_current(self): + module = self._current_model_ref() + if module is not NO_MODEL: + module.to(device=OFFLOAD_DEVICE) + self.clear_current_model() + + def _load(self, module: torch.nn.Module) -> torch.nn.Module: + assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}" + module = module.to(self.execution_device) + self.set_current_model(module) + return module + + def is_current_model(self, model: torch.nn.Module) -> bool: + """Is the given model the one currently loaded on the execution device?""" + return self._current_model_ref() is model + + def is_empty(self): + """Are none of this group's models loaded on the execution device?""" + return self._current_model_ref() is NO_MODEL + + def set_current_model(self, value): + self._current_model_ref = weakref.ref(value) + + def clear_current_model(self): + self._current_model_ref = weakref.ref(NO_MODEL) + + def set_device(self, device: torch.device): + if device == self.execution_device: + return + self.execution_device = device + current = self._current_model_ref() + if current is not NO_MODEL: + current.to(device) + + def device_for(self, model): + if model not in self: + raise KeyError(f"This does not manage this model {type(model).__name__}", model) + return self.execution_device # this implementation only dispatches to one device + + def ready(self): + pass # always ready to load on-demand + + def __contains__(self, model): + return model in self._hooks + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} object at {id(self):x}: " \ + f"current_model={type(self._current_model_ref()).__name__} >" + + +class FullyLoadedModelGroup(ModelGroup): + """ + A group of models without any implicit loading or unloading. + + :py:meth:`.ready` loads _all_ the models to the execution device at once. + """ + _models: weakref.WeakSet + + def __init__(self, execution_device: torch.device): + super().__init__(execution_device) + self._models = weakref.WeakSet() + + def install(self, *models: torch.nn.Module): + for model in models: + self._models.add(model) + model.to(device=self.execution_device) + + def uninstall(self, *models: torch.nn.Module): + for model in models: + self._models.remove(model) + + def uninstall_all(self): + self.uninstall(*self._models) + + def load(self, model): + model.to(device=self.execution_device) + + def offload_current(self): + for model in self._models: + model.to(device=OFFLOAD_DEVICE) + + def ready(self): + for model in self._models: + self.load(model) + + def set_device(self, device: torch.device): + self.execution_device = device + for model in self._models: + if model.device != OFFLOAD_DEVICE: + model.to(device=device) + + def device_for(self, model): + if model not in self: + raise KeyError("This does not manage this model f{type(model).__name__}", model) + return self.execution_device # this implementation only dispatches to one device + + def __contains__(self, model): + return model in self._models diff --git a/ldm/modules/prompt_to_embeddings_converter.py b/ldm/modules/prompt_to_embeddings_converter.py index dea15d61b4..84d927d48b 100644 --- a/ldm/modules/prompt_to_embeddings_converter.py +++ b/ldm/modules/prompt_to_embeddings_converter.py @@ -214,7 +214,7 @@ class WeightedPromptFragmentsToEmbeddingsConverter(): def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor) -> torch.Tensor: ''' - Build a tensor that embeds the passed-in token IDs and applyies the given per_token weights + Build a tensor that embeds the passed-in token IDs and applies the given per_token weights :param token_ids: A tensor of shape `[self.max_length]` containing token IDs (ints) :param per_token_weights: A tensor of shape `[self.max_length]` containing weights (floats) :return: A tensor of shape `[1, self.max_length, token_dim]` representing the requested weighted embeddings @@ -224,13 +224,12 @@ class WeightedPromptFragmentsToEmbeddingsConverter(): if token_ids.shape != torch.Size([self.max_length]): raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]") - z = self.text_encoder.forward(input_ids=token_ids.unsqueeze(0), - return_dict=False)[0] + z = self.text_encoder(token_ids.unsqueeze(0), return_dict=False)[0] empty_token_ids = torch.tensor([self.tokenizer.bos_token_id] + [self.tokenizer.pad_token_id] * (self.max_length-2) + - [self.tokenizer.eos_token_id], dtype=torch.int, device=token_ids.device).unsqueeze(0) - empty_z = self.text_encoder(input_ids=empty_token_ids).last_hidden_state - batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape) + [self.tokenizer.eos_token_id], dtype=torch.int, device=z.device).unsqueeze(0) + empty_z = self.text_encoder(empty_token_ids).last_hidden_state + batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape).to(z) z_delta_from_empty = z - empty_z weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)