replace load_and_cache_model() with load_remote_model() and load_local_odel()

This commit is contained in:
Lincoln Stein
2024-06-06 00:31:41 -04:00
committed by psychedelicious
parent 9f9379682e
commit dc134935c8
12 changed files with 106 additions and 69 deletions

View File

@ -1585,9 +1585,9 @@ Within invocations, the following methods are available from the
### context.download_and_cache_model(source) -> Path
This method accepts a `source` of a model, downloads and caches it
locally, and returns a Path to the local model. The source can be a
local file or directory, a URL, or a HuggingFace repo_id.
This method accepts a `source` of a remote model, downloads and caches
it locally, and then returns a Path to the local model. The source can
be a direct download URL or a HuggingFace repo_id.
In the case of HuggingFace repo_id, the following variants are
recognized:
@ -1602,16 +1602,34 @@ directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
### context.load_and_cache_model(source, [loader]) -> LoadedModel
### context.load_local_model(model_path, [loader]) -> LoadedModel
This method takes a model source, downloads it, caches it, and then
loads it into the RAM cache for use in inference. The optional loader
is a Callable that accepts a Path to the object, and returns a
`Dict[str, torch.Tensor]`. If no loader is provided, then the method
will use `torch.load()` for a .ckpt or .bin checkpoint file,
`safetensors.torch.load_file()` for a safetensors checkpoint file, or
`*.from_pretrained()` for a directory that looks like a
diffusers directory.
This method loads a local model from the indicated path, returning a
`LoadedModel`. The optional loader is a Callable that accepts a Path
to the object, and returns a `AnyModel` object. If no loader is
provided, then the method will use `torch.load()` for a .ckpt or .bin
checkpoint file, `safetensors.torch.load_file()` for a safetensors
checkpoint file, or `cls.from_pretrained()` for a directory that looks
like a diffusers directory.
### context.load_remote_model(source, [loader]) -> LoadedModel
This method accepts a `source` of a remote model, downloads and caches
it locally, loads it, and returns a `LoadedModel`. The source can be a
direct download URL or a HuggingFace repo_id.
In the case of HuggingFace repo_id, the following variants are
recognized:
* stabilityai/stable-diffusion-v4 -- default model
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
You can also point at an arbitrary individual file within a repo_id
directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors