support VRAM caching of dict models that lack to()

This commit is contained in:
Lincoln Stein
2024-04-28 13:41:06 -04:00
parent a26667d3ca
commit 7c39929758
5 changed files with 12 additions and 20 deletions

View File

@ -58,7 +58,7 @@ class ModelLoadServiceBase(ABC):
Args:
model_path: A pathlib.Path to a checkpoint-style models file
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
loader: A Callable that expects a Path and returns a Dict[str, Tensor]
Returns:
A LoadedModel object.

View File

@ -109,7 +109,7 @@ class ModelLoadService(ModelLoadServiceBase):
Args:
model_path: A pathlib.Path to a checkpoint-style models file
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
loader: A Callable that expects a Path and returns a Dict[str, Tensor]
Returns:
A LoadedModel object.

View File

@ -437,7 +437,7 @@ class ModelsInterface(InvocationContextInterface):
def download_and_cache_ckpt(
self,
source: Union[str, AnyHttpUrl],
source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
) -> Path:
@ -501,7 +501,7 @@ class ModelsInterface(InvocationContextInterface):
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
Download, cache, and Load the model file located at the indicated URL.
Download, cache, and load the model file located at the indicated URL.
This will check the model download cache for the model designated
by the provided URL and download it if needed using download_and_cache_ckpt().