mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
support VRAM caching of dict models that lack to()
This commit is contained in:
parent
a26667d3ca
commit
7c39929758
@ -58,7 +58,7 @@ class ModelLoadServiceBase(ABC):
|
||||
|
||||
Args:
|
||||
model_path: A pathlib.Path to a checkpoint-style models file
|
||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||
loader: A Callable that expects a Path and returns a Dict[str, Tensor]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
|
@ -109,7 +109,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
|
||||
Args:
|
||||
model_path: A pathlib.Path to a checkpoint-style models file
|
||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||
loader: A Callable that expects a Path and returns a Dict[str, Tensor]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
|
@ -437,7 +437,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
|
||||
def download_and_cache_ckpt(
|
||||
self,
|
||||
source: Union[str, AnyHttpUrl],
|
||||
source: str | AnyHttpUrl,
|
||||
access_token: Optional[str] = None,
|
||||
timeout: Optional[int] = 0,
|
||||
) -> Path:
|
||||
@ -501,7 +501,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Download, cache, and Load the model file located at the indicated URL.
|
||||
Download, cache, and load the model file located at the indicated URL.
|
||||
|
||||
This will check the model download cache for the model designated
|
||||
by the provided URL and download it if needed using download_and_cache_ckpt().
|
||||
|
@ -252,23 +252,22 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
May raise a torch.cuda.OutOfMemoryError
|
||||
"""
|
||||
# These attributes are not in the base ModelMixin class but in various derived classes.
|
||||
# Some models don't have these attributes, in which case they run in RAM/CPU.
|
||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
|
||||
return
|
||||
model = cache_entry.model
|
||||
|
||||
source_device = cache_entry.model.device
|
||||
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
# This would need to be revised to support multi-GPU.
|
||||
source_device = model.device if hasattr(model, "device") else self.storage_device
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
try:
|
||||
cache_entry.model.to(target_device)
|
||||
if hasattr(model, "to"):
|
||||
model.to(target_device)
|
||||
elif isinstance(model, dict):
|
||||
for _, v in model.items():
|
||||
if hasattr(v, "to"):
|
||||
v.to(target_device)
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise e
|
||||
|
@ -29,10 +29,6 @@ class ModelLocker(ModelLockerBase):
|
||||
|
||||
def lock(self) -> AnyModel:
|
||||
"""Move the model into the execution device (GPU) and lock it."""
|
||||
if not hasattr(self.model, "to"):
|
||||
return self.model
|
||||
|
||||
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
|
||||
self._cache_entry.lock()
|
||||
try:
|
||||
if self._cache.lazy_offloading:
|
||||
@ -55,9 +51,6 @@ class ModelLocker(ModelLockerBase):
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Call upon exit from context."""
|
||||
if not hasattr(self.model, "to"):
|
||||
return
|
||||
|
||||
self._cache_entry.unlock()
|
||||
if not self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
|
Loading…
Reference in New Issue
Block a user