refactor multifile download code

This commit is contained in:
Lincoln Stein 2024-05-17 22:29:19 -04:00
parent 2dae5eb7ad
commit d968c6f379
13 changed files with 262 additions and 144 deletions

View File

@ -128,7 +128,8 @@ The queue operates on a series of download job objects. These objects
specify the source and destination of the download, and keep track of specify the source and destination of the download, and keep track of
the progress of the download. the progress of the download.
The only job type currently implemented is `DownloadJob`, a pydantic object with the Two job types are defined. `DownloadJob` and
`MultiFileDownloadJob`. The former is a pydantic object with the
following fields: following fields:
| **Field** | **Type** | **Default** | **Description** | | **Field** | **Type** | **Default** | **Description** |
@ -138,7 +139,7 @@ following fields:
| `dest` | Path | | Where to download to | | `dest` | Path | | Where to download to |
| `access_token` | str | | [optional] string containing authentication token for access | | `access_token` | str | | [optional] string containing authentication token for access |
| `on_start` | Callable | | [optional] callback when the download starts | | `on_start` | Callable | | [optional] callback when the download starts |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress | | `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_complete` | Callable | | [optional] callback called after successful download completion | | `on_complete` | Callable | | [optional] callback called after successful download completion |
| `on_error` | Callable | | [optional] callback called after an error occurs | | `on_error` | Callable | | [optional] callback called after an error occurs |
| `id` | int | auto assigned | Job ID, an integer >= 0 | | `id` | int | auto assigned | Job ID, an integer >= 0 |
@ -190,6 +191,33 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an
`error_type` field of "DownloadJobCancelledException". In addition, `error_type` field of "DownloadJobCancelledException". In addition,
the job's `cancelled` property will be set to True. the job's `cancelled` property will be set to True.
The `MultiFileDownloadJob` is used for diffusers model downloads,
which contain multiple files and directories under a common root:
| **Field** | **Type** | **Default** | **Description** |
|----------------|-----------------|---------------|-----------------|
| _Fields passed in at job creation time_ |
| `download_parts` | Set[DownloadJob]| | Component download jobs |
| `dest` | Path | | Where to download to |
| `on_start` | Callable | | [optional] callback when the download starts |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_complete` | Callable | | [optional] callback called after successful download completion |
| `on_error` | Callable | | [optional] callback called after an error occurs |
| `id` | int | auto assigned | Job ID, an integer >= 0 |
| _Fields updated over the course of the download task_
| `status` | DownloadJobStatus| | Status code |
| `download_path` | Path | | Path to the root of the downloaded files |
| `bytes` | int | 0 | Bytes downloaded so far |
| `total_bytes` | int | 0 | Total size of the file at the remote site |
| `error_type` | str | | String version of the exception that caused an error during download |
| `error` | str | | String version of the traceback associated with an error |
| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|
Note that the MultiFileDownloadJob does not support the `priority`,
`job_started`, `job_ended` or `content_type` attributes. You can get
these from the individual download jobs in `download_parts`.
### Callbacks ### Callbacks
Download jobs can be associated with a series of callbacks, each with Download jobs can be associated with a series of callbacks, each with
@ -251,11 +279,40 @@ jobs using `list_jobs()`, fetch a single job by its with
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
with `join()`. with `join()`.
#### job = queue.download(source, dest, priority, access_token) #### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
Create a new download job and put it on the queue, returning the Create a new download job and put it on the queue, returning the
DownloadJob object. DownloadJob object.
#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
This is similar to download(), but instead of taking a single source,
it accepts a `parts` argument consisting of a list of
`RemoteModelFile` objects. Each part corresponds to a URL/Path pair,
where the URL is the location of the remote file, and the Path is the
destination.
`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and
consists of a url/path pair. Note that the path *must* be relative.
The method returns a `MultiFileDownloadJob`.
```
from invokeai.backend.model_manager.metadata import RemoteModelFile
remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'',
path='my_model/textencoder/pytorch_model.safetensors'
)
remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt',
path='my_model/vae/diffusers_model.safetensors'
)
job = queue.multifile_download(parts=[remote_file_1, remote_file_2],
dest='/tmp/downloads',
on_progress=TqdmProgress().update)
queue.wait_for_job(job)
print(f"The files were downloaded to {job.download_path}")
```
#### jobs = queue.list_jobs() #### jobs = queue.list_jobs()
Return a list of all active and inactive `DownloadJob`s. Return a list of all active and inactive `DownloadJob`s.

View File

