From d968c6f379dec510eed914b185c3872d3196e7d2 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 17 May 2024 22:29:19 -0400 Subject: [PATCH] refactor multifile download code --- docs/contributing/DOWNLOAD_QUEUE.md | 63 ++++++++- docs/contributing/MODEL_MANAGER.md | 38 ++++++ .../controlnet_image_processors.py | 6 +- invokeai/app/invocations/infill.py | 2 +- invokeai/app/invocations/upscale.py | 2 +- .../model_install/model_install_default.py | 123 ++++++++++-------- .../services/model_load/model_load_base.py | 4 +- .../services/model_load/model_load_default.py | 14 +- .../model_manager/model_manager_base.py | 2 +- .../model_manager/model_manager_default.py | 56 ++++---- .../app/services/shared/invocation_context.py | 51 ++------ .../services/download/test_download_queue.py | 13 ++ .../app/services/model_load/test_load_api.py | 32 +++-- 13 files changed, 262 insertions(+), 144 deletions(-) diff --git a/docs/contributing/DOWNLOAD_QUEUE.md b/docs/contributing/DOWNLOAD_QUEUE.md index d43c670d2c..960180961e 100644 --- a/docs/contributing/DOWNLOAD_QUEUE.md +++ b/docs/contributing/DOWNLOAD_QUEUE.md @@ -128,7 +128,8 @@ The queue operates on a series of download job objects. These objects specify the source and destination of the download, and keep track of the progress of the download. -The only job type currently implemented is `DownloadJob`, a pydantic object with the +Two job types are defined. `DownloadJob` and +`MultiFileDownloadJob`. The former is a pydantic object with the following fields: | **Field** | **Type** | **Default** | **Description** | @@ -138,7 +139,7 @@ following fields: | `dest` | Path | | Where to download to | | `access_token` | str | | [optional] string containing authentication token for access | | `on_start` | Callable | | [optional] callback when the download starts | -| `on_progress` | Callable | | [optional] callback called at intervals during download progress | +| `on_progress` | Callable | | [optional] callback called at intervals during download progress | | `on_complete` | Callable | | [optional] callback called after successful download completion | | `on_error` | Callable | | [optional] callback called after an error occurs | | `id` | int | auto assigned | Job ID, an integer >= 0 | @@ -190,6 +191,33 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an `error_type` field of "DownloadJobCancelledException". In addition, the job's `cancelled` property will be set to True. +The `MultiFileDownloadJob` is used for diffusers model downloads, +which contain multiple files and directories under a common root: + +| **Field** | **Type** | **Default** | **Description** | +|----------------|-----------------|---------------|-----------------| +| _Fields passed in at job creation time_ | +| `download_parts` | Set[DownloadJob]| | Component download jobs | +| `dest` | Path | | Where to download to | +| `on_start` | Callable | | [optional] callback when the download starts | +| `on_progress` | Callable | | [optional] callback called at intervals during download progress | +| `on_complete` | Callable | | [optional] callback called after successful download completion | +| `on_error` | Callable | | [optional] callback called after an error occurs | +| `id` | int | auto assigned | Job ID, an integer >= 0 | +| _Fields updated over the course of the download task_ +| `status` | DownloadJobStatus| | Status code | +| `download_path` | Path | | Path to the root of the downloaded files | +| `bytes` | int | 0 | Bytes downloaded so far | +| `total_bytes` | int | 0 | Total size of the file at the remote site | +| `error_type` | str | | String version of the exception that caused an error during download | +| `error` | str | | String version of the traceback associated with an error | +| `cancelled` | bool | False | Set to true if the job was cancelled by the caller| + +Note that the MultiFileDownloadJob does not support the `priority`, +`job_started`, `job_ended` or `content_type` attributes. You can get +these from the individual download jobs in `download_parts`. + + ### Callbacks Download jobs can be associated with a series of callbacks, each with @@ -251,11 +279,40 @@ jobs using `list_jobs()`, fetch a single job by its with running jobs with `cancel_all_jobs()`, and wait for all jobs to finish with `join()`. -#### job = queue.download(source, dest, priority, access_token) +#### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error) Create a new download job and put it on the queue, returning the DownloadJob object. +#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error) + +This is similar to download(), but instead of taking a single source, +it accepts a `parts` argument consisting of a list of +`RemoteModelFile` objects. Each part corresponds to a URL/Path pair, +where the URL is the location of the remote file, and the Path is the +destination. + +`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and +consists of a url/path pair. Note that the path *must* be relative. + +The method returns a `MultiFileDownloadJob`. + + +``` +from invokeai.backend.model_manager.metadata import RemoteModelFile +remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'', + path='my_model/textencoder/pytorch_model.safetensors' + ) +remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt', + path='my_model/vae/diffusers_model.safetensors' + ) +job = queue.multifile_download(parts=[remote_file_1, remote_file_2], + dest='/tmp/downloads', + on_progress=TqdmProgress().update) +queue.wait_for_job(job) +print(f"The files were downloaded to {job.download_path}") +``` + #### jobs = queue.list_jobs() Return a list of all active and inactive `DownloadJob`s. diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index d53198b98e..fbc9079d49 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -1577,3 +1577,41 @@ This method takes a model key, looks it up using the `ModelRecordServiceBase` object in `mm.store`, and passes the returned model configuration to `load_model_by_config()`. It may raise a `NotImplementedException`. + +## Invocation Context Model Manager API + +Within invocations, the following methods are available from the +`InvocationContext` object: + +### context.download_and_cache_model(source) -> Path + +This method accepts a `source` of a model, downloads and caches it +locally, and returns a Path to the local model. The source can be a +local file or directory, a URL, or a HuggingFace repo_id. + +In the case of HuggingFace repo_id, the following variants are +recognized: + +* stabilityai/stable-diffusion-v4 -- default model +* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant +* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder +* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder + +You can also point at an arbitrary individual file within a repo_id +directory using this syntax: + +* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors + +### context.load_and_cache_model(source, [loader]) -> LoadedModel + +This method takes a model source, downloads it, caches it, and then +loads it into the RAM cache for use in inference. The optional loader +is a Callable that accepts a Path to the object, and returns a +`Dict[str, torch.Tensor]`. If no loader is provided, then the method +will use `torch.load()` for a .ckpt or .bin checkpoint file, +`safetensors.torch.load_file()` for a safetensors checkpoint file, or +`*.from_pretrained()` for a directory that looks like a +diffusers directory. + + + diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 971179ac93..e69f4b54ad 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -611,7 +611,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation): model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device() ) - with context.models.load_ckpt_from_url(source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader) as model: + with context.models.load_and_cache_model(source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader) as model: depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device()) processed_image = depth_anything_detector(image=image, resolution=self.resolution) return processed_image @@ -634,8 +634,8 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation): def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: mm = context.models - onnx_det = mm.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"]) - onnx_pose = mm.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]) + onnx_det = mm.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"]) + onnx_pose = mm.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]) dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose) processed_image = dw_openpose( diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index f8358d1df5..ddd11cf93f 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -133,7 +133,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation): """Infills transparent areas of an image using the LaMa model""" def infill(self, image: Image.Image, context: InvocationContext): - with context.models.load_ckpt_from_url( + with context.models.load_and_cache_model( source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", loader=LaMA.load_jit_model, ) as model: diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 29cf7819de..670082f120 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -91,7 +91,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): context.logger.error(msg) raise ValueError(msg) - loadnet = context.models.load_ckpt_from_url( + loadnet = context.models.load_and_cache_model( source=ESRGAN_MODEL_URLS[self.model_name], ) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index a6bb7ad10d..cde9a6502e 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -387,12 +387,13 @@ class ModelInstallService(ModelInstallServiceBase): model_path.mkdir(parents=True, exist_ok=True) model_source = self._guess_source(source) remote_files, _ = self._remote_files_from_source(model_source) - job = self._download_queue.multifile_download( - parts=remote_files, + job = self._multifile_download( dest=model_path, + remote_files=remote_files, + subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None, ) - files_string = "file" if len(remote_files) == 1 else "file" - self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})") + files_string = "file" if len(remote_files) == 1 else "files" + self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})") self._download_queue.wait_for_job(job) if job.complete: assert job.download_path is not None @@ -734,26 +735,12 @@ class ModelInstallService(ModelInstallServiceBase): ) # remember the temporary directory for later removal install_job._install_tmpdir = destdir + install_job.total_bytes = sum((x.size or 0) for x in remote_files) - # In the event that there is a subfolder specified in the source, - # we need to remove it from the destination path in order to avoid - # creating unwanted subfolders - if isinstance(source, HFModelSource) and source.subfolder: - root = Path(remote_files[0].path.parts[0]) - subfolder = root / source.subfolder - else: - root = Path(".") - subfolder = Path(".") - - parts: List[RemoteModelFile] = [] - for model_file in remote_files: - assert install_job.total_bytes is not None - assert model_file.size is not None - install_job.total_bytes += model_file.size - parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder))) multifile_job = self._multifile_download( - parts=parts, + remote_files=remote_files, dest=destdir, + subfolder=source.subfolder if isinstance(source, HFModelSource) else None, access_token=source.access_token, submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict ) @@ -776,8 +763,35 @@ class ModelInstallService(ModelInstallServiceBase): return size def _multifile_download( - self, parts: List[RemoteModelFile], dest: Path, access_token: Optional[str] = None, submit_job: bool = True + self, + remote_files: List[RemoteModelFile], + dest: Path, + subfolder: Optional[Path] = None, + access_token: Optional[str] = None, + submit_job: bool = True, ) -> MultiFileDownloadJob: + # HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and + # we are installing the "vae" subfolder, we do not want to create an additional folder level, such + # as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo". + # So what we do is to synthesize a folder named "sdxl-turbo_vae" here. + if subfolder: + top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/" + path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/ + path_to_add = Path(f"{top}_{subfolder}") + else: + path_to_remove = Path(".") + path_to_add = Path(".") + + parts: List[RemoteModelFile] = [] + for model_file in remote_files: + assert model_file.size is not None + parts.append( + RemoteModelFile( + url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json + path=path_to_add / model_file.path.relative_to(path_to_remove), + ) + ) + return self._download_queue.multifile_download( parts=parts, dest=dest, @@ -795,56 +809,53 @@ class ModelInstallService(ModelInstallServiceBase): # ------------------------------------------------------------------ def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache[download_job.id] - install_job.status = InstallStatus.DOWNLOADING + if install_job := self._download_cache.get(download_job.id, None): + install_job.status = InstallStatus.DOWNLOADING - assert download_job.download_path - if install_job.local_path == install_job._install_tmpdir: # first time - install_job.local_path = download_job.download_path - install_job.total_bytes = download_job.total_bytes + assert download_job.download_path + if install_job.local_path == install_job._install_tmpdir: # first time + install_job.local_path = download_job.download_path + install_job.total_bytes = download_job.total_bytes def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache[download_job.id] - if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel() - self._download_queue.cancel_job(download_job) - else: - # update sizes - install_job.bytes = sum(x.bytes for x in download_job.download_parts) - self._signal_job_downloading(install_job) + if install_job := self._download_cache.get(download_job.id, None): + if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel() + self._download_queue.cancel_job(download_job) + else: + # update sizes + install_job.bytes = sum(x.bytes for x in download_job.download_parts) + self._signal_job_downloading(install_job) def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache.pop(download_job.id) - self._signal_job_downloads_done(install_job) - self._put_in_queue(install_job) # this starts the installation and registration + if install_job := self._download_cache.pop(download_job.id, None): + self._signal_job_downloads_done(install_job) + self._put_in_queue(install_job) # this starts the installation and registration - # Let other threads know that the number of downloads has changed - self._downloads_changed_event.set() + # Let other threads know that the number of downloads has changed + self._downloads_changed_event.set() def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: with self._lock: - install_job = self._download_cache.pop(download_job.id) - assert install_job is not None - assert excp is not None - install_job.set_error(excp) - self._download_queue.cancel_job(download_job) + if install_job := self._download_cache.pop(download_job.id, None): + assert excp is not None + install_job.set_error(excp) + self._download_queue.cancel_job(download_job) - # Let other threads know that the number of downloads has changed - self._downloads_changed_event.set() + # Let other threads know that the number of downloads has changed + self._downloads_changed_event.set() def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache.pop(download_job.id, None) - if not install_job: - return - self._downloads_changed_event.set() - # if install job has already registered an error, then do not replace its status with cancelled - if not install_job.errored: - install_job.cancel() + if install_job := self._download_cache.pop(download_job.id, None): + self._downloads_changed_event.set() + # if install job has already registered an error, then do not replace its status with cancelled + if not install_job.errored: + install_job.cancel() - # Let other threads know that the number of downloads has changed - self._downloads_changed_event.set() + # Let other threads know that the number of downloads has changed + self._downloads_changed_event.set() # ------------------------------------------------------------------------------------------------ # Internal methods that put events on the event bus diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 32fc62fa5b..7de36793fb 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -43,11 +43,11 @@ class ModelLoadServiceBase(ABC): """Return the checkpoint convert cache used by this loader.""" @abstractmethod - def load_ckpt_from_path( + def load_model_from_path( self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None ) -> LoadedModel: """ - Load the checkpoint-format model file located at the indicated Path. + Load the model file or directory located at the indicated Path. This will load an arbitrary model file into the RAM cache. If the optional loader argument is provided, the loader will be invoked to load the model into diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index af211c260e..cd14235ee0 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -20,6 +20,8 @@ from invokeai.backend.model_manager.load import ( ) from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase +from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader +from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger from .model_load_base import ModelLoadServiceBase @@ -94,7 +96,7 @@ class ModelLoadService(ModelLoadServiceBase): ) return loaded_model - def load_ckpt_from_path( + def load_model_from_path( self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None ) -> LoadedModel: """ @@ -128,6 +130,16 @@ class ModelLoadService(ModelLoadServiceBase): result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu") return result + def diffusers_load_directory(directory: Path) -> AnyModel: + load_class = GenericDiffusersLoader( + app_config=self._app_config, + logger=self._logger, + ram_cache=self._ram_cache, + convert_cache=self.convert_cache, + ).get_hf_load_class(directory) + result: AnyModel = load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype()) + return result + if loader is None: loader = ( torch_load_file diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index d16c00302e..063979ebe6 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -72,7 +72,7 @@ class ModelManagerServiceBase(ABC): pass @abstractmethod - def load_ckpt_from_url( + def load_model_from_url( self, source: str | AnyHttpUrl, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index ed274266f3..dd78f1f3b2 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -64,6 +64,34 @@ class ModelManagerService(ModelManagerServiceBase): if hasattr(service, "stop"): service.stop(invoker) + def load_model_from_url( + self, + source: str | AnyHttpUrl, + loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, + ) -> LoadedModel: + """ + Download, cache, and Load the model file located at the indicated URL. + + This will check the model download cache for the model designated + by the provided URL and download it if needed using download_and_cache_ckpt(). + It will then load the model into the RAM cache. If the optional loader + argument is provided, the loader will be invoked to load the model into + memory. Otherwise the method will call safetensors.torch.load_file() or + torch.load() as appropriate to the file suffix. + + Be aware that the LoadedModel object will have a `config` attribute of None. + + Args: + source: A URL or a string that can be converted in one. Repo_ids + do not work here. + loader: A Callable that expects a Path and returns a Dict[str|int, Any] + + Returns: + A LoadedModel object. + """ + model_path = self.install.download_and_cache_model(source=str(source)) + return self.load.load_model_from_path(model_path=model_path, loader=loader) + @classmethod def build_model_manager( cls, @@ -102,31 +130,3 @@ class ModelManagerService(ModelManagerServiceBase): event_bus=events, ) return cls(store=model_record_service, install=installer, load=loader) - - def load_ckpt_from_url( - self, - source: str | AnyHttpUrl, - loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, - ) -> LoadedModel: - """ - Download, cache, and Load the model file located at the indicated URL. - - This will check the model download cache for the model designated - by the provided URL and download it if needed using download_and_cache_ckpt(). - It will then load the model into the RAM cache. If the optional loader - argument is provided, the loader will be invoked to load the model into - memory. Otherwise the method will call safetensors.torch.load_file() or - torch.load() as appropriate to the file suffix. - - Be aware that the LoadedModel object will have a `config` attribute of None. - - Args: - source: A URL or a string that can be converted in one. Repo_ids - do not work here. - loader: A Callable that expects a Path and returns a Dict[str|int, Any] - - Returns: - A LoadedModel object. - """ - model_path = self.install.download_and_cache_ckpt(source=source) - return self.load.load_ckpt_from_path(model_path=model_path, loader=loader) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index c7602760f7..32d32e227b 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -435,11 +435,9 @@ class ModelsInterface(InvocationContextInterface): ) return result - def download_and_cache_ckpt( + def download_and_cache_model( self, source: str | AnyHttpUrl, - access_token: Optional[str] = None, - timeout: Optional[int] = 0, ) -> Path: """ Download the model file located at source to the models cache and return its Path. @@ -449,12 +447,7 @@ class ModelsInterface(InvocationContextInterface): installed, the cached path will be returned. Otherwise it will be downloaded. Args: - source: A URL or a string that can be converted in one. Repo_ids - do not work here. - access_token: Optional access token for restricted resources. - timeout: Wait up to the indicated number of seconds before timing - out long downloads. - + source: A model path, URL or repo_id. Result: Path to the downloaded model @@ -463,39 +456,14 @@ class ModelsInterface(InvocationContextInterface): TimeoutError """ installer = self._services.model_manager.install - path: Path = installer.download_and_cache_ckpt( + path: Path = installer.download_and_cache_model( source=source, - access_token=access_token, - timeout=timeout, ) return path - def load_ckpt_from_path( - self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None - ) -> LoadedModel: - """ - Load the checkpoint-format model file located at the indicated Path. - - This will load an arbitrary model file into the RAM cache. If the optional loader - argument is provided, the loader will be invoked to load the model into - memory. Otherwise the method will call safetensors.torch.load_file() or - torch.load() as appropriate to the file suffix. - - Be aware that the LoadedModel object will have a `config` attribute of None. - - Args: - model_path: A pathlib.Path to a checkpoint-style models file - loader: A Callable that expects a Path and returns a Dict[str|int, Any] - - Returns: - A LoadedModel object. - """ - result: LoadedModel = self._services.model_manager.load.load_ckpt_from_path(model_path, loader=loader) - return result - - def load_ckpt_from_url( + def load_and_cache_model( self, - source: str | AnyHttpUrl, + source: Path | str | AnyHttpUrl, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, ) -> LoadedModel: """ @@ -511,14 +479,17 @@ class ModelsInterface(InvocationContextInterface): Be aware that the LoadedModel object will have a `config` attribute of None. Args: - source: A URL or a string that can be converted in one. Repo_ids - do not work here. + source: A model Path, URL, or repoid. loader: A Callable that expects a Path and returns a Dict[str|int, Any] Returns: A LoadedModel object. """ - result: LoadedModel = self._services.model_manager.load_ckpt_from_url(source=source, loader=loader) + result: LoadedModel = ( + self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader) + if isinstance(source, Path) + else self._services.model_manager.load_model_from_url(source=source, loader=loader) + ) return result diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index 564d9c30a0..c9317163c8 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -304,6 +304,19 @@ def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None: queue.stop() +def test_multifile_no_rel_paths(tmp_path: Path, mm2_session: Session) -> None: + queue = DownloadQueueService( + requests_session=mm2_session, + ) + + with pytest.raises(AssertionError) as error: + queue.multifile_download( + parts=[RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("/etc/passwd"))], + dest=tmp_path, + ) + assert str(error.value) == "only relative download paths accepted" + + @contextmanager def clear_config() -> Generator[None, None, None]: try: diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py index 7eb09fb375..bd3a67a894 100644 --- a/tests/app/services/model_load/test_load_api.py +++ b/tests/app/services/model_load/test_load_api.py @@ -24,7 +24,7 @@ def mock_context( def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) -> None: - downloaded_path = mock_context.models.download_and_cache_ckpt( + downloaded_path = mock_context.models.download_and_cache_model( "https://www.test.foo/download/test_embedding.safetensors" ) assert downloaded_path.is_file() @@ -32,24 +32,24 @@ def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) assert downloaded_path.name == "test_embedding.safetensors" assert downloaded_path.parent.parent == mm2_root_dir / "models/.download_cache" - downloaded_path_2 = mock_context.models.download_and_cache_ckpt( + downloaded_path_2 = mock_context.models.download_and_cache_model( "https://www.test.foo/download/test_embedding.safetensors" ) assert downloaded_path == downloaded_path_2 def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -> None: - downloaded_path = mock_context.models.download_and_cache_ckpt( + downloaded_path = mock_context.models.download_and_cache_model( "https://www.test.foo/download/test_embedding.safetensors" ) - loaded_model_1 = mock_context.models.load_ckpt_from_path(downloaded_path) + loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path) assert isinstance(loaded_model_1, LoadedModel) - loaded_model_2 = mock_context.models.load_ckpt_from_path(downloaded_path) + loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path) assert isinstance(loaded_model_2, LoadedModel) assert loaded_model_1.model is loaded_model_2.model - loaded_model_3 = mock_context.models.load_ckpt_from_path(embedding_file) + loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file) assert isinstance(loaded_model_3, LoadedModel) assert loaded_model_1.model is not loaded_model_3.model assert isinstance(loaded_model_1.model, dict) @@ -58,9 +58,25 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) - def test_download_and_load(mock_context: InvocationContext) -> None: - loaded_model_1 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors") + loaded_model_1 = mock_context.models.load_and_cache_model( + "https://www.test.foo/download/test_embedding.safetensors" + ) assert isinstance(loaded_model_1, LoadedModel) - loaded_model_2 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors") + loaded_model_2 = mock_context.models.load_and_cache_model( + "https://www.test.foo/download/test_embedding.safetensors" + ) assert isinstance(loaded_model_2, LoadedModel) assert loaded_model_1.model is loaded_model_2.model # should be cached copy + + +def test_download_diffusers(mock_context: InvocationContext) -> None: + model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo") + assert (model_path / "model_index.json").exists() + assert (model_path / "vae").is_dir() + + +def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None: + model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae") + assert model_path.is_dir() + assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists()