2023-02-16 23:48:27 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import warnings
|
|
|
|
import weakref
|
|
|
|
from abc import ABCMeta, abstractmethod
|
|
|
|
from collections.abc import MutableMapping
|
2023-07-03 14:55:04 +00:00
|
|
|
from typing import Callable, Union
|
2023-02-16 23:48:27 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
from accelerate.utils import send_to_device
|
|
|
|
from torch.utils.hooks import RemovableHandle
|
|
|
|
|
|
|
|
OFFLOAD_DEVICE = torch.device("cpu")
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2023-02-16 23:48:27 +00:00
|
|
|
class _NoModel:
|
|
|
|
"""Symbol that indicates no model is loaded.
|
|
|
|
|
|
|
|
(We can't weakref.ref(None), so this was my best idea at the time to come up with something
|
|
|
|
type-checkable.)
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __bool__(self):
|
|
|
|
return False
|
|
|
|
|
|
|
|
def to(self, device: torch.device):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return "<NO MODEL>"
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2023-02-16 23:48:27 +00:00
|
|
|
NO_MODEL = _NoModel()
|
|
|
|
|
|
|
|
|
|
|
|
class ModelGroup(metaclass=ABCMeta):
|
|
|
|
"""
|
|
|
|
A group of models.
|
|
|
|
|
|
|
|
The use case I had in mind when writing this is the sub-models used by a DiffusionPipeline,
|
|
|
|
e.g. its text encoder, U-net, VAE, etc.
|
|
|
|
|
|
|
|
Those models are :py:class:`diffusers.ModelMixin`, but "model" is interchangeable with
|
|
|
|
:py:class:`torch.nn.Module` here.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, execution_device: torch.device):
|
|
|
|
self.execution_device = execution_device
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def install(self, *models: torch.nn.Module):
|
|
|
|
"""Add models to this group."""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def uninstall(self, models: torch.nn.Module):
|
|
|
|
"""Remove models from this group."""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def uninstall_all(self):
|
|
|
|
"""Remove all models from this group."""
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def load(self, model: torch.nn.Module):
|
|
|
|
"""Load this model to the execution device."""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def offload_current(self):
|
|
|
|
"""Offload the current model(s) from the execution device."""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def ready(self):
|
|
|
|
"""Ready this group for use."""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def set_device(self, device: torch.device):
|
|
|
|
"""Change which device models from this group will execute on."""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def device_for(self, model) -> torch.device:
|
|
|
|
"""Get the device the given model will execute on.
|
|
|
|
|
|
|
|
The model should already be a member of this group.
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def __contains__(self, model):
|
|
|
|
"""Check if the model is a member of this group."""
|
|
|
|
pass
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
2023-07-27 14:54:01 +00:00
|
|
|
return f"<{self.__class__.__name__} object at {id(self):x}: " f"device={self.execution_device} >"
|
2023-02-16 23:48:27 +00:00
|
|
|
|
|
|
|
|
|
|
|
class LazilyLoadedModelGroup(ModelGroup):
|
|
|
|
"""
|
|
|
|
Only one model from this group is loaded on the GPU at a time.
|
|
|
|
|
|
|
|
Running the forward method of a model will displace the previously-loaded model,
|
|
|
|
offloading it to CPU.
|
|
|
|
|
|
|
|
If you call other methods on the model, e.g. ``model.encode(x)`` instead of ``model(x)``,
|
|
|
|
you will need to explicitly load it with :py:method:`.load(model)`.
|
|
|
|
|
|
|
|
This implementation relies on pytorch forward-pre-hooks, and it will copy forward arguments
|
|
|
|
to the appropriate execution device, as long as they are positional arguments and not keyword
|
|
|
|
arguments. (I didn't make the rules; that's the way the pytorch 1.13 API works for hooks.)
|
|
|
|
"""
|
|
|
|
|
|
|
|
_hooks: MutableMapping[torch.nn.Module, RemovableHandle]
|
2023-07-03 14:55:04 +00:00
|
|
|
_current_model_ref: Callable[[], Union[torch.nn.Module, _NoModel]]
|
2023-02-16 23:48:27 +00:00
|
|
|
|
|
|
|
def __init__(self, execution_device: torch.device):
|
|
|
|
super().__init__(execution_device)
|
|
|
|
self._hooks = weakref.WeakKeyDictionary()
|
|
|
|
self._current_model_ref = weakref.ref(NO_MODEL)
|
|
|
|
|
|
|
|
def install(self, *models: torch.nn.Module):
|
|
|
|
for model in models:
|
|
|
|
self._hooks[model] = model.register_forward_pre_hook(self._pre_hook)
|
|
|
|
|
|
|
|
def uninstall(self, *models: torch.nn.Module):
|
|
|
|
for model in models:
|
|
|
|
hook = self._hooks.pop(model)
|
|
|
|
hook.remove()
|
|
|
|
if self.is_current_model(model):
|
|
|
|
# no longer hooked by this object, so don't claim to manage it
|
|
|
|
self.clear_current_model()
|
|
|
|
|
|
|
|
def uninstall_all(self):
|
|
|
|
self.uninstall(*self._hooks.keys())
|
|
|
|
|
|
|
|
def _pre_hook(self, module: torch.nn.Module, forward_input):
|
|
|
|
self.load(module)
|
|
|
|
if len(forward_input) == 0:
|
2023-03-03 06:02:00 +00:00
|
|
|
warnings.warn(
|
2023-07-27 14:54:01 +00:00
|
|
|
f"Hook for {module.__class__.__name__} got no input. " f"Inputs must be positional, not keywords.",
|
2023-03-03 06:02:00 +00:00
|
|
|
stacklevel=3,
|
|
|
|
)
|
2023-02-16 23:48:27 +00:00
|
|
|
return send_to_device(forward_input, self.execution_device)
|
|
|
|
|
|
|
|
def load(self, module):
|
|
|
|
if not self.is_current_model(module):
|
|
|
|
self.offload_current()
|
|
|
|
self._load(module)
|
|
|
|
|
|
|
|
def offload_current(self):
|
|
|
|
module = self._current_model_ref()
|
|
|
|
if module is not NO_MODEL:
|
2023-04-28 04:41:52 +00:00
|
|
|
module.to(OFFLOAD_DEVICE)
|
2023-02-16 23:48:27 +00:00
|
|
|
self.clear_current_model()
|
|
|
|
|
|
|
|
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
|
2023-07-27 14:54:01 +00:00
|
|
|
assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}"
|
2023-02-16 23:48:27 +00:00
|
|
|
module = module.to(self.execution_device)
|
|
|
|
self.set_current_model(module)
|
|
|
|
return module
|
|
|
|
|
|
|
|
def is_current_model(self, model: torch.nn.Module) -> bool:
|
|
|
|
"""Is the given model the one currently loaded on the execution device?"""
|
|
|
|
return self._current_model_ref() is model
|
|
|
|
|
|
|
|
def is_empty(self):
|
|
|
|
"""Are none of this group's models loaded on the execution device?"""
|
|
|
|
return self._current_model_ref() is NO_MODEL
|
|
|
|
|
|
|
|
def set_current_model(self, value):
|
|
|
|
self._current_model_ref = weakref.ref(value)
|
|
|
|
|
|
|
|
def clear_current_model(self):
|
|
|
|
self._current_model_ref = weakref.ref(NO_MODEL)
|
|
|
|
|
|
|
|
def set_device(self, device: torch.device):
|
|
|
|
if device == self.execution_device:
|
|
|
|
return
|
|
|
|
self.execution_device = device
|
|
|
|
current = self._current_model_ref()
|
|
|
|
if current is not NO_MODEL:
|
|
|
|
current.to(device)
|
|
|
|
|
|
|
|
def device_for(self, model):
|
|
|
|
if model not in self:
|
2023-07-27 14:54:01 +00:00
|
|
|
raise KeyError(f"This does not manage this model {type(model).__name__}", model)
|
|
|
|
return self.execution_device # this implementation only dispatches to one device
|
2023-02-16 23:48:27 +00:00
|
|
|
|
|
|
|
def ready(self):
|
|
|
|
pass # always ready to load on-demand
|
|
|
|
|
|
|
|
def __contains__(self, model):
|
|
|
|
return model in self._hooks
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
2023-03-03 06:02:00 +00:00
|
|
|
return (
|
|
|
|
f"<{self.__class__.__name__} object at {id(self):x}: "
|
|
|
|
f"current_model={type(self._current_model_ref()).__name__} >"
|
|
|
|
)
|
2023-02-16 23:48:27 +00:00
|
|
|
|
|
|
|
|
|
|
|
class FullyLoadedModelGroup(ModelGroup):
|
|
|
|
"""
|
|
|
|
A group of models without any implicit loading or unloading.
|
|
|
|
|
|
|
|
:py:meth:`.ready` loads _all_ the models to the execution device at once.
|
|
|
|
"""
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2023-02-16 23:48:27 +00:00
|
|
|
_models: weakref.WeakSet
|
|
|
|
|
|
|
|
def __init__(self, execution_device: torch.device):
|
|
|
|
super().__init__(execution_device)
|
|
|
|
self._models = weakref.WeakSet()
|
|
|
|
|
|
|
|
def install(self, *models: torch.nn.Module):
|
|
|
|
for model in models:
|
|
|
|
self._models.add(model)
|
2023-04-28 04:41:52 +00:00
|
|
|
model.to(self.execution_device)
|
2023-02-16 23:48:27 +00:00
|
|
|
|
|
|
|
def uninstall(self, *models: torch.nn.Module):
|
|
|
|
for model in models:
|
|
|
|
self._models.remove(model)
|
|
|
|
|
|
|
|
def uninstall_all(self):
|
|
|
|
self.uninstall(*self._models)
|
|
|
|
|
|
|
|
def load(self, model):
|
2023-04-28 04:41:52 +00:00
|
|
|
model.to(self.execution_device)
|
2023-02-16 23:48:27 +00:00
|
|
|
|
|
|
|
def offload_current(self):
|
|
|
|
for model in self._models:
|
2023-04-28 04:41:52 +00:00
|
|
|
model.to(OFFLOAD_DEVICE)
|
2023-02-16 23:48:27 +00:00
|
|
|
|
|
|
|
def ready(self):
|
|
|
|
for model in self._models:
|
|
|
|
self.load(model)
|
|
|
|
|
|
|
|
def set_device(self, device: torch.device):
|
|
|
|
self.execution_device = device
|
|
|
|
for model in self._models:
|
|
|
|
if model.device != OFFLOAD_DEVICE:
|
2023-04-28 04:41:52 +00:00
|
|
|
model.to(device)
|
2023-02-16 23:48:27 +00:00
|
|
|
|
|
|
|
def device_for(self, model):
|
|
|
|
if model not in self:
|
2023-07-27 14:54:01 +00:00
|
|
|
raise KeyError("This does not manage this model f{type(model).__name__}", model)
|
|
|
|
return self.execution_device # this implementation only dispatches to one device
|
2023-02-16 23:48:27 +00:00
|
|
|
|
|
|
|
def __contains__(self, model):
|
|
|
|
return model in self._models
|