@ -1577,3 +1577,41 @@ This method takes a model key, looks it up using the
`ModelRecordServiceBase` object in `mm.store`, and passes the returned `ModelRecordServiceBase` object in `mm.store`, and passes the returned
model configuration to `load_model_by_config()`. It may raise a model configuration to `load_model_by_config()`. It may raise a
`NotImplementedException`. `NotImplementedException`.
## Invocation Context Model Manager API
Within invocations, the following methods are available from the
`InvocationContext` object:
### context.download_and_cache_model(source) -> Path
This method accepts a `source` of a model, downloads and caches it
locally, and returns a Path to the local model. The source can be a
local file or directory, a URL, or a HuggingFace repo_id.
In the case of HuggingFace repo_id, the following variants are
recognized:
* stabilityai/stable-diffusion-v4 -- default model
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
You can also point at an arbitrary individual file within a repo_id
directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
### context.load_and_cache_model(source, [loader]) -> LoadedModel
This method takes a model source, downloads it, caches it, and then
loads it into the RAM cache for use in inference. The optional loader
is a Callable that accepts a Path to the object, and returns a
`Dict[str, torch.Tensor]`. If no loader is provided, then the method
will use `torch.load()` for a .ckpt or .bin checkpoint file,
`safetensors.torch.load_file()` for a safetensors checkpoint file, or
`*.from_pretrained()` for a directory that looks like a
diffusers directory.

View File

