support VRAM caching of dict models that lack to()

This commit is contained in:
Lincoln Stein 2024-04-28 13:41:06 -04:00
parent a26667d3ca
commit 7c39929758
5 changed files with 12 additions and 20 deletions

View File

@ -58,7 +58,7 @@ class ModelLoadServiceBase(ABC):
Args:
model_path: A pathlib.Path to a checkpoint-style models file
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
loader: A Callable that expects a Path and returns a Dict[str, Tensor]
Returns:
A LoadedModel object.

View File

@ -109,7 +109,7 @@ class ModelLoadService(ModelLoadServiceBase):
Args:
model_path: A pathlib.Path to a checkpoint-style models file
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
loader: A Callable that expects a Path and returns a Dict[str, Tensor]
Returns:
A LoadedModel object.

View File

@ -437,7 +437,7 @@ class ModelsInterface(InvocationContextInterface):
def download_and_cache_ckpt(
self,
source: Union[str, AnyHttpUrl],
source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
) -> Path:
@ -501,7 +501,7 @@ class ModelsInterface(InvocationContextInterface):
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
Download, cache, and Load the model file located at the indicated URL.
Download, cache, and load the model file located at the indicated URL.
This will check the model download cache for the model designated
by the provided URL and download it if needed using download_and_cache_ckpt().

View File

@ -252,23 +252,22 @@ class ModelCache(ModelCacheBase[AnyModel]):
May raise a torch.cuda.OutOfMemoryError
"""
# These attributes are not in the base ModelMixin class but in various derived classes.
# Some models don't have these attributes, in which case they run in RAM/CPU.
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
return
model = cache_entry.model
source_device = cache_entry.model.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU.
source_device = model.device if hasattr(model, "device") else self.storage_device
if torch.device(source_device).type == torch.device(target_device).type:
return
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
try:
cache_entry.model.to(target_device)
if hasattr(model, "to"):
model.to(target_device)
elif isinstance(model, dict):
for _, v in model.items():
if hasattr(v, "to"):
v.to(target_device)
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e

View File

@ -29,10 +29,6 @@ class ModelLocker(ModelLockerBase):
def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it."""
if not hasattr(self.model, "to"):
return self.model
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
self._cache_entry.lock()
try:
if self._cache.lazy_offloading:
@ -55,9 +51,6 @@ class ModelLocker(ModelLockerBase):
def unlock(self) -> None:
"""Call upon exit from context."""
if not hasattr(self.model, "to"):
return
self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)