added notes

This commit is contained in:
Lincoln Stein 2024-04-01 13:30:02 -04:00
parent 9df0980c46
commit eca29c41d0

View File

@ -33,6 +33,20 @@ class ModelLocker(ModelLockerBase):
"""Return the model without moving it around."""
return self._cache_entry.model
# ---------------------------- NOTE -----------------
# Ryan suggests keeping a copy of the model's state dict in CPU and copying it
# into the GPU with code like this:
#
# def state_dict_to(state_dict: dict[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]:
# new_state_dict: dict[str, torch.Tensor] = {}
# for k, v in state_dict.items():
# new_state_dict[k] = v.to(device=device, copy=True, non_blocking=True)
# return new_state_dict
#
# I believe we'd then use load_state_dict() to inject the state dict into the model.
# See: https://pytorch.org/tutorials/beginner/saving_loading_models.html
# ---------------------------- NOTE -----------------
def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it."""
if not hasattr(self.model, "to"):