@ -611,7 +611,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device() model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
) )
with context.models.load_ckpt_from_url(source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader) as model: with context.models.load_and_cache_model(source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader) as model:
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device()) depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
processed_image = depth_anything_detector(image=image, resolution=self.resolution) processed_image = depth_anything_detector(image=image, resolution=self.resolution)
return processed_image return processed_image
@ -634,8 +634,8 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
mm = context.models mm = context.models
onnx_det = mm.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"]) onnx_det = mm.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
onnx_pose = mm.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]) onnx_pose = mm.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose) dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
processed_image = dw_openpose( processed_image = dw_openpose(

View File

@ -133,7 +133,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model""" """Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image, context: InvocationContext): def infill(self, image: Image.Image, context: InvocationContext):
with context.models.load_ckpt_from_url( with context.models.load_and_cache_model(
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
loader=LaMA.load_jit_model, loader=LaMA.load_jit_model,
) as model: ) as model:

View File

@ -91,7 +91,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
context.logger.error(msg) context.logger.error(msg)
raise ValueError(msg) raise ValueError(msg)
loadnet = context.models.load_ckpt_from_url( loadnet = context.models.load_and_cache_model(
source=ESRGAN_MODEL_URLS[self.model_name], source=ESRGAN_MODEL_URLS[self.model_name],
) )

View File

@ -387,12 +387,13 @@ class ModelInstallService(ModelInstallServiceBase):
model_path.mkdir(parents=True, exist_ok=True) model_path.mkdir(parents=True, exist_ok=True)
model_source = self._guess_source(source) model_source = self._guess_source(source)
remote_files, _ = self._remote_files_from_source(model_source) remote_files, _ = self._remote_files_from_source(model_source)
job = self._download_queue.multifile_download( job = self._multifile_download(
parts=remote_files,
dest=model_path, dest=model_path,
remote_files=remote_files,
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
) )
files_string = "file" if len(remote_files) == 1 else "file" files_string = "file" if len(remote_files) == 1 else "files"
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})") self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
self._download_queue.wait_for_job(job) self._download_queue.wait_for_job(job)
if job.complete: if job.complete:
assert job.download_path is not None assert job.download_path is not None
@ -734,26 +735,12 @@ class ModelInstallService(ModelInstallServiceBase):
) )
# remember the temporary directory for later removal # remember the temporary directory for later removal
install_job._install_tmpdir = destdir install_job._install_tmpdir = destdir
install_job.total_bytes = sum((x.size or 0) for x in remote_files)
# In the event that there is a subfolder specified in the source,
# we need to remove it from the destination path in order to avoid
# creating unwanted subfolders
if isinstance(source, HFModelSource) and source.subfolder:
root = Path(remote_files[0].path.parts[0])
subfolder = root / source.subfolder
else:
root = Path(".")
subfolder = Path(".")
parts: List[RemoteModelFile] = []
for model_file in remote_files:
assert install_job.total_bytes is not None
assert model_file.size is not None
install_job.total_bytes += model_file.size
parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder)))
multifile_job = self._multifile_download( multifile_job = self._multifile_download(
parts=parts, remote_files=remote_files,
dest=destdir, dest=destdir,
subfolder=source.subfolder if isinstance(source, HFModelSource) else None,
access_token=source.access_token, access_token=source.access_token,
submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
) )
@ -776,8 +763,35 @@ class ModelInstallService(ModelInstallServiceBase):
return size return size
def _multifile_download( def _multifile_download(
self, parts: List[RemoteModelFile], dest: Path, access_token: Optional[str] = None, submit_job: bool = True self,
remote_files: List[RemoteModelFile],
dest: Path,
subfolder: Optional[Path] = None,
access_token: Optional[str] = None,
submit_job: bool = True,
) -> MultiFileDownloadJob: ) -> MultiFileDownloadJob:
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
if subfolder:
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
path_to_add = Path(f"{top}_{subfolder}")
else:
path_to_remove = Path(".")
path_to_add = Path(".")
parts: List[RemoteModelFile] = []
for model_file in remote_files:
assert model_file.size is not None
parts.append(
RemoteModelFile(
url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json
path=path_to_add / model_file.path.relative_to(path_to_remove),
)
)
return self._download_queue.multifile_download( return self._download_queue.multifile_download(
parts=parts, parts=parts,
dest=dest, dest=dest,
@ -795,56 +809,53 @@ class ModelInstallService(ModelInstallServiceBase):
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None: def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock: with self._lock:
install_job = self._download_cache[download_job.id] if install_job := self._download_cache.get(download_job.id, None):
install_job.status = InstallStatus.DOWNLOADING install_job.status = InstallStatus.DOWNLOADING
assert download_job.download_path assert download_job.download_path
if install_job.local_path == install_job._install_tmpdir: # first time if install_job.local_path == install_job._install_tmpdir: # first time
install_job.local_path = download_job.download_path install_job.local_path = download_job.download_path
install_job.total_bytes = download_job.total_bytes install_job.total_bytes = download_job.total_bytes
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None: def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock: with self._lock:
install_job = self._download_cache[download_job.id] if install_job := self._download_cache.get(download_job.id, None):
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel() if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
self._download_queue.cancel_job(download_job) self._download_queue.cancel_job(download_job)
else: else:
# update sizes # update sizes
install_job.bytes = sum(x.bytes for x in download_job.download_parts) install_job.bytes = sum(x.bytes for x in download_job.download_parts)
self._signal_job_downloading(install_job) self._signal_job_downloading(install_job)
def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None: def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock: with self._lock:
install_job = self._download_cache.pop(download_job.id) if install_job := self._download_cache.pop(download_job.id, None):
self._signal_job_downloads_done(install_job) self._signal_job_downloads_done(install_job)
self._put_in_queue(install_job) # this starts the installation and registration self._put_in_queue(install_job) # this starts the installation and registration
# Let other threads know that the number of downloads has changed # Let other threads know that the number of downloads has changed
self._downloads_changed_event.set() self._downloads_changed_event.set()
def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock: with self._lock:
install_job = self._download_cache.pop(download_job.id) if install_job := self._download_cache.pop(download_job.id, None):
assert install_job is not None assert excp is not None
assert excp is not None install_job.set_error(excp)
install_job.set_error(excp) self._download_queue.cancel_job(download_job)
self._download_queue.cancel_job(download_job)
# Let other threads know that the number of downloads has changed # Let other threads know that the number of downloads has changed
self._downloads_changed_event.set() self._downloads_changed_event.set()
def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None: def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock: with self._lock:
install_job = self._download_cache.pop(download_job.id, None) if install_job := self._download_cache.pop(download_job.id, None):
if not install_job: self._downloads_changed_event.set()
return # if install job has already registered an error, then do not replace its status with cancelled
self._downloads_changed_event.set() if not install_job.errored:
# if install job has already registered an error, then do not replace its status with cancelled install_job.cancel()
if not install_job.errored:
install_job.cancel()
# Let other threads know that the number of downloads has changed # Let other threads know that the number of downloads has changed
self._downloads_changed_event.set() self._downloads_changed_event.set()
# ------------------------------------------------------------------------------------------------ # ------------------------------------------------------------------------------------------------
# Internal methods that put events on the event bus # Internal methods that put events on the event bus

View File

@ -43,11 +43,11 @@ class ModelLoadServiceBase(ABC):
"""Return the checkpoint convert cache used by this loader.""" """Return the checkpoint convert cache used by this loader."""
@abstractmethod @abstractmethod
def load_ckpt_from_path( def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
) -> LoadedModel: ) -> LoadedModel:
""" """
Load the checkpoint-format model file located at the indicated Path. Load the model file or directory located at the indicated Path.
This will load an arbitrary model file into the RAM cache. If the optional loader This will load an arbitrary model file into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into argument is provided, the loader will be invoked to load the model into

View File

@ -20,6 +20,8 @@ from invokeai.backend.model_manager.load import (
) )
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from .model_load_base import ModelLoadServiceBase from .model_load_base import ModelLoadServiceBase
@ -94,7 +96,7 @@ class ModelLoadService(ModelLoadServiceBase):
) )
return loaded_model return loaded_model
def load_ckpt_from_path( def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
) -> LoadedModel: ) -> LoadedModel:
""" """
@ -128,6 +130,16 @@ class ModelLoadService(ModelLoadServiceBase):
result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu") result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu")
return result return result
def diffusers_load_directory(directory: Path) -> AnyModel:
load_class = GenericDiffusersLoader(
app_config=self._app_config,
logger=self._logger,
ram_cache=self._ram_cache,
convert_cache=self.convert_cache,
).get_hf_load_class(directory)
result: AnyModel = load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
return result
if loader is None: if loader is None:
loader = ( loader = (
torch_load_file torch_load_file

View File

@ -72,7 +72,7 @@ class ModelManagerServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def load_ckpt_from_url( def load_model_from_url(
self, self,
source: str | AnyHttpUrl, source: str | AnyHttpUrl,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,

View File

@ -64,6 +64,34 @@ class ModelManagerService(ModelManagerServiceBase):
if hasattr(service, "stop"): if hasattr(service, "stop"):
service.stop(invoker) service.stop(invoker)
def load_model_from_url(
self,
source: str | AnyHttpUrl,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
Download, cache, and Load the model file located at the indicated URL.
This will check the model download cache for the model designated
by the provided URL and download it if needed using download_and_cache_ckpt().
It will then load the model into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None.
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""
model_path = self.install.download_and_cache_model(source=str(source))
return self.load.load_model_from_path(model_path=model_path, loader=loader)
@classmethod @classmethod
def build_model_manager( def build_model_manager(
cls, cls,
@ -102,31 +130,3 @@ class ModelManagerService(ModelManagerServiceBase):
event_bus=events, event_bus=events,
) )
return cls(store=model_record_service, install=installer, load=loader) return cls(store=model_record_service, install=installer, load=loader)
def load_ckpt_from_url(
self,
source: str | AnyHttpUrl,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
Download, cache, and Load the model file located at the indicated URL.
This will check the model download cache for the model designated
by the provided URL and download it if needed using download_and_cache_ckpt().
It will then load the model into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None.
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""
model_path = self.install.download_and_cache_ckpt(source=source)
return self.load.load_ckpt_from_path(model_path=model_path, loader=loader)

View File

@ -435,11 +435,9 @@ class ModelsInterface(InvocationContextInterface):
) )
return result return result
def download_and_cache_ckpt( def download_and_cache_model(
self, self,
source: str | AnyHttpUrl, source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
) -> Path: ) -> Path:
""" """
Download the model file located at source to the models cache and return its Path. Download the model file located at source to the models cache and return its Path.
@ -449,12 +447,7 @@ class ModelsInterface(InvocationContextInterface):
installed, the cached path will be returned. Otherwise it will be downloaded. installed, the cached path will be returned. Otherwise it will be downloaded.
Args: Args:
source: A URL or a string that can be converted in one. Repo_ids source: A model path, URL or repo_id.
do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
Result: Result:
Path to the downloaded model Path to the downloaded model
@ -463,39 +456,14 @@ class ModelsInterface(InvocationContextInterface):
TimeoutError TimeoutError
""" """
installer = self._services.model_manager.install installer = self._services.model_manager.install
path: Path = installer.download_and_cache_ckpt( path: Path = installer.download_and_cache_model(
source=source, source=source,
access_token=access_token,
timeout=timeout,
) )
return path return path
def load_ckpt_from_path( def load_and_cache_model(
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None
) -> LoadedModel:
"""
Load the checkpoint-format model file located at the indicated Path.
This will load an arbitrary model file into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None.
Args:
model_path: A pathlib.Path to a checkpoint-style models file
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""
result: LoadedModel = self._services.model_manager.load.load_ckpt_from_path(model_path, loader=loader)
return result
def load_ckpt_from_url(
self, self,
source: str | AnyHttpUrl, source: Path | str | AnyHttpUrl,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel: ) -> LoadedModel:
""" """
@ -511,14 +479,17 @@ class ModelsInterface(InvocationContextInterface):
Be aware that the LoadedModel object will have a `config` attribute of None. Be aware that the LoadedModel object will have a `config` attribute of None.
Args: Args:
source: A URL or a string that can be converted in one. Repo_ids source: A model Path, URL, or repoid.
do not work here.
loader: A Callable that expects a Path and returns a Dict[str|int, Any] loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns: Returns:
A LoadedModel object. A LoadedModel object.
""" """
result: LoadedModel = self._services.model_manager.load_ckpt_from_url(source=source, loader=loader) result: LoadedModel = (
self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
if isinstance(source, Path)
else self._services.model_manager.load_model_from_url(source=source, loader=loader)
)
return result return result

View File

@ -304,6 +304,19 @@ def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None:
queue.stop() queue.stop()
def test_multifile_no_rel_paths(tmp_path: Path, mm2_session: Session) -> None:
queue = DownloadQueueService(
requests_session=mm2_session,
)
with pytest.raises(AssertionError) as error:
queue.multifile_download(
parts=[RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("/etc/passwd"))],
dest=tmp_path,
)
assert str(error.value) == "only relative download paths accepted"
@contextmanager @contextmanager
def clear_config() -> Generator[None, None, None]: def clear_config() -> Generator[None, None, None]:
try: try:

View File

@ -24,7 +24,7 @@ def mock_context(
def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) -> None: def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) -> None:
downloaded_path = mock_context.models.download_and_cache_ckpt( downloaded_path = mock_context.models.download_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors" "https://www.test.foo/download/test_embedding.safetensors"
) )
assert downloaded_path.is_file() assert downloaded_path.is_file()
@ -32,24 +32,24 @@ def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path)
assert downloaded_path.name == "test_embedding.safetensors" assert downloaded_path.name == "test_embedding.safetensors"
assert downloaded_path.parent.parent == mm2_root_dir / "models/.download_cache" assert downloaded_path.parent.parent == mm2_root_dir / "models/.download_cache"
downloaded_path_2 = mock_context.models.download_and_cache_ckpt( downloaded_path_2 = mock_context.models.download_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors" "https://www.test.foo/download/test_embedding.safetensors"
) )
assert downloaded_path == downloaded_path_2 assert downloaded_path == downloaded_path_2
def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -> None: def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -> None:
downloaded_path = mock_context.models.download_and_cache_ckpt( downloaded_path = mock_context.models.download_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors" "https://www.test.foo/download/test_embedding.safetensors"
) )
loaded_model_1 = mock_context.models.load_ckpt_from_path(downloaded_path) loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path)
assert isinstance(loaded_model_1, LoadedModel) assert isinstance(loaded_model_1, LoadedModel)
loaded_model_2 = mock_context.models.load_ckpt_from_path(downloaded_path) loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path)
assert isinstance(loaded_model_2, LoadedModel) assert isinstance(loaded_model_2, LoadedModel)
assert loaded_model_1.model is loaded_model_2.model assert loaded_model_1.model is loaded_model_2.model
loaded_model_3 = mock_context.models.load_ckpt_from_path(embedding_file) loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file)
assert isinstance(loaded_model_3, LoadedModel) assert isinstance(loaded_model_3, LoadedModel)
assert loaded_model_1.model is not loaded_model_3.model assert loaded_model_1.model is not loaded_model_3.model
assert isinstance(loaded_model_1.model, dict) assert isinstance(loaded_model_1.model, dict)
@ -58,9 +58,25 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
def test_download_and_load(mock_context: InvocationContext) -> None: def test_download_and_load(mock_context: InvocationContext) -> None:
loaded_model_1 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors") loaded_model_1 = mock_context.models.load_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors"
)
assert isinstance(loaded_model_1, LoadedModel) assert isinstance(loaded_model_1, LoadedModel)
loaded_model_2 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors") loaded_model_2 = mock_context.models.load_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors"
)
assert isinstance(loaded_model_2, LoadedModel) assert isinstance(loaded_model_2, LoadedModel)
assert loaded_model_1.model is loaded_model_2.model # should be cached copy assert loaded_model_1.model is loaded_model_2.model # should be cached copy
def test_download_diffusers(mock_context: InvocationContext) -> None:
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo")
assert (model_path / "model_index.json").exists()
assert (model_path / "vae").is_dir()
def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None:
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae")
assert model_path.is_dir()
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists()