support VRAM caching of dict models that lack to()

This commit is contained in:
Lincoln Stein 2024-04-28 13:41:06 -04:00
parent a26667d3ca
commit 7c39929758
5 changed files with 12 additions and 20 deletions

View File

@ -58,7 +58,7 @@ class ModelLoadServiceBase(ABC):
Args: Args:
model_path: A pathlib.Path to a checkpoint-style models file model_path: A pathlib.Path to a checkpoint-style models file
loader: A Callable that expects a Path and returns a Dict[str|int, Any] loader: A Callable that expects a Path and returns a Dict[str, Tensor]
Returns: Returns:
A LoadedModel object. A LoadedModel object.

View File

@ -109,7 +109,7 @@ class ModelLoadService(ModelLoadServiceBase):
Args: Args:
model_path: A pathlib.Path to a checkpoint-style models file model_path: A pathlib.Path to a checkpoint-style models file
loader: A Callable that expects a Path and returns a Dict[str|int, Any] loader: A Callable that expects a Path and returns a Dict[str, Tensor]
Returns: Returns:
A LoadedModel object. A LoadedModel object.

View File

@ -437,7 +437,7 @@ class ModelsInterface(InvocationContextInterface):
def download_and_cache_ckpt( def download_and_cache_ckpt(
self, self,
source: Union[str, AnyHttpUrl], source: str | AnyHttpUrl,
access_token: Optional[str] = None, access_token: Optional[str] = None,
timeout: Optional[int] = 0, timeout: Optional[int] = 0,
) -> Path: ) -> Path:
@ -501,7 +501,7 @@ class ModelsInterface(InvocationContextInterface):
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel: ) -> LoadedModel:
""" """
Download, cache, and Load the model file located at the indicated URL. Download, cache, and load the model file located at the indicated URL.
This will check the model download cache for the model designated This will check the model download cache for the model designated
by the provided URL and download it if needed using download_and_cache_ckpt(). by the provided URL and download it if needed using download_and_cache_ckpt().

View File

@ -252,23 +252,22 @@ class ModelCache(ModelCacheBase[AnyModel]):
May raise a torch.cuda.OutOfMemoryError May raise a torch.cuda.OutOfMemoryError
""" """
# These attributes are not in the base ModelMixin class but in various derived classes.
# Some models don't have these attributes, in which case they run in RAM/CPU.
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")): model = cache_entry.model
return
source_device = cache_entry.model.device source_device = model.device if hasattr(model, "device") else self.storage_device
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type: if torch.device(source_device).type == torch.device(target_device).type:
return return
start_model_to_time = time.time() start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot() snapshot_before = self._capture_memory_snapshot()
try: try:
cache_entry.model.to(target_device) if hasattr(model, "to"):
model.to(target_device)
elif isinstance(model, dict):
for _, v in model.items():
if hasattr(v, "to"):
v.to(target_device)
except Exception as e: # blow away cache entry except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry) self._delete_cache_entry(cache_entry)
raise e raise e

View File

@ -29,10 +29,6 @@ class ModelLocker(ModelLockerBase):
def lock(self) -> AnyModel: def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it.""" """Move the model into the execution device (GPU) and lock it."""
if not hasattr(self.model, "to"):
return self.model
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
self._cache_entry.lock() self._cache_entry.lock()
try: try:
if self._cache.lazy_offloading: if self._cache.lazy_offloading:
@ -55,9 +51,6 @@ class ModelLocker(ModelLockerBase):
def unlock(self) -> None: def unlock(self) -> None:
"""Call upon exit from context.""" """Call upon exit from context."""
if not hasattr(self.model, "to"):
return
self._cache_entry.unlock() self._cache_entry.unlock()
if not self._cache.lazy_offloading: if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size) self._cache.offload_unlocked_models(self._cache_entry.size)