mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor multifile download code
This commit is contained in:
parent
2dae5eb7ad
commit
d968c6f379
@ -128,7 +128,8 @@ The queue operates on a series of download job objects. These objects
|
|||||||
specify the source and destination of the download, and keep track of
|
specify the source and destination of the download, and keep track of
|
||||||
the progress of the download.
|
the progress of the download.
|
||||||
|
|
||||||
The only job type currently implemented is `DownloadJob`, a pydantic object with the
|
Two job types are defined. `DownloadJob` and
|
||||||
|
`MultiFileDownloadJob`. The former is a pydantic object with the
|
||||||
following fields:
|
following fields:
|
||||||
|
|
||||||
| **Field** | **Type** | **Default** | **Description** |
|
| **Field** | **Type** | **Default** | **Description** |
|
||||||
@ -138,7 +139,7 @@ following fields:
|
|||||||
| `dest` | Path | | Where to download to |
|
| `dest` | Path | | Where to download to |
|
||||||
| `access_token` | str | | [optional] string containing authentication token for access |
|
| `access_token` | str | | [optional] string containing authentication token for access |
|
||||||
| `on_start` | Callable | | [optional] callback when the download starts |
|
| `on_start` | Callable | | [optional] callback when the download starts |
|
||||||
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
||||||
| `on_complete` | Callable | | [optional] callback called after successful download completion |
|
| `on_complete` | Callable | | [optional] callback called after successful download completion |
|
||||||
| `on_error` | Callable | | [optional] callback called after an error occurs |
|
| `on_error` | Callable | | [optional] callback called after an error occurs |
|
||||||
| `id` | int | auto assigned | Job ID, an integer >= 0 |
|
| `id` | int | auto assigned | Job ID, an integer >= 0 |
|
||||||
@ -190,6 +191,33 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an
|
|||||||
`error_type` field of "DownloadJobCancelledException". In addition,
|
`error_type` field of "DownloadJobCancelledException". In addition,
|
||||||
the job's `cancelled` property will be set to True.
|
the job's `cancelled` property will be set to True.
|
||||||
|
|
||||||
|
The `MultiFileDownloadJob` is used for diffusers model downloads,
|
||||||
|
which contain multiple files and directories under a common root:
|
||||||
|
|
||||||
|
| **Field** | **Type** | **Default** | **Description** |
|
||||||
|
|----------------|-----------------|---------------|-----------------|
|
||||||
|
| _Fields passed in at job creation time_ |
|
||||||
|
| `download_parts` | Set[DownloadJob]| | Component download jobs |
|
||||||
|
| `dest` | Path | | Where to download to |
|
||||||
|
| `on_start` | Callable | | [optional] callback when the download starts |
|
||||||
|
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
||||||
|
| `on_complete` | Callable | | [optional] callback called after successful download completion |
|
||||||
|
| `on_error` | Callable | | [optional] callback called after an error occurs |
|
||||||
|
| `id` | int | auto assigned | Job ID, an integer >= 0 |
|
||||||
|
| _Fields updated over the course of the download task_
|
||||||
|
| `status` | DownloadJobStatus| | Status code |
|
||||||
|
| `download_path` | Path | | Path to the root of the downloaded files |
|
||||||
|
| `bytes` | int | 0 | Bytes downloaded so far |
|
||||||
|
| `total_bytes` | int | 0 | Total size of the file at the remote site |
|
||||||
|
| `error_type` | str | | String version of the exception that caused an error during download |
|
||||||
|
| `error` | str | | String version of the traceback associated with an error |
|
||||||
|
| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|
|
||||||
|
|
||||||
|
Note that the MultiFileDownloadJob does not support the `priority`,
|
||||||
|
`job_started`, `job_ended` or `content_type` attributes. You can get
|
||||||
|
these from the individual download jobs in `download_parts`.
|
||||||
|
|
||||||
|
|
||||||
### Callbacks
|
### Callbacks
|
||||||
|
|
||||||
Download jobs can be associated with a series of callbacks, each with
|
Download jobs can be associated with a series of callbacks, each with
|
||||||
@ -251,11 +279,40 @@ jobs using `list_jobs()`, fetch a single job by its with
|
|||||||
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
|
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
|
||||||
with `join()`.
|
with `join()`.
|
||||||
|
|
||||||
#### job = queue.download(source, dest, priority, access_token)
|
#### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
|
||||||
|
|
||||||
Create a new download job and put it on the queue, returning the
|
Create a new download job and put it on the queue, returning the
|
||||||
DownloadJob object.
|
DownloadJob object.
|
||||||
|
|
||||||
|
#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
|
||||||
|
|
||||||
|
This is similar to download(), but instead of taking a single source,
|
||||||
|
it accepts a `parts` argument consisting of a list of
|
||||||
|
`RemoteModelFile` objects. Each part corresponds to a URL/Path pair,
|
||||||
|
where the URL is the location of the remote file, and the Path is the
|
||||||
|
destination.
|
||||||
|
|
||||||
|
`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and
|
||||||
|
consists of a url/path pair. Note that the path *must* be relative.
|
||||||
|
|
||||||
|
The method returns a `MultiFileDownloadJob`.
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
||||||
|
remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'',
|
||||||
|
path='my_model/textencoder/pytorch_model.safetensors'
|
||||||
|
)
|
||||||
|
remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt',
|
||||||
|
path='my_model/vae/diffusers_model.safetensors'
|
||||||
|
)
|
||||||
|
job = queue.multifile_download(parts=[remote_file_1, remote_file_2],
|
||||||
|
dest='/tmp/downloads',
|
||||||
|
on_progress=TqdmProgress().update)
|
||||||
|
queue.wait_for_job(job)
|
||||||
|
print(f"The files were downloaded to {job.download_path}")
|
||||||
|
```
|
||||||
|
|
||||||
#### jobs = queue.list_jobs()
|
#### jobs = queue.list_jobs()
|
||||||
|
|
||||||
Return a list of all active and inactive `DownloadJob`s.
|
Return a list of all active and inactive `DownloadJob`s.
|
||||||
|
@ -1577,3 +1577,41 @@ This method takes a model key, looks it up using the
|
|||||||
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
|
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
|
||||||
model configuration to `load_model_by_config()`. It may raise a
|
model configuration to `load_model_by_config()`. It may raise a
|
||||||
`NotImplementedException`.
|
`NotImplementedException`.
|
||||||
|
|
||||||
|
## Invocation Context Model Manager API
|
||||||
|
|
||||||
|
Within invocations, the following methods are available from the
|
||||||
|
`InvocationContext` object:
|
||||||
|
|
||||||
|
### context.download_and_cache_model(source) -> Path
|
||||||
|
|
||||||
|
This method accepts a `source` of a model, downloads and caches it
|
||||||
|
locally, and returns a Path to the local model. The source can be a
|
||||||
|
local file or directory, a URL, or a HuggingFace repo_id.
|
||||||
|
|
||||||
|
In the case of HuggingFace repo_id, the following variants are
|
||||||
|
recognized:
|
||||||
|
|
||||||
|
* stabilityai/stable-diffusion-v4 -- default model
|
||||||
|
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
|
||||||
|
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
|
||||||
|
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
|
||||||
|
|
||||||
|
You can also point at an arbitrary individual file within a repo_id
|
||||||
|
directory using this syntax:
|
||||||
|
|
||||||
|
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||||
|
|
||||||
|
### context.load_and_cache_model(source, [loader]) -> LoadedModel
|
||||||
|
|
||||||
|
This method takes a model source, downloads it, caches it, and then
|
||||||
|
loads it into the RAM cache for use in inference. The optional loader
|
||||||
|
is a Callable that accepts a Path to the object, and returns a
|
||||||
|
`Dict[str, torch.Tensor]`. If no loader is provided, then the method
|
||||||
|
will use `torch.load()` for a .ckpt or .bin checkpoint file,
|
||||||
|
`safetensors.torch.load_file()` for a safetensors checkpoint file, or
|
||||||
|
`*.from_pretrained()` for a directory that looks like a
|
||||||
|
diffusers directory.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -611,7 +611,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
||||||
)
|
)
|
||||||
|
|
||||||
with context.models.load_ckpt_from_url(source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader) as model:
|
with context.models.load_and_cache_model(source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader) as model:
|
||||||
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
||||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||||
return processed_image
|
return processed_image
|
||||||
@ -634,8 +634,8 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
|
|
||||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
mm = context.models
|
mm = context.models
|
||||||
onnx_det = mm.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"])
|
onnx_det = mm.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
|
||||||
onnx_pose = mm.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
onnx_pose = mm.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
||||||
|
|
||||||
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||||
processed_image = dw_openpose(
|
processed_image = dw_openpose(
|
||||||
|
@ -133,7 +133,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
|||||||
"""Infills transparent areas of an image using the LaMa model"""
|
"""Infills transparent areas of an image using the LaMa model"""
|
||||||
|
|
||||||
def infill(self, image: Image.Image, context: InvocationContext):
|
def infill(self, image: Image.Image, context: InvocationContext):
|
||||||
with context.models.load_ckpt_from_url(
|
with context.models.load_and_cache_model(
|
||||||
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||||
loader=LaMA.load_jit_model,
|
loader=LaMA.load_jit_model,
|
||||||
) as model:
|
) as model:
|
||||||
|
@ -91,7 +91,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
context.logger.error(msg)
|
context.logger.error(msg)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
loadnet = context.models.load_ckpt_from_url(
|
loadnet = context.models.load_and_cache_model(
|
||||||
source=ESRGAN_MODEL_URLS[self.model_name],
|
source=ESRGAN_MODEL_URLS[self.model_name],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -387,12 +387,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
model_path.mkdir(parents=True, exist_ok=True)
|
model_path.mkdir(parents=True, exist_ok=True)
|
||||||
model_source = self._guess_source(source)
|
model_source = self._guess_source(source)
|
||||||
remote_files, _ = self._remote_files_from_source(model_source)
|
remote_files, _ = self._remote_files_from_source(model_source)
|
||||||
job = self._download_queue.multifile_download(
|
job = self._multifile_download(
|
||||||
parts=remote_files,
|
|
||||||
dest=model_path,
|
dest=model_path,
|
||||||
|
remote_files=remote_files,
|
||||||
|
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
|
||||||
)
|
)
|
||||||
files_string = "file" if len(remote_files) == 1 else "file"
|
files_string = "file" if len(remote_files) == 1 else "files"
|
||||||
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
|
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
|
||||||
self._download_queue.wait_for_job(job)
|
self._download_queue.wait_for_job(job)
|
||||||
if job.complete:
|
if job.complete:
|
||||||
assert job.download_path is not None
|
assert job.download_path is not None
|
||||||
@ -734,26 +735,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
)
|
)
|
||||||
# remember the temporary directory for later removal
|
# remember the temporary directory for later removal
|
||||||
install_job._install_tmpdir = destdir
|
install_job._install_tmpdir = destdir
|
||||||
|
install_job.total_bytes = sum((x.size or 0) for x in remote_files)
|
||||||
|
|
||||||
# In the event that there is a subfolder specified in the source,
|
|
||||||
# we need to remove it from the destination path in order to avoid
|
|
||||||
# creating unwanted subfolders
|
|
||||||
if isinstance(source, HFModelSource) and source.subfolder:
|
|
||||||
root = Path(remote_files[0].path.parts[0])
|
|
||||||
subfolder = root / source.subfolder
|
|
||||||
else:
|
|
||||||
root = Path(".")
|
|
||||||
subfolder = Path(".")
|
|
||||||
|
|
||||||
parts: List[RemoteModelFile] = []
|
|
||||||
for model_file in remote_files:
|
|
||||||
assert install_job.total_bytes is not None
|
|
||||||
assert model_file.size is not None
|
|
||||||
install_job.total_bytes += model_file.size
|
|
||||||
parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder)))
|
|
||||||
multifile_job = self._multifile_download(
|
multifile_job = self._multifile_download(
|
||||||
parts=parts,
|
remote_files=remote_files,
|
||||||
dest=destdir,
|
dest=destdir,
|
||||||
|
subfolder=source.subfolder if isinstance(source, HFModelSource) else None,
|
||||||
access_token=source.access_token,
|
access_token=source.access_token,
|
||||||
submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
|
submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
|
||||||
)
|
)
|
||||||
@ -776,8 +763,35 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
return size
|
return size
|
||||||
|
|
||||||
def _multifile_download(
|
def _multifile_download(
|
||||||
self, parts: List[RemoteModelFile], dest: Path, access_token: Optional[str] = None, submit_job: bool = True
|
self,
|
||||||
|
remote_files: List[RemoteModelFile],
|
||||||
|
dest: Path,
|
||||||
|
subfolder: Optional[Path] = None,
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
submit_job: bool = True,
|
||||||
) -> MultiFileDownloadJob:
|
) -> MultiFileDownloadJob:
|
||||||
|
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
|
||||||
|
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
|
||||||
|
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
|
||||||
|
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
|
||||||
|
if subfolder:
|
||||||
|
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
||||||
|
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
|
||||||
|
path_to_add = Path(f"{top}_{subfolder}")
|
||||||
|
else:
|
||||||
|
path_to_remove = Path(".")
|
||||||
|
path_to_add = Path(".")
|
||||||
|
|
||||||
|
parts: List[RemoteModelFile] = []
|
||||||
|
for model_file in remote_files:
|
||||||
|
assert model_file.size is not None
|
||||||
|
parts.append(
|
||||||
|
RemoteModelFile(
|
||||||
|
url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json
|
||||||
|
path=path_to_add / model_file.path.relative_to(path_to_remove),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return self._download_queue.multifile_download(
|
return self._download_queue.multifile_download(
|
||||||
parts=parts,
|
parts=parts,
|
||||||
dest=dest,
|
dest=dest,
|
||||||
@ -795,56 +809,53 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
|
def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache[download_job.id]
|
if install_job := self._download_cache.get(download_job.id, None):
|
||||||
install_job.status = InstallStatus.DOWNLOADING
|
install_job.status = InstallStatus.DOWNLOADING
|
||||||
|
|
||||||
assert download_job.download_path
|
assert download_job.download_path
|
||||||
if install_job.local_path == install_job._install_tmpdir: # first time
|
if install_job.local_path == install_job._install_tmpdir: # first time
|
||||||
install_job.local_path = download_job.download_path
|
install_job.local_path = download_job.download_path
|
||||||
install_job.total_bytes = download_job.total_bytes
|
install_job.total_bytes = download_job.total_bytes
|
||||||
|
|
||||||
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
|
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache[download_job.id]
|
if install_job := self._download_cache.get(download_job.id, None):
|
||||||
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
|
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
|
||||||
self._download_queue.cancel_job(download_job)
|
self._download_queue.cancel_job(download_job)
|
||||||
else:
|
else:
|
||||||
# update sizes
|
# update sizes
|
||||||
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
||||||
self._signal_job_downloading(install_job)
|
self._signal_job_downloading(install_job)
|
||||||
|
|
||||||
def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
|
def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache.pop(download_job.id)
|
if install_job := self._download_cache.pop(download_job.id, None):
|
||||||
self._signal_job_downloads_done(install_job)
|
self._signal_job_downloads_done(install_job)
|
||||||
self._put_in_queue(install_job) # this starts the installation and registration
|
self._put_in_queue(install_job) # this starts the installation and registration
|
||||||
|
|
||||||
# Let other threads know that the number of downloads has changed
|
# Let other threads know that the number of downloads has changed
|
||||||
self._downloads_changed_event.set()
|
self._downloads_changed_event.set()
|
||||||
|
|
||||||
def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache.pop(download_job.id)
|
if install_job := self._download_cache.pop(download_job.id, None):
|
||||||
assert install_job is not None
|
assert excp is not None
|
||||||
assert excp is not None
|
install_job.set_error(excp)
|
||||||
install_job.set_error(excp)
|
self._download_queue.cancel_job(download_job)
|
||||||
self._download_queue.cancel_job(download_job)
|
|
||||||
|
|
||||||
# Let other threads know that the number of downloads has changed
|
# Let other threads know that the number of downloads has changed
|
||||||
self._downloads_changed_event.set()
|
self._downloads_changed_event.set()
|
||||||
|
|
||||||
def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
|
def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache.pop(download_job.id, None)
|
if install_job := self._download_cache.pop(download_job.id, None):
|
||||||
if not install_job:
|
self._downloads_changed_event.set()
|
||||||
return
|
# if install job has already registered an error, then do not replace its status with cancelled
|
||||||
self._downloads_changed_event.set()
|
if not install_job.errored:
|
||||||
# if install job has already registered an error, then do not replace its status with cancelled
|
install_job.cancel()
|
||||||
if not install_job.errored:
|
|
||||||
install_job.cancel()
|
|
||||||
|
|
||||||
# Let other threads know that the number of downloads has changed
|
# Let other threads know that the number of downloads has changed
|
||||||
self._downloads_changed_event.set()
|
self._downloads_changed_event.set()
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------------
|
||||||
# Internal methods that put events on the event bus
|
# Internal methods that put events on the event bus
|
||||||
|
@ -43,11 +43,11 @@ class ModelLoadServiceBase(ABC):
|
|||||||
"""Return the checkpoint convert cache used by this loader."""
|
"""Return the checkpoint convert cache used by this loader."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_ckpt_from_path(
|
def load_model_from_path(
|
||||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
|
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
Load the checkpoint-format model file located at the indicated Path.
|
Load the model file or directory located at the indicated Path.
|
||||||
|
|
||||||
This will load an arbitrary model file into the RAM cache. If the optional loader
|
This will load an arbitrary model file into the RAM cache. If the optional loader
|
||||||
argument is provided, the loader will be invoked to load the model into
|
argument is provided, the loader will be invoked to load the model into
|
||||||
|
@ -20,6 +20,8 @@ from invokeai.backend.model_manager.load import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||||
|
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from .model_load_base import ModelLoadServiceBase
|
from .model_load_base import ModelLoadServiceBase
|
||||||
@ -94,7 +96,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
)
|
)
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
||||||
def load_ckpt_from_path(
|
def load_model_from_path(
|
||||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
|
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
@ -128,6 +130,16 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu")
|
result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def diffusers_load_directory(directory: Path) -> AnyModel:
|
||||||
|
load_class = GenericDiffusersLoader(
|
||||||
|
app_config=self._app_config,
|
||||||
|
logger=self._logger,
|
||||||
|
ram_cache=self._ram_cache,
|
||||||
|
convert_cache=self.convert_cache,
|
||||||
|
).get_hf_load_class(directory)
|
||||||
|
result: AnyModel = load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
|
||||||
|
return result
|
||||||
|
|
||||||
if loader is None:
|
if loader is None:
|
||||||
loader = (
|
loader = (
|
||||||
torch_load_file
|
torch_load_file
|
||||||
|
@ -72,7 +72,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_ckpt_from_url(
|
def load_model_from_url(
|
||||||
self,
|
self,
|
||||||
source: str | AnyHttpUrl,
|
source: str | AnyHttpUrl,
|
||||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||||
|
@ -64,6 +64,34 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
if hasattr(service, "stop"):
|
if hasattr(service, "stop"):
|
||||||
service.stop(invoker)
|
service.stop(invoker)
|
||||||
|
|
||||||
|
def load_model_from_url(
|
||||||
|
self,
|
||||||
|
source: str | AnyHttpUrl,
|
||||||
|
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||||
|
) -> LoadedModel:
|
||||||
|
"""
|
||||||
|
Download, cache, and Load the model file located at the indicated URL.
|
||||||
|
|
||||||
|
This will check the model download cache for the model designated
|
||||||
|
by the provided URL and download it if needed using download_and_cache_ckpt().
|
||||||
|
It will then load the model into the RAM cache. If the optional loader
|
||||||
|
argument is provided, the loader will be invoked to load the model into
|
||||||
|
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||||
|
torch.load() as appropriate to the file suffix.
|
||||||
|
|
||||||
|
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: A URL or a string that can be converted in one. Repo_ids
|
||||||
|
do not work here.
|
||||||
|
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A LoadedModel object.
|
||||||
|
"""
|
||||||
|
model_path = self.install.download_and_cache_model(source=str(source))
|
||||||
|
return self.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_model_manager(
|
def build_model_manager(
|
||||||
cls,
|
cls,
|
||||||
@ -102,31 +130,3 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
event_bus=events,
|
event_bus=events,
|
||||||
)
|
)
|
||||||
return cls(store=model_record_service, install=installer, load=loader)
|
return cls(store=model_record_service, install=installer, load=loader)
|
||||||
|
|
||||||
def load_ckpt_from_url(
|
|
||||||
self,
|
|
||||||
source: str | AnyHttpUrl,
|
|
||||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
|
||||||
) -> LoadedModel:
|
|
||||||
"""
|
|
||||||
Download, cache, and Load the model file located at the indicated URL.
|
|
||||||
|
|
||||||
This will check the model download cache for the model designated
|
|
||||||
by the provided URL and download it if needed using download_and_cache_ckpt().
|
|
||||||
It will then load the model into the RAM cache. If the optional loader
|
|
||||||
argument is provided, the loader will be invoked to load the model into
|
|
||||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
|
||||||
torch.load() as appropriate to the file suffix.
|
|
||||||
|
|
||||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source: A URL or a string that can be converted in one. Repo_ids
|
|
||||||
do not work here.
|
|
||||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A LoadedModel object.
|
|
||||||
"""
|
|
||||||
model_path = self.install.download_and_cache_ckpt(source=source)
|
|
||||||
return self.load.load_ckpt_from_path(model_path=model_path, loader=loader)
|
|
||||||
|
@ -435,11 +435,9 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def download_and_cache_ckpt(
|
def download_and_cache_model(
|
||||||
self,
|
self,
|
||||||
source: str | AnyHttpUrl,
|
source: str | AnyHttpUrl,
|
||||||
access_token: Optional[str] = None,
|
|
||||||
timeout: Optional[int] = 0,
|
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""
|
"""
|
||||||
Download the model file located at source to the models cache and return its Path.
|
Download the model file located at source to the models cache and return its Path.
|
||||||
@ -449,12 +447,7 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
installed, the cached path will be returned. Otherwise it will be downloaded.
|
installed, the cached path will be returned. Otherwise it will be downloaded.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source: A URL or a string that can be converted in one. Repo_ids
|
source: A model path, URL or repo_id.
|
||||||
do not work here.
|
|
||||||
access_token: Optional access token for restricted resources.
|
|
||||||
timeout: Wait up to the indicated number of seconds before timing
|
|
||||||
out long downloads.
|
|
||||||
|
|
||||||
Result:
|
Result:
|
||||||
Path to the downloaded model
|
Path to the downloaded model
|
||||||
|
|
||||||
@ -463,39 +456,14 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
TimeoutError
|
TimeoutError
|
||||||
"""
|
"""
|
||||||
installer = self._services.model_manager.install
|
installer = self._services.model_manager.install
|
||||||
path: Path = installer.download_and_cache_ckpt(
|
path: Path = installer.download_and_cache_model(
|
||||||
source=source,
|
source=source,
|
||||||
access_token=access_token,
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
def load_ckpt_from_path(
|
def load_and_cache_model(
|
||||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None
|
|
||||||
) -> LoadedModel:
|
|
||||||
"""
|
|
||||||
Load the checkpoint-format model file located at the indicated Path.
|
|
||||||
|
|
||||||
This will load an arbitrary model file into the RAM cache. If the optional loader
|
|
||||||
argument is provided, the loader will be invoked to load the model into
|
|
||||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
|
||||||
torch.load() as appropriate to the file suffix.
|
|
||||||
|
|
||||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: A pathlib.Path to a checkpoint-style models file
|
|
||||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A LoadedModel object.
|
|
||||||
"""
|
|
||||||
result: LoadedModel = self._services.model_manager.load.load_ckpt_from_path(model_path, loader=loader)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def load_ckpt_from_url(
|
|
||||||
self,
|
self,
|
||||||
source: str | AnyHttpUrl,
|
source: Path | str | AnyHttpUrl,
|
||||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
@ -511,14 +479,17 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source: A URL or a string that can be converted in one. Repo_ids
|
source: A model Path, URL, or repoid.
|
||||||
do not work here.
|
|
||||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A LoadedModel object.
|
A LoadedModel object.
|
||||||
"""
|
"""
|
||||||
result: LoadedModel = self._services.model_manager.load_ckpt_from_url(source=source, loader=loader)
|
result: LoadedModel = (
|
||||||
|
self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
|
||||||
|
if isinstance(source, Path)
|
||||||
|
else self._services.model_manager.load_model_from_url(source=source, loader=loader)
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@ -304,6 +304,19 @@ def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None:
|
|||||||
queue.stop()
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_multifile_no_rel_paths(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
|
queue = DownloadQueueService(
|
||||||
|
requests_session=mm2_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError) as error:
|
||||||
|
queue.multifile_download(
|
||||||
|
parts=[RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("/etc/passwd"))],
|
||||||
|
dest=tmp_path,
|
||||||
|
)
|
||||||
|
assert str(error.value) == "only relative download paths accepted"
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def clear_config() -> Generator[None, None, None]:
|
def clear_config() -> Generator[None, None, None]:
|
||||||
try:
|
try:
|
||||||
|
@ -24,7 +24,7 @@ def mock_context(
|
|||||||
|
|
||||||
|
|
||||||
def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) -> None:
|
def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) -> None:
|
||||||
downloaded_path = mock_context.models.download_and_cache_ckpt(
|
downloaded_path = mock_context.models.download_and_cache_model(
|
||||||
"https://www.test.foo/download/test_embedding.safetensors"
|
"https://www.test.foo/download/test_embedding.safetensors"
|
||||||
)
|
)
|
||||||
assert downloaded_path.is_file()
|
assert downloaded_path.is_file()
|
||||||
@ -32,24 +32,24 @@ def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path)
|
|||||||
assert downloaded_path.name == "test_embedding.safetensors"
|
assert downloaded_path.name == "test_embedding.safetensors"
|
||||||
assert downloaded_path.parent.parent == mm2_root_dir / "models/.download_cache"
|
assert downloaded_path.parent.parent == mm2_root_dir / "models/.download_cache"
|
||||||
|
|
||||||
downloaded_path_2 = mock_context.models.download_and_cache_ckpt(
|
downloaded_path_2 = mock_context.models.download_and_cache_model(
|
||||||
"https://www.test.foo/download/test_embedding.safetensors"
|
"https://www.test.foo/download/test_embedding.safetensors"
|
||||||
)
|
)
|
||||||
assert downloaded_path == downloaded_path_2
|
assert downloaded_path == downloaded_path_2
|
||||||
|
|
||||||
|
|
||||||
def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -> None:
|
def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -> None:
|
||||||
downloaded_path = mock_context.models.download_and_cache_ckpt(
|
downloaded_path = mock_context.models.download_and_cache_model(
|
||||||
"https://www.test.foo/download/test_embedding.safetensors"
|
"https://www.test.foo/download/test_embedding.safetensors"
|
||||||
)
|
)
|
||||||
loaded_model_1 = mock_context.models.load_ckpt_from_path(downloaded_path)
|
loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path)
|
||||||
assert isinstance(loaded_model_1, LoadedModel)
|
assert isinstance(loaded_model_1, LoadedModel)
|
||||||
|
|
||||||
loaded_model_2 = mock_context.models.load_ckpt_from_path(downloaded_path)
|
loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path)
|
||||||
assert isinstance(loaded_model_2, LoadedModel)
|
assert isinstance(loaded_model_2, LoadedModel)
|
||||||
assert loaded_model_1.model is loaded_model_2.model
|
assert loaded_model_1.model is loaded_model_2.model
|
||||||
|
|
||||||
loaded_model_3 = mock_context.models.load_ckpt_from_path(embedding_file)
|
loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file)
|
||||||
assert isinstance(loaded_model_3, LoadedModel)
|
assert isinstance(loaded_model_3, LoadedModel)
|
||||||
assert loaded_model_1.model is not loaded_model_3.model
|
assert loaded_model_1.model is not loaded_model_3.model
|
||||||
assert isinstance(loaded_model_1.model, dict)
|
assert isinstance(loaded_model_1.model, dict)
|
||||||
@ -58,9 +58,25 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
|
|||||||
|
|
||||||
|
|
||||||
def test_download_and_load(mock_context: InvocationContext) -> None:
|
def test_download_and_load(mock_context: InvocationContext) -> None:
|
||||||
loaded_model_1 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors")
|
loaded_model_1 = mock_context.models.load_and_cache_model(
|
||||||
|
"https://www.test.foo/download/test_embedding.safetensors"
|
||||||
|
)
|
||||||
assert isinstance(loaded_model_1, LoadedModel)
|
assert isinstance(loaded_model_1, LoadedModel)
|
||||||
|
|
||||||
loaded_model_2 = mock_context.models.load_ckpt_from_url("https://www.test.foo/download/test_embedding.safetensors")
|
loaded_model_2 = mock_context.models.load_and_cache_model(
|
||||||
|
"https://www.test.foo/download/test_embedding.safetensors"
|
||||||
|
)
|
||||||
assert isinstance(loaded_model_2, LoadedModel)
|
assert isinstance(loaded_model_2, LoadedModel)
|
||||||
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
|
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_diffusers(mock_context: InvocationContext) -> None:
|
||||||
|
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo")
|
||||||
|
assert (model_path / "model_index.json").exists()
|
||||||
|
assert (model_path / "vae").is_dir()
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None:
|
||||||
|
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae")
|
||||||
|
assert model_path.is_dir()
|
||||||
|
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists()
|
||||||
|
Loading…
Reference in New Issue
Block a user