mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Model Manager Refactor: Install remote models and store their tags and other metadata (#5361)
* add basic functionality for model metadata fetching from hf and civitai * add storage * start unit tests * add unit tests and documentation * add missing dependency for pytests * remove redundant fetch; add modified/published dates; updated docs * add code to select diffusers files based on the variant type * implement Civitai installs * make huggingface parallel downloading work * add unit tests for model installation manager - Fixed race condition on selection of download destination path - Add fixtures common to several model_manager_2 unit tests - Added dummy model files for testing diffusers and safetensors downloading/probing - Refactored code for selecting proper variant from list of huggingface repo files - Regrouped ordering of methods in model_install_default.py * improve Civitai model downloading - Provide a better error message when Civitai requires an access token (doesn't give a 403 forbidden, but redirects to the HTML of an authorization page -- arrgh) - Handle case of Civitai providing a primary download link plus additional links for VAEs, config files, etc * add routes for retrieving metadata and tags * code tidying and documentation * fix ruff errors * add file needed to maintain test root diretory in repo for unit tests * fix self->cls in classmethod * add pydantic plugin for mypy * use TestSession instead of requests.Session to prevent any internet activity improve logging fix error message formatting fix logging again fix forward vs reverse slash issue in Windows install tests * Several fixes of problems detected during PR review: - Implement cancel_model_install_job and get_model_install_job routes to allow for better control of model download and install. - Fix thread deadlock that occurred after cancelling an install. - Remove unneeded pytest_plugins section from tests/conftest.py - Remove unused _in_terminal_state() from model_install_default. - Remove outdated documentation from several spots. - Add workaround for Civitai API results which don't return correct URL for the default model. * fix docs and tests to match get_job_by_source() rather than get_job() * Update invokeai/backend/model_manager/metadata/fetch/huggingface.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * Call CivitaiMetadata.model_validate_json() directly Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * Second round of revisions suggested by @ryanjdick: - Fix type mismatch in `list_all_metadata()` route. - Do not have a default value for the model install job id - Remove static class variable declarations from non Pydantic classes - Change `id` field to `model_id` for the sqlite3 `model_tags` table. - Changed AFTER DELETE triggers to ON DELETE CASCADE for the metadata and tags tables. - Made the `id` field of the `model_metadata` table into a primary key to achieve uniqueness. * Code cleanup suggested in PR review: - Narrowed the declaration of the `parts` attribute of the download progress event - Removed auto-conversion of str to Url in Url-containing sources - Fixed handling of `InvalidModelConfigException` - Made unknown sources raise `NotImplementedError` rather than `Exception` - Improved status reporting on cached HuggingFace access tokens * Multiple fixes: - `job.total_size` returns a valid size for locally installed models - new route `list_models` returns a paged summary of model, name, description, tags and other essential info - fix a few type errors * consolidated all invokeai root pytest fixtures into a single location * Update invokeai/backend/model_manager/metadata/metadata_store.py Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> * Small tweaks in response to review comments: - Remove flake8 configuration from pyproject.toml - Use `id` rather than `modelId` for huggingface `ModelInfo` object - Use `last_modified` rather than `LastModified` for huggingface `ModelInfo` object - Add `sha256` field to file metadata downloaded from huggingface - Add `Invoker` argument to the model installer `start()` and `stop()` routines (but made it optional in order to facilitate use of the service outside the API) - Removed redundant `PRAGMA foreign_keys` from metadata store initialization code. * Additional tweaks and minor bug fixes - Fix calculation of aggregate diffusers model size to only count the size of files, not files + directories (which gives different unit test results on different filesystems). - Refactor _get_metadata() and _get_download_urls() to have distinct code paths for Civitai, HuggingFace and URL sources. - Forward the `inplace` flag from the source to the job and added unit test for this. - Attach cached model metadata to the job rather than to the model install service. * fix unit test that was breaking on windows due to CR/LF changing size of test json files * fix ruff formatting * a few last minor fixes before merging: - Turn job `error` and `error_type` into properties derived from the exception. - Add TODO comment about the reason for handling temporary directory destruction manually rather than using tempfile.tmpdir(). * add unit tests for reporting HTTP download errors --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
parent
426a7b900f
commit
4536e4a8b6
@ -15,8 +15,13 @@ model. These are the:
|
||||
their metadata, and `ModelRecordServiceBase` to store that
|
||||
information. It is also responsible for managing the InvokeAI
|
||||
`models` directory and its contents.
|
||||
|
||||
* _DownloadQueueServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
|
||||
|
||||
* _ModelMetadataStore_ and _ModelMetaDataFetch_ Backend modules that
|
||||
are able to retrieve metadata from online model repositories,
|
||||
transform them into Pydantic models, and cache them to the InvokeAI
|
||||
SQL database.
|
||||
|
||||
* _DownloadQueueServiceBase_
|
||||
A multithreaded downloader responsible
|
||||
for downloading models from a remote source to disk. The download
|
||||
queue has special methods for downloading repo_id folders from
|
||||
@ -30,13 +35,13 @@ model. These are the:
|
||||
|
||||
## Location of the Code
|
||||
|
||||
All four of these services can be found in
|
||||
The four main services can be found in
|
||||
`invokeai/app/services` in the following directories:
|
||||
|
||||
* `invokeai/app/services/model_records/`
|
||||
* `invokeai/app/services/model_install/`
|
||||
* `invokeai/app/services/downloads/`
|
||||
* `invokeai/app/services/model_loader/` (**under development**)
|
||||
* `invokeai/app/services/downloads/`(**under development**)
|
||||
|
||||
Code related to the FastAPI web API can be found in
|
||||
`invokeai/app/api/routers/model_records.py`.
|
||||
@ -402,15 +407,18 @@ functionality:
|
||||
the download, installation and registration process.
|
||||
|
||||
- Downloading a model from an arbitrary URL and installing it in
|
||||
`models_dir` (_implementation pending_).
|
||||
`models_dir`.
|
||||
|
||||
- Special handling for Civitai model URLs which allow the user to
|
||||
paste in a model page's URL or download link (_implementation pending_).
|
||||
|
||||
paste in a model page's URL or download link
|
||||
|
||||
- Special handling for HuggingFace repo_ids to recursively download
|
||||
the contents of the repository, paying attention to alternative
|
||||
variants such as fp16. (_implementation pending_)
|
||||
variants such as fp16.
|
||||
|
||||
- Saving tags and other metadata about the model into the invokeai database
|
||||
when fetching from a repo that provides that type of information,
|
||||
(currently only Civitai and HuggingFace).
|
||||
|
||||
### Initializing the installer
|
||||
|
||||
@ -426,16 +434,24 @@ following initialization pattern:
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||
from invokeai.app.services.model_install import ModelInstallService
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
db = SqliteDatabase(config, logger)
|
||||
record_store = ModelRecordServiceSQL(db)
|
||||
queue = DownloadQueueService()
|
||||
queue.start()
|
||||
|
||||
store = ModelRecordServiceSQL(db)
|
||||
installer = ModelInstallService(config, store)
|
||||
installer = ModelInstallService(app_config=config,
|
||||
record_store=record_store,
|
||||
download_queue=queue
|
||||
)
|
||||
installer.start()
|
||||
```
|
||||
|
||||
The full form of `ModelInstallService()` takes the following
|
||||
@ -443,9 +459,12 @@ required parameters:
|
||||
|
||||
| **Argument** | **Type** | **Description** |
|
||||
|------------------|------------------------------|------------------------------|
|
||||
| `config` | InvokeAIAppConfig | InvokeAI app configuration object |
|
||||
| `app_config` | InvokeAIAppConfig | InvokeAI app configuration object |
|
||||
| `record_store` | ModelRecordServiceBase | Config record storage database |
|
||||
| `event_bus` | EventServiceBase | Optional event bus to send download/install progress events to |
|
||||
| `download_queue` | DownloadQueueServiceBase | Download queue object |
|
||||
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|
||||
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
|
||||
|
||||
|
||||
Once initialized, the installer will provide the following methods:
|
||||
|
||||
@ -474,14 +493,14 @@ source7 = URLModelSource(url='https://civitai.com/api/download/models/63006', ac
|
||||
for source in [source1, source2, source3, source4, source5, source6, source7]:
|
||||
install_job = installer.install_model(source)
|
||||
|
||||
source2job = installer.wait_for_installs()
|
||||
source2job = installer.wait_for_installs(timeout=120)
|
||||
for source in sources:
|
||||
job = source2job[source]
|
||||
if job.status == "completed":
|
||||
if job.complete:
|
||||
model_config = job.config_out
|
||||
model_key = model_config.key
|
||||
print(f"{source} installed as {model_key}")
|
||||
elif job.status == "error":
|
||||
elif job.errored:
|
||||
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
||||
|
||||
```
|
||||
@ -515,43 +534,117 @@ The full list of arguments to `import_model()` is as follows:
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `source` | Union[str, Path, AnyHttpUrl] | | The source of the model, Path, URL or repo_id |
|
||||
| `inplace` | bool | True | Leave a local model in its current location |
|
||||
| `variant` | str | None | Desired variant, such as 'fp16' or 'onnx' (HuggingFace only) |
|
||||
| `subfolder` | str | None | Repository subfolder (HuggingFace only) |
|
||||
| `source` | ModelSource | None | The source of the model, Path, URL or repo_id |
|
||||
| `config` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
|
||||
| `access_token` | str | None | Provide authorization information needed to download |
|
||||
|
||||
|
||||
The `inplace` field controls how local model Paths are handled. If
|
||||
True (the default), then the model is simply registered in its current
|
||||
location by the installer's `ModelConfigRecordService`. Otherwise, a
|
||||
copy of the model put into the location specified by the `models_dir`
|
||||
application configuration parameter.
|
||||
|
||||
The `variant` field is used for HuggingFace repo_ids only. If
|
||||
provided, the repo_id download handler will look for and download
|
||||
tensors files that follow the convention for the selected variant:
|
||||
|
||||
- "fp16" will select files named "*model.fp16.{safetensors,bin}"
|
||||
- "onnx" will select files ending with the suffix ".onnx"
|
||||
- "openvino" will select files beginning with "openvino_model"
|
||||
|
||||
In the special case of the "fp16" variant, the installer will select
|
||||
the 32-bit version of the files if the 16-bit version is unavailable.
|
||||
|
||||
`subfolder` is used for HuggingFace repo_ids only. If provided, the
|
||||
model will be downloaded from the designated subfolder rather than the
|
||||
top-level repository folder. If a subfolder is attached to the repo_id
|
||||
using the format `repo_owner/repo_name:subfolder`, then the subfolder
|
||||
specified by the repo_id will override the subfolder argument.
|
||||
The next few sections describe the various types of ModelSource that
|
||||
can be passed to `import_model()`.
|
||||
|
||||
`config` can be used to override all or a portion of the configuration
|
||||
attributes returned by the model prober. See the section below for
|
||||
details.
|
||||
|
||||
`access_token` is passed to the download queue and used to access
|
||||
repositories that require it.
|
||||
|
||||
#### LocalModelSource
|
||||
|
||||
This is used for a model that is located on a locally-accessible Posix
|
||||
filesystem, such as a local disk or networked fileshare.
|
||||
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `path` | str | Path | None | Path to the model file or directory |
|
||||
| `inplace` | bool | False | If set, the model file(s) will be left in their location; otherwise they will be copied into the InvokeAI root's `models` directory |
|
||||
|
||||
#### URLModelSource
|
||||
|
||||
This is used for a single-file model that is accessible via a URL. The
|
||||
fields are:
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `url` | AnyHttpUrl | None | The URL for the model file. |
|
||||
| `access_token` | str | None | An access token needed to gain access to this file. |
|
||||
|
||||
The `AnyHttpUrl` class can be imported from `pydantic.networks`.
|
||||
|
||||
Ordinarily, no metadata is retrieved from these sources. However,
|
||||
there is special-case code in the installer that looks for HuggingFace
|
||||
and Civitai URLs and fetches the corresponding model metadata from
|
||||
the corresponding repo.
|
||||
|
||||
#### CivitaiModelSource
|
||||
|
||||
This is used for a model that is hosted by the Civitai web site.
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `version_id` | int | None | The ID of the particular version of the desired model. |
|
||||
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
|
||||
|
||||
Civitai has two model IDs, both of which are integers. The `model_id`
|
||||
corresponds to a collection of model versions that may different in
|
||||
arbitrary ways, such as derivation from different checkpoint training
|
||||
steps, SFW vs NSFW generation, pruned vs non-pruned, etc. The
|
||||
`version_id` points to a specific version. Please use the latter.
|
||||
|
||||
Some Civitai models require an access token to download. These can be
|
||||
generated from the Civitai profile page of a logged-in
|
||||
account. Somewhat annoyingly, if you fail to provide the access token
|
||||
when downloading a model that needs it, Civitai generates a redirect
|
||||
to a login page rather than a 403 Forbidden error. The installer
|
||||
attempts to catch this event and issue an informative error
|
||||
message. Otherwise you will get an "unrecognized model suffix" error
|
||||
when the model prober tries to identify the type of the HTML login
|
||||
page.
|
||||
|
||||
#### HFModelSource
|
||||
|
||||
HuggingFace has the most complicated `ModelSource` structure:
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `repo_id` | str | None | The ID of the desired model. |
|
||||
| `variant` | ModelRepoVariant | ModelRepoVariant('fp16') | The desired variant. |
|
||||
| `subfolder` | Path | None | Look for the model in a subfolder of the repo. |
|
||||
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
|
||||
|
||||
|
||||
The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`.
|
||||
|
||||
The `variant` is one of the various diffusers formats that HuggingFace
|
||||
supports and is used to pick out from the hodgepodge of files that in
|
||||
a typical HuggingFace repository the particular components needed for
|
||||
a complete diffusers model. `ModelRepoVariant` is an enum that can be
|
||||
imported from `invokeai.backend.model_manager` and has the following
|
||||
values:
|
||||
|
||||
| **Name** | **String Value** |
|
||||
|----------------------------|---------------------------|
|
||||
| ModelRepoVariant.DEFAULT | "default" |
|
||||
| ModelRepoVariant.FP16 | "fp16" |
|
||||
| ModelRepoVariant.FP32 | "fp32" |
|
||||
| ModelRepoVariant.ONNX | "onnx" |
|
||||
| ModelRepoVariant.OPENVINO | "openvino" |
|
||||
| ModelRepoVariant.FLAX | "flax" |
|
||||
|
||||
You can also pass the string forms to `variant` directly. Note that
|
||||
InvokeAI may not be able to load and run all variants. At the current
|
||||
time, specifying `ModelRepoVariant.DEFAULT` will retrieve model files
|
||||
that are unqualified, e.g. `pytorch_model.safetensors` rather than
|
||||
`pytorch_model.fp16.safetensors`. These are usually the 32-bit
|
||||
safetensors forms of the model.
|
||||
|
||||
If `subfolder` is specified, then the requested model resides in a
|
||||
subfolder of the main model repository. This is typically used to
|
||||
fetch and install VAEs.
|
||||
|
||||
Some models require you to be registered with HuggingFace and logged
|
||||
in. To download these files, you must provide an
|
||||
`access_token`. Internally, if no access token is provided, then
|
||||
`HfFolder.get_token()` will be called to fill it in with the cached
|
||||
one.
|
||||
|
||||
|
||||
#### Monitoring the install job process
|
||||
|
||||
@ -563,7 +656,8 @@ The `ModelInstallJob` class has the following structure:
|
||||
|
||||
| **Attribute** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `status` | `InstallStatus` | An enum of ["waiting", "running", "completed" and "error" |
|
||||
| `id` | `int` | Integer ID for this job |
|
||||
| `status` | `InstallStatus` | An enum of [`waiting`, `downloading`, `running`, `completed`, `error` and `cancelled`]|
|
||||
| `config_in` | `dict` | Overriding configuration values provided by the caller |
|
||||
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
|
||||
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
|
||||
@ -578,30 +672,70 @@ broadcast to the InvokeAI event bus. The events will appear on the bus
|
||||
as an event of type `EventServiceBase.model_event`, a timestamp and
|
||||
the following event names:
|
||||
|
||||
- `model_install_started`
|
||||
##### `model_install_downloading`
|
||||
|
||||
The payload will contain the keys `timestamp` and `source`. The latter
|
||||
indicates the requested model source for installation.
|
||||
For remote models only, `model_install_downloading` events will be issued at regular
|
||||
intervals as the download progresses. The event's payload contains the
|
||||
following keys:
|
||||
|
||||
- `model_install_progress`
|
||||
| **Key** | **Type** | **Description** |
|
||||
|----------------|-----------|------------------|
|
||||
| `source` | str | String representation of the requested source |
|
||||
| `local_path` | str | String representation of the path to the downloading model (usually a temporary directory) |
|
||||
| `bytes` | int | How many bytes downloaded so far |
|
||||
| `total_bytes` | int | Total size of all the files that make up the model |
|
||||
| `parts` | List[Dict]| Information on the progress of the individual files that make up the model |
|
||||
|
||||
Emitted at regular intervals when downloading a remote model, the
|
||||
payload will contain the keys `timestamp`, `source`, `current_bytes`
|
||||
and `total_bytes`. These events are _not_ emitted when a local model
|
||||
already on the filesystem is imported.
|
||||
|
||||
- `model_install_completed`
|
||||
The parts is a list of dictionaries that give information on each of
|
||||
the components pieces of the download. The dictionary's keys are
|
||||
`source`, `local_path`, `bytes` and `total_bytes`, and correspond to
|
||||
the like-named keys in the main event.
|
||||
|
||||
Issued once at the end of a successful installation. The payload will
|
||||
contain the keys `timestamp`, `source` and `key`, where `key` is the
|
||||
ID under which the model has been registered.
|
||||
Note that downloading events will not be issued for local models, and
|
||||
that downloading events occur *before* the running event.
|
||||
|
||||
- `model_install_error`
|
||||
##### `model_install_running`
|
||||
|
||||
`model_install_running` is issued when all the required downloads have completed (if applicable) and the
|
||||
model probing, copying and registration process has now started.
|
||||
|
||||
The payload will contain the key `source`.
|
||||
|
||||
##### `model_install_completed`
|
||||
|
||||
`model_install_completed` is issued once at the end of a successful
|
||||
installation. The payload will contain the keys `source`,
|
||||
`total_bytes` and `key`, where `key` is the ID under which the model
|
||||
has been registered.
|
||||
|
||||
##### `model_install_error`
|
||||
|
||||
`model_install_error` is emitted if the installation process fails for
|
||||
some reason. The payload will contain the keys `source`, `error_type`
|
||||
and `error`. `error_type` is a short message indicating the nature of
|
||||
the error, and `error` is the long traceback to help debug the
|
||||
problem.
|
||||
|
||||
##### `model_install_cancelled`
|
||||
|
||||
`model_install_cancelled` is issued if the model installation is
|
||||
cancelled, or if one or more of its files' downloads are
|
||||
cancelled. The payload will contain `source`.
|
||||
|
||||
##### Following the model status
|
||||
|
||||
You may poll the `ModelInstallJob` object returned by `import_model()`
|
||||
to ascertain the state of the install. The job status can be read from
|
||||
the job's `status` attribute, an `InstallStatus` enum which has the
|
||||
enumerated values `WAITING`, `DOWNLOADING`, `RUNNING`, `COMPLETED`,
|
||||
`ERROR` and `CANCELLED`.
|
||||
|
||||
For convenience, install jobs also provided the following boolean
|
||||
properties: `waiting`, `downloading`, `running`, `complete`, `errored`
|
||||
and `cancelled`, as well as `in_terminal_state`. The last will return
|
||||
True if the job is in the complete, errored or cancelled states.
|
||||
|
||||
Emitted if the installation process fails for some reason. The payload
|
||||
will contain the keys `timestamp`, `source`, `error_type` and
|
||||
`error`. `error_type` is a short message indicating the nature of the
|
||||
error, and `error` is the long traceback to help debug the problem.
|
||||
|
||||
#### Model confguration and probing
|
||||
|
||||
@ -621,17 +755,9 @@ overriding values for any of the model's configuration
|
||||
attributes. Here is an example of setting the
|
||||
`SchedulerPredictionType` and `name` for an sd-2 model:
|
||||
|
||||
This is typically used to set
|
||||
the model's name and description, but can also be used to overcome
|
||||
cases in which automatic probing is unable to (correctly) determine
|
||||
the model's attribute. The most common situation is the
|
||||
`prediction_type` field for sd-2 (and rare sd-1) models. Here is an
|
||||
example of how it works:
|
||||
|
||||
```
|
||||
install_job = installer.import_model(
|
||||
source='stabilityai/stable-diffusion-2-1',
|
||||
variant='fp16',
|
||||
source=HFModelSource(repo_id='stabilityai/stable-diffusion-2-1',variant='fp32'),
|
||||
config=dict(
|
||||
prediction_type=SchedulerPredictionType('v_prediction')
|
||||
name='stable diffusion 2 base model',
|
||||
@ -643,29 +769,38 @@ install_job = installer.import_model(
|
||||
|
||||
This section describes additional methods provided by the installer class.
|
||||
|
||||
#### jobs = installer.wait_for_installs()
|
||||
#### jobs = installer.wait_for_installs([timeout])
|
||||
|
||||
Block until all pending installs are completed or errored and then
|
||||
returns a list of completed jobs.
|
||||
returns a list of completed jobs. The optional `timeout` argument will
|
||||
return from the call if jobs aren't completed in the specified
|
||||
time. An argument of 0 (the default) will block indefinitely.
|
||||
|
||||
#### jobs = installer.list_jobs([source])
|
||||
#### jobs = installer.list_jobs()
|
||||
|
||||
Return a list of all active and complete `ModelInstallJobs`. An
|
||||
optional `source` argument allows you to filter the returned list by a
|
||||
model source string pattern using a partial string match.
|
||||
Return a list of all active and complete `ModelInstallJobs`.
|
||||
|
||||
#### jobs = installer.get_job(source)
|
||||
#### jobs = installer.get_job_by_source(source)
|
||||
|
||||
Return a list of `ModelInstallJob` corresponding to the indicated
|
||||
model source.
|
||||
|
||||
#### jobs = installer.get_job_by_id(id)
|
||||
|
||||
Return a list of `ModelInstallJob` corresponding to the indicated
|
||||
model id.
|
||||
|
||||
#### jobs = installer.cancel_job(job)
|
||||
|
||||
Cancel the indicated job.
|
||||
|
||||
#### installer.prune_jobs
|
||||
|
||||
Remove non-pending jobs (completed or errored) from the job list
|
||||
returned by `list_jobs()` and `get_job()`.
|
||||
Remove jobs that are in a terminal state (i.e. complete, errored or
|
||||
cancelled) from the job list returned by `list_jobs()` and
|
||||
`get_job()`.
|
||||
|
||||
#### installer.app_config, installer.record_store,
|
||||
installer.event_bus
|
||||
#### installer.app_config, installer.record_store, installer.event_bus
|
||||
|
||||
Properties that provide access to the installer's `InvokeAIAppConfig`,
|
||||
`ModelRecordServiceBase` and `EventServiceBase` objects.
|
||||
@ -726,120 +861,6 @@ the API starts up. Its effect is to call `sync_to_config()` to
|
||||
synchronize the model record store database with what's currently on
|
||||
disk.
|
||||
|
||||
# The remainder of this documentation is provisional, pending implementation of the Download and Load services
|
||||
|
||||
## Let's get loaded, the lowdown on ModelLoadService
|
||||
|
||||
The `ModelLoadService` is responsible for loading a named model into
|
||||
memory so that it can be used for inference. Despite the fact that it
|
||||
does a lot under the covers, it is very straightforward to use.
|
||||
|
||||
An application-wide model loader is created at API initialization time
|
||||
and stored in
|
||||
`ApiDependencies.invoker.services.model_loader`. However, you can
|
||||
create alternative instances if you wish.
|
||||
|
||||
### Creating a ModelLoadService object
|
||||
|
||||
The class is defined in
|
||||
`invokeai.app.services.model_loader_service`. It is initialized with
|
||||
an InvokeAIAppConfig object, from which it gets configuration
|
||||
information such as the user's desired GPU and precision, and with a
|
||||
previously-created `ModelRecordServiceBase` object, from which it
|
||||
loads the requested model's configuration information.
|
||||
|
||||
Here is a typical initialization pattern:
|
||||
|
||||
```
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.app.services.model_loader_service import ModelLoadService
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
store = ModelRecordServiceBase.open(config)
|
||||
loader = ModelLoadService(config, store)
|
||||
```
|
||||
|
||||
Note that we are relying on the contents of the application
|
||||
configuration to choose the implementation of
|
||||
`ModelRecordServiceBase`.
|
||||
|
||||
### get_model(key, [submodel_type], [context]) -> ModelInfo:
|
||||
|
||||
*** TO DO: change to get_model(key, context=None, **kwargs)
|
||||
|
||||
The `get_model()` method, like its similarly-named cousin in
|
||||
`ModelRecordService`, receives the unique key that identifies the
|
||||
model. It loads the model into memory, gets the model ready for use,
|
||||
and returns a `ModelInfo` object.
|
||||
|
||||
The optional second argument, `subtype` is a `SubModelType` string
|
||||
enum, such as "vae". It is mandatory when used with a main model, and
|
||||
is used to select which part of the main model to load.
|
||||
|
||||
The optional third argument, `context` can be provided by
|
||||
an invocation to trigger model load event reporting. See below for
|
||||
details.
|
||||
|
||||
The returned `ModelInfo` object shares some fields in common with
|
||||
`ModelConfigBase`, but is otherwise a completely different beast:
|
||||
|
||||
| **Field Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `key` | str | The model key derived from the ModelRecordService database |
|
||||
| `name` | str | Name of this model |
|
||||
| `base_model` | BaseModelType | Base model for this model |
|
||||
| `type` | ModelType or SubModelType | Either the model type (non-main) or the submodel type (main models)|
|
||||
| `location` | Path or str | Location of the model on the filesystem |
|
||||
| `precision` | torch.dtype | The torch.precision to use for inference |
|
||||
| `context` | ModelCache.ModelLocker | A context class used to lock the model in VRAM while in use |
|
||||
|
||||
The types for `ModelInfo` and `SubModelType` can be imported from
|
||||
`invokeai.app.services.model_loader_service`.
|
||||
|
||||
To use the model, you use the `ModelInfo` as a context manager using
|
||||
the following pattern:
|
||||
|
||||
```
|
||||
model_info = loader.get_model('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||
with model_info as vae:
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
The `vae` model will stay locked in the GPU during the period of time
|
||||
it is in the context manager's scope.
|
||||
|
||||
`get_model()` may raise any of the following exceptions:
|
||||
|
||||
- `UnknownModelException` -- key not in database
|
||||
- `ModelNotFoundException` -- key in database but model not found at path
|
||||
- `InvalidModelException` -- the model is guilty of a variety of sins
|
||||
|
||||
** TO DO: ** Resolve discrepancy between ModelInfo.location and
|
||||
ModelConfig.path.
|
||||
|
||||
### Emitting model loading events
|
||||
|
||||
When the `context` argument is passed to `get_model()`, it will
|
||||
retrieve the invocation event bus from the passed `InvocationContext`
|
||||
object to emit events on the invocation bus. The two events are
|
||||
"model_load_started" and "model_load_completed". Both carry the
|
||||
following payload:
|
||||
|
||||
```
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
precision=str(model_info.precision),
|
||||
)
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
## Get on line: The Download Queue
|
||||
@ -879,7 +900,6 @@ following fields:
|
||||
| `job_started` | float | | Timestamp for when the job started running |
|
||||
| `job_ended` | float | | Timestamp for when the job completed or errored out |
|
||||
| `job_sequence` | int | | A counter that is incremented each time a model is dequeued |
|
||||
| `preserve_partial_downloads`| bool | False | Resume partial downloads when relaunched. |
|
||||
| `error` | Exception | | A copy of the Exception that caused an error during download |
|
||||
|
||||
When you create a job, you can assign it a `priority`. If multiple
|
||||
@ -1184,3 +1204,362 @@ other resources that it might have been using.
|
||||
This will start/pause/cancel all jobs that have been submitted to the
|
||||
queue and have not yet reached a terminal state.
|
||||
|
||||
***
|
||||
|
||||
## This Meta be Good: Model Metadata Storage
|
||||
|
||||
The modules found under `invokeai.backend.model_manager.metadata`
|
||||
provide a straightforward API for fetching model metadatda from online
|
||||
repositories. Currently two repositories are supported: HuggingFace
|
||||
and Civitai. However, the modules are easily extended for additional
|
||||
repos, provided that they have defined APIs for metadata access.
|
||||
|
||||
Metadata comprises any descriptive information that is not essential
|
||||
for getting the model to run. For example "author" is metadata, while
|
||||
"type", "base" and "format" are not. The latter fields are part of the
|
||||
model's config, as defined in `invokeai.backend.model_manager.config`.
|
||||
|
||||
### Example Usage:
|
||||
|
||||
```
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
CivitaiMetadata
|
||||
ModelMetadataStore,
|
||||
)
|
||||
# to access the initialized sql database
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
|
||||
civitai = CivitaiMetadataFetch()
|
||||
|
||||
# fetch the metadata
|
||||
model_metadata = civitai.from_url("https://civitai.com/models/215796")
|
||||
|
||||
# get some common metadata fields
|
||||
author = model_metadata.author
|
||||
tags = model_metadata.tags
|
||||
|
||||
# get some Civitai-specific fields
|
||||
assert isinstance(model_metadata, CivitaiMetadata)
|
||||
|
||||
trained_words = model_metadata.trained_words
|
||||
base_model = model_metadata.base_model_trained_on
|
||||
thumbnail = model_metadata.thumbnail_url
|
||||
|
||||
# cache the metadata to the database using the key corresponding to
|
||||
# an existing model config record in the `model_config` table
|
||||
sql_cache = ModelMetadataStore(ApiDependencies.invoker.services.db)
|
||||
sql_cache.add_metadata('fb237ace520b6716adc98bcb16e8462c', model_metadata)
|
||||
|
||||
# now we can search the database by tag, author or model name
|
||||
# matches will contain a list of model keys that match the search
|
||||
matches = sql_cache.search_by_tag({"tool", "turbo"})
|
||||
```
|
||||
|
||||
### Structure of the Metadata objects
|
||||
|
||||
There is a short class hierarchy of Metadata objects, all of which
|
||||
descend from the Pydantic `BaseModel`.
|
||||
|
||||
#### `ModelMetadataBase`
|
||||
|
||||
This is the common base class for metadata:
|
||||
|
||||
| **Field Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `name` | str | Repository's name for the model |
|
||||
| `author` | str | Model's author |
|
||||
| `tags` | Set[str] | Model tags |
|
||||
|
||||
|
||||
Note that the model config record also has a `name` field. It is
|
||||
intended that the config record version be locally customizable, while
|
||||
the metadata version is read-only. However, enforcing this is expected
|
||||
to be part of the business logic.
|
||||
|
||||
Descendents of the base add additional fields.
|
||||
|
||||
#### `HuggingFaceMetadata`
|
||||
|
||||
This descends from `ModelMetadataBase` and adds the following fields:
|
||||
|
||||
| **Field Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `type` | Literal["huggingface"] | Used for the discriminated union of metadata classes|
|
||||
| `id` | str | HuggingFace repo_id |
|
||||
| `tag_dict` | Dict[str, Any] | A dictionary of tag/value pairs provided in addition to `tags` |
|
||||
| `last_modified`| datetime | Date of last commit of this model to the repo |
|
||||
| `files` | List[Path] | List of the files in the model repo |
|
||||
|
||||
|
||||
#### `CivitaiMetadata`
|
||||
|
||||
This descends from `ModelMetadataBase` and adds the following fields:
|
||||
|
||||
| **Field Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `type` | Literal["civitai"] | Used for the discriminated union of metadata classes|
|
||||
| `id` | int | Civitai model id |
|
||||
| `version_name` | str | Name of this version of the model (distinct from model name) |
|
||||
| `version_id` | int | Civitai model version id (distinct from model id) |
|
||||
| `created` | datetime | Date this version of the model was created |
|
||||
| `updated` | datetime | Date this version of the model was last updated |
|
||||
| `published` | datetime | Date this version of the model was published to Civitai |
|
||||
| `description` | str | Model description. Quite verbose and contains HTML tags |
|
||||
| `version_description` | str | Model version description, usually describes changes to the model |
|
||||
| `nsfw` | bool | Whether the model tends to generate NSFW content |
|
||||
| `restrictions` | LicenseRestrictions | An object that describes what is and isn't allowed with this model |
|
||||
| `trained_words`| Set[str] | Trigger words for this model, if any |
|
||||
| `download_url` | AnyHttpUrl | URL for downloading this version of the model |
|
||||
| `base_model_trained_on` | str | Name of the model that this version was trained on |
|
||||
| `thumbnail_url` | AnyHttpUrl | URL to access a representative thumbnail image of the model's output |
|
||||
| `weight_min` | int | For LoRA sliders, the minimum suggested weight to apply |
|
||||
| `weight_max` | int | For LoRA sliders, the maximum suggested weight to apply |
|
||||
|
||||
Note that `weight_min` and `weight_max` are not currently populated
|
||||
and take the default values of (-1.0, +2.0). The issue is that these
|
||||
values aren't part of the structured data but appear in the text
|
||||
description. Some regular expression or LLM coding may be able to
|
||||
extract these values.
|
||||
|
||||
Also be aware that `base_model_trained_on` is free text and doesn't
|
||||
correspond to our `ModelType` enum.
|
||||
|
||||
`CivitaiMetadata` also defines some convenience properties relating to
|
||||
licensing restrictions: `credit_required`, `allow_commercial_use`,
|
||||
`allow_derivatives` and `allow_different_license`.
|
||||
|
||||
#### `AnyModelRepoMetadata`
|
||||
|
||||
This is a discriminated Union of `CivitaiMetadata` and
|
||||
`HuggingFaceMetadata`.
|
||||
|
||||
### Fetching Metadata from Online Repos
|
||||
|
||||
The `HuggingFaceMetadataFetch` and `CivitaiMetadataFetch` classes will
|
||||
retrieve metadata from their corresponding repositories and return
|
||||
`AnyModelRepoMetadata` objects. Their base class
|
||||
`ModelMetadataFetchBase` is an abstract class that defines two
|
||||
methods: `from_url()` and `from_id()`. The former accepts the type of
|
||||
model URLs that the user will try to cut and paste into the model
|
||||
import form. The latter accepts a string ID in the format recognized
|
||||
by the repository of choice. Both methods return an
|
||||
`AnyModelRepoMetadata`.
|
||||
|
||||
The base class also has a class method `from_json()` which will take
|
||||
the JSON representation of a `ModelMetadata` object, validate it, and
|
||||
return the corresponding `AnyModelRepoMetadata` object.
|
||||
|
||||
When initializing one of the metadata fetching classes, you may
|
||||
provide a `requests.Session` argument. This allows you to customize
|
||||
the low-level HTTP fetch requests and is used, for instance, in the
|
||||
testing suite to avoid hitting the internet.
|
||||
|
||||
The HuggingFace and Civitai fetcher subclasses add additional
|
||||
repo-specific fetching methods:
|
||||
|
||||
|
||||
#### HuggingFaceMetadataFetch
|
||||
|
||||
This overrides its base class `from_json()` method to return a
|
||||
`HuggingFaceMetadata` object directly.
|
||||
|
||||
#### CivitaiMetadataFetch
|
||||
|
||||
This adds the following methods:
|
||||
|
||||
`from_civitai_modelid()` This takes the ID of a model, finds the
|
||||
default version of the model, and then retrieves the metadata for
|
||||
that version, returning a `CivitaiMetadata` object directly.
|
||||
|
||||
`from_civitai_versionid()` This takes the ID of a model version and
|
||||
retrieves its metadata. Functionally equivalent to `from_id()`, the
|
||||
only difference is that it returna a `CivitaiMetadata` object rather
|
||||
than an `AnyModelRepoMetadata`.
|
||||
|
||||
|
||||
### Metadata Storage
|
||||
|
||||
The `ModelMetadataStore` provides a simple facility to store model
|
||||
metadata in the `invokeai.db` database. The data is stored as a JSON
|
||||
blob, with a few common fields (`name`, `author`, `tags`) broken out
|
||||
to be searchable.
|
||||
|
||||
When a metadata object is saved to the database, it is identified
|
||||
using the model key, _and this key must correspond to an existing
|
||||
model key in the model_config table_. There is a foreign key integrity
|
||||
constraint between the `model_config.id` field and the
|
||||
`model_metadata.id` field such that if you attempt to save metadata
|
||||
under an unknown key, the attempt will result in an
|
||||
`UnknownModelException`. Likewise, when a model is deleted from
|
||||
`model_config`, the deletion of the corresponding metadata record will
|
||||
be triggered.
|
||||
|
||||
Tags are stored in a normalized fashion in the tables `model_tags` and
|
||||
`tags`. Triggers keep the tag table in sync with the `model_metadata`
|
||||
table.
|
||||
|
||||
To create the storage object, initialize it with the InvokeAI
|
||||
`SqliteDatabase` object. This is often done this way:
|
||||
|
||||
```
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
metadata_store = ModelMetadataStore(ApiDependencies.invoker.services.db)
|
||||
```
|
||||
|
||||
You can then access the storage with the following methods:
|
||||
|
||||
#### `add_metadata(key, metadata)`
|
||||
|
||||
Add the metadata using a previously-defined model key.
|
||||
|
||||
There is currently no `delete_metadata()` method. The metadata will
|
||||
persist until the matching config is deleted from the `model_config`
|
||||
table.
|
||||
|
||||
#### `get_metadata(key) -> AnyModelRepoMetadata`
|
||||
|
||||
Retrieve the metadata corresponding to the model key.
|
||||
|
||||
#### `update_metadata(key, new_metadata)`
|
||||
|
||||
Update an existing metadata record with new metadata.
|
||||
|
||||
#### `search_by_tag(tags: Set[str]) -> Set[str]`
|
||||
|
||||
Given a set of tags, find models that are tagged with them. If
|
||||
multiple tags are provided then a matching model must be tagged with
|
||||
*all* the tags in the set. This method returns a set of model keys and
|
||||
is intended to be used in conjunction with the `ModelRecordService`:
|
||||
|
||||
```
|
||||
model_config_store = ApiDependencies.invoker.services.model_records
|
||||
matches = metadata_store.search_by_tag({'license:other'})
|
||||
models = [model_config_store.get(x) for x in matches]
|
||||
```
|
||||
|
||||
#### `search_by_name(name: str) -> Set[str]
|
||||
|
||||
Find all model metadata records that have the given name and return a
|
||||
set of keys to the corresponding model config objects.
|
||||
|
||||
#### `search_by_author(author: str) -> Set[str]
|
||||
|
||||
Find all model metadata records that have the given author and return
|
||||
a set of keys to the corresponding model config objects.
|
||||
|
||||
# The remainder of this documentation is provisional, pending implementation of the Load service
|
||||
|
||||
## Let's get loaded, the lowdown on ModelLoadService
|
||||
|
||||
The `ModelLoadService` is responsible for loading a named model into
|
||||
memory so that it can be used for inference. Despite the fact that it
|
||||
does a lot under the covers, it is very straightforward to use.
|
||||
|
||||
An application-wide model loader is created at API initialization time
|
||||
and stored in
|
||||
`ApiDependencies.invoker.services.model_loader`. However, you can
|
||||
create alternative instances if you wish.
|
||||
|
||||
### Creating a ModelLoadService object
|
||||
|
||||
The class is defined in
|
||||
`invokeai.app.services.model_loader_service`. It is initialized with
|
||||
an InvokeAIAppConfig object, from which it gets configuration
|
||||
information such as the user's desired GPU and precision, and with a
|
||||
previously-created `ModelRecordServiceBase` object, from which it
|
||||
loads the requested model's configuration information.
|
||||
|
||||
Here is a typical initialization pattern:
|
||||
|
||||
```
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.app.services.model_loader_service import ModelLoadService
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
store = ModelRecordServiceBase.open(config)
|
||||
loader = ModelLoadService(config, store)
|
||||
```
|
||||
|
||||
Note that we are relying on the contents of the application
|
||||
configuration to choose the implementation of
|
||||
`ModelRecordServiceBase`.
|
||||
|
||||
### get_model(key, [submodel_type], [context]) -> ModelInfo:
|
||||
|
||||
*** TO DO: change to get_model(key, context=None, **kwargs)
|
||||
|
||||
The `get_model()` method, like its similarly-named cousin in
|
||||
`ModelRecordService`, receives the unique key that identifies the
|
||||
model. It loads the model into memory, gets the model ready for use,
|
||||
and returns a `ModelInfo` object.
|
||||
|
||||
The optional second argument, `subtype` is a `SubModelType` string
|
||||
enum, such as "vae". It is mandatory when used with a main model, and
|
||||
is used to select which part of the main model to load.
|
||||
|
||||
The optional third argument, `context` can be provided by
|
||||
an invocation to trigger model load event reporting. See below for
|
||||
details.
|
||||
|
||||
The returned `ModelInfo` object shares some fields in common with
|
||||
`ModelConfigBase`, but is otherwise a completely different beast:
|
||||
|
||||
| **Field Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `key` | str | The model key derived from the ModelRecordService database |
|
||||
| `name` | str | Name of this model |
|
||||
| `base_model` | BaseModelType | Base model for this model |
|
||||
| `type` | ModelType or SubModelType | Either the model type (non-main) or the submodel type (main models)|
|
||||
| `location` | Path or str | Location of the model on the filesystem |
|
||||
| `precision` | torch.dtype | The torch.precision to use for inference |
|
||||
| `context` | ModelCache.ModelLocker | A context class used to lock the model in VRAM while in use |
|
||||
|
||||
The types for `ModelInfo` and `SubModelType` can be imported from
|
||||
`invokeai.app.services.model_loader_service`.
|
||||
|
||||
To use the model, you use the `ModelInfo` as a context manager using
|
||||
the following pattern:
|
||||
|
||||
```
|
||||
model_info = loader.get_model('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||
with model_info as vae:
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
The `vae` model will stay locked in the GPU during the period of time
|
||||
it is in the context manager's scope.
|
||||
|
||||
`get_model()` may raise any of the following exceptions:
|
||||
|
||||
- `UnknownModelException` -- key not in database
|
||||
- `ModelNotFoundException` -- key in database but model not found at path
|
||||
- `InvalidModelException` -- the model is guilty of a variety of sins
|
||||
|
||||
** TO DO: ** Resolve discrepancy between ModelInfo.location and
|
||||
ModelConfig.path.
|
||||
|
||||
### Emitting model loading events
|
||||
|
||||
When the `context` argument is passed to `get_model()`, it will
|
||||
retrieve the invocation event bus from the passed `InvocationContext`
|
||||
object to emit events on the invocation bus. The two events are
|
||||
"model_load_started" and "model_load_completed". Both carry the
|
||||
following payload:
|
||||
|
||||
```
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
precision=str(model_info.precision),
|
||||
)
|
||||
```
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
@ -61,7 +62,7 @@ class ApiDependencies:
|
||||
invoker: Invoker
|
||||
|
||||
@staticmethod
|
||||
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
|
||||
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None:
|
||||
logger.info(f"InvokeAI version {__version__}")
|
||||
logger.info(f"Root directory = {str(config.root_path)}")
|
||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||
@ -87,8 +88,13 @@ class ApiDependencies:
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
model_record_service = ModelRecordServiceSQL(db=db)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
metadata_store = ModelMetadataStore(db=db)
|
||||
model_install_service = ModelInstallService(
|
||||
app_config=config, record_store=model_record_service, event_bus=events
|
||||
app_config=config,
|
||||
record_store=model_record_service,
|
||||
download_queue=download_queue_service,
|
||||
metadata_store=metadata_store,
|
||||
event_bus=events,
|
||||
)
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
@ -131,6 +137,6 @@ class ApiDependencies:
|
||||
db.clean()
|
||||
|
||||
@staticmethod
|
||||
def shutdown():
|
||||
def shutdown() -> None:
|
||||
if ApiDependencies.invoker:
|
||||
ApiDependencies.invoker.stop()
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
@ -16,13 +16,18 @@ from invokeai.app.services.model_install import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelRecordOrderBy,
|
||||
ModelSummary,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
@ -32,11 +37,20 @@ model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager
|
||||
class ModelsList(BaseModel):
|
||||
"""Return list of configs."""
|
||||
|
||||
models: list[AnyModelConfig]
|
||||
models: List[AnyModelConfig]
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class ModelTagSet(BaseModel):
|
||||
"""Return tags for a set of models."""
|
||||
|
||||
key: str
|
||||
name: str
|
||||
author: str
|
||||
tags: Set[str]
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/",
|
||||
operation_id="list_model_records",
|
||||
@ -45,7 +59,7 @@ async def list_model_records(
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
|
||||
model_format: Optional[str] = Query(
|
||||
model_format: Optional[ModelFormat] = Query(
|
||||
default=None, description="Exact match on the format of the model (e.g. 'diffusers')"
|
||||
),
|
||||
) -> ModelsList:
|
||||
@ -86,6 +100,59 @@ async def get_model_record(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.get("/meta", operation_id="list_model_summary")
|
||||
async def list_model_summary(
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of models per page"),
|
||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Gets a page of model summary data."""
|
||||
return ApiDependencies.invoker.services.model_records.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/meta/i/{key}",
|
||||
operation_id="get_model_metadata",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "No metadata available"},
|
||||
},
|
||||
)
|
||||
async def get_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Get a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
result = record_store.get_metadata(key)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="No metadata for a model with this key")
|
||||
return result
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/tags",
|
||||
operation_id="list_tags",
|
||||
)
|
||||
async def list_tags() -> Set[str]:
|
||||
"""Get a unique set of all the model tags."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
return record_store.list_tags()
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/tags/search",
|
||||
operation_id="search_by_metadata_tags",
|
||||
)
|
||||
async def search_by_metadata_tags(
|
||||
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
results = record_store.search_by_metadata_tag(tags)
|
||||
return ModelsList(models=results)
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
"/i/{key}",
|
||||
operation_id="update_model_record",
|
||||
@ -159,9 +226,7 @@ async def del_model_record(
|
||||
async def add_model_record(
|
||||
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model using the configuration information appropriate for its type.
|
||||
"""
|
||||
"""Add a model using the configuration information appropriate for its type."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
if config.key == "<NOKEY>":
|
||||
@ -243,7 +308,7 @@ async def import_model(
|
||||
Installation occurs in the background. Either use list_model_install_jobs()
|
||||
to poll for completion, or listen on the event bus for the following events:
|
||||
|
||||
"model_install_started"
|
||||
"model_install_running"
|
||||
"model_install_completed"
|
||||
"model_install_error"
|
||||
|
||||
@ -279,16 +344,46 @@ async def import_model(
|
||||
operation_id="list_model_install_jobs",
|
||||
)
|
||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
"""
|
||||
Return list of model install jobs.
|
||||
|
||||
If the optional 'source' argument is provided, then the list will be filtered
|
||||
for partial string matches against the install source.
|
||||
"""
|
||||
"""Return list of model install jobs."""
|
||||
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs()
|
||||
return jobs
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/import/{id}",
|
||||
operation_id="get_model_install_job",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
404: {"description": "No such job"},
|
||||
},
|
||||
)
|
||||
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
|
||||
"""Return model install job corresponding to the given source."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.model_install.get_job_by_id(id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.delete(
|
||||
"/import/{id}",
|
||||
operation_id="cancel_model_install_job",
|
||||
responses={
|
||||
201: {"description": "The job was cancelled successfully"},
|
||||
415: {"description": "No such job"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
|
||||
"""Cancel the model install job(s) corresponding to the given job ID."""
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
try:
|
||||
job = installer.get_job_by_id(id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=415, detail=str(e))
|
||||
installer.cancel_job(job)
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
"/import",
|
||||
operation_id="prune_model_install_jobs",
|
||||
@ -298,9 +393,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
},
|
||||
)
|
||||
async def prune_model_install_jobs() -> Response:
|
||||
"""
|
||||
Prune all completed and errored jobs from the install job list.
|
||||
"""
|
||||
"""Prune all completed and errored jobs from the install job list."""
|
||||
ApiDependencies.invoker.services.model_install.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
@ -315,7 +408,9 @@ async def prune_model_install_jobs() -> Response:
|
||||
)
|
||||
async def sync_models_to_config() -> Response:
|
||||
"""
|
||||
Traverse the models and autoimport directories. Model files without a corresponding
|
||||
Traverse the models and autoimport directories.
|
||||
|
||||
Model files without a corresponding
|
||||
record in the database are added. Orphan records without a models file are deleted.
|
||||
"""
|
||||
ApiDependencies.invoker.services.model_install.sync_to_config()
|
||||
|
@ -209,7 +209,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""Configuration object for InvokeAI App."""
|
||||
|
||||
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
|
||||
singleton_init: ClassVar[Optional[Dict]] = None
|
||||
singleton_init: ClassVar[Optional[Dict[str, Any]]] = None
|
||||
|
||||
# fmt: off
|
||||
type: Literal["InvokeAI"] = "InvokeAI"
|
||||
@ -301,8 +301,8 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
self,
|
||||
argv: Optional[list[str]] = None,
|
||||
conf: Optional[DictConfig] = None,
|
||||
clobber=False,
|
||||
):
|
||||
clobber: Optional[bool] = False,
|
||||
) -> None:
|
||||
"""
|
||||
Update settings with contents of init file, environment, and command-line settings.
|
||||
|
||||
@ -337,7 +337,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, **kwargs: Dict[str, Any]) -> InvokeAIAppConfig:
|
||||
def get_config(cls, **kwargs: Any) -> InvokeAIAppConfig:
|
||||
"""Return a singleton InvokeAIAppConfig configuration object."""
|
||||
if (
|
||||
cls.singleton_config is None
|
||||
@ -455,7 +455,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
return _find_root()
|
||||
|
||||
|
||||
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
|
||||
def get_invokeai_config(**kwargs: Any) -> InvokeAIAppConfig:
|
||||
"""Legacy function which returns InvokeAIAppConfig.get_config()."""
|
||||
return InvokeAIAppConfig.get_config(**kwargs)
|
||||
|
||||
|
@ -34,6 +34,7 @@ class ServiceInactiveException(Exception):
|
||||
|
||||
|
||||
DownloadEventHandler = Callable[["DownloadJob"], None]
|
||||
DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
|
||||
|
||||
|
||||
@total_ordering
|
||||
@ -55,6 +56,7 @@ class DownloadJob(BaseModel):
|
||||
job_ended: Optional[str] = Field(
|
||||
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
|
||||
)
|
||||
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
|
||||
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
||||
total_bytes: int = Field(default=0, description="Total file size (bytes)")
|
||||
|
||||
@ -70,7 +72,11 @@ class DownloadJob(BaseModel):
|
||||
_on_progress: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
||||
_on_complete: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
||||
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
||||
_on_error: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
||||
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of the string representation of this object, for indexing."""
|
||||
return hash(str(self))
|
||||
|
||||
def __le__(self, other: "DownloadJob") -> bool:
|
||||
"""Return True if this job's priority is less than another's."""
|
||||
@ -87,6 +93,26 @@ class DownloadJob(BaseModel):
|
||||
"""Call to cancel the job."""
|
||||
return self._cancelled
|
||||
|
||||
@property
|
||||
def complete(self) -> bool:
|
||||
"""Return true if job completed without errors."""
|
||||
return self.status == DownloadJobStatus.COMPLETED
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""Return true if the job is running."""
|
||||
return self.status == DownloadJobStatus.RUNNING
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
"""Return true if the job is errored."""
|
||||
return self.status == DownloadJobStatus.ERROR
|
||||
|
||||
@property
|
||||
def in_terminal_state(self) -> bool:
|
||||
"""Return true if job has finished, one way or another."""
|
||||
return self.status not in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]
|
||||
|
||||
@property
|
||||
def on_start(self) -> Optional[DownloadEventHandler]:
|
||||
"""Return the on_start event handler."""
|
||||
@ -103,7 +129,7 @@ class DownloadJob(BaseModel):
|
||||
return self._on_complete
|
||||
|
||||
@property
|
||||
def on_error(self) -> Optional[DownloadEventHandler]:
|
||||
def on_error(self) -> Optional[DownloadExceptionHandler]:
|
||||
"""Return the on_error event handler."""
|
||||
return self._on_error
|
||||
|
||||
@ -118,7 +144,7 @@ class DownloadJob(BaseModel):
|
||||
on_progress: Optional[DownloadEventHandler] = None,
|
||||
on_complete: Optional[DownloadEventHandler] = None,
|
||||
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||
on_error: Optional[DownloadEventHandler] = None,
|
||||
on_error: Optional[DownloadExceptionHandler] = None,
|
||||
) -> None:
|
||||
"""Set the callbacks for download events."""
|
||||
self._on_start = on_start
|
||||
@ -150,10 +176,10 @@ class DownloadQueueServiceBase(ABC):
|
||||
on_progress: Optional[DownloadEventHandler] = None,
|
||||
on_complete: Optional[DownloadEventHandler] = None,
|
||||
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||
on_error: Optional[DownloadEventHandler] = None,
|
||||
on_error: Optional[DownloadExceptionHandler] = None,
|
||||
) -> DownloadJob:
|
||||
"""
|
||||
Create a download job.
|
||||
Create and enqueue download job.
|
||||
|
||||
:param source: Source of the download as a URL.
|
||||
:param dest: Path to download to. See below.
|
||||
@ -175,6 +201,25 @@ class DownloadQueueServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def submit_download_job(
|
||||
self,
|
||||
job: DownloadJob,
|
||||
on_start: Optional[DownloadEventHandler] = None,
|
||||
on_progress: Optional[DownloadEventHandler] = None,
|
||||
on_complete: Optional[DownloadEventHandler] = None,
|
||||
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||
on_error: Optional[DownloadExceptionHandler] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Enqueue a download job.
|
||||
|
||||
:param job: The DownloadJob
|
||||
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
|
||||
events.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_jobs(self) -> List[DownloadJob]:
|
||||
"""
|
||||
@ -197,21 +242,21 @@ class DownloadQueueServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_all_jobs(self):
|
||||
def cancel_all_jobs(self) -> None:
|
||||
"""Cancel all active and enquedjobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prune_jobs(self):
|
||||
def prune_jobs(self) -> None:
|
||||
"""Prune completed and errored queue items from the job list."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, job: DownloadJob):
|
||||
def cancel_job(self, job: DownloadJob) -> None:
|
||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def join(self):
|
||||
def join(self) -> None:
|
||||
"""Wait until all jobs are off the queue."""
|
||||
pass
|
||||
|
@ -5,10 +5,9 @@ import os
|
||||
import re
|
||||
import threading
|
||||
import traceback
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
@ -21,6 +20,7 @@ from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .download_base import (
|
||||
DownloadEventHandler,
|
||||
DownloadExceptionHandler,
|
||||
DownloadJob,
|
||||
DownloadJobCancelledException,
|
||||
DownloadJobStatus,
|
||||
@ -36,18 +36,6 @@ DOWNLOAD_CHUNK_SIZE = 100000
|
||||
class DownloadQueueService(DownloadQueueServiceBase):
|
||||
"""Class for queued download of models."""
|
||||
|
||||
_jobs: Dict[int, DownloadJob]
|
||||
_max_parallel_dl: int = 5
|
||||
_worker_pool: Set[threading.Thread]
|
||||
_queue: PriorityQueue[DownloadJob]
|
||||
_stop_event: threading.Event
|
||||
_lock: threading.Lock
|
||||
_logger: Logger
|
||||
_events: Optional[EventServiceBase] = None
|
||||
_next_job_id: int = 0
|
||||
_accept_download_requests: bool = False
|
||||
_requests: requests.sessions.Session
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_parallel_dl: int = 5,
|
||||
@ -99,6 +87,33 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
self._stop_event.set()
|
||||
self._worker_pool.clear()
|
||||
|
||||
def submit_download_job(
|
||||
self,
|
||||
job: DownloadJob,
|
||||
on_start: Optional[DownloadEventHandler] = None,
|
||||
on_progress: Optional[DownloadEventHandler] = None,
|
||||
on_complete: Optional[DownloadEventHandler] = None,
|
||||
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||
on_error: Optional[DownloadExceptionHandler] = None,
|
||||
) -> None:
|
||||
"""Enqueue a download job."""
|
||||
if not self._accept_download_requests:
|
||||
raise ServiceInactiveException(
|
||||
"The download service is not currently accepting requests. Please call start() to initialize the service."
|
||||
)
|
||||
with self._lock:
|
||||
job.id = self._next_job_id
|
||||
self._next_job_id += 1
|
||||
job.set_callbacks(
|
||||
on_start=on_start,
|
||||
on_progress=on_progress,
|
||||
on_complete=on_complete,
|
||||
on_cancelled=on_cancelled,
|
||||
on_error=on_error,
|
||||
)
|
||||
self._jobs[job.id] = job
|
||||
self._queue.put(job)
|
||||
|
||||
def download(
|
||||
self,
|
||||
source: AnyHttpUrl,
|
||||
@ -109,32 +124,27 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
on_progress: Optional[DownloadEventHandler] = None,
|
||||
on_complete: Optional[DownloadEventHandler] = None,
|
||||
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||
on_error: Optional[DownloadEventHandler] = None,
|
||||
on_error: Optional[DownloadExceptionHandler] = None,
|
||||
) -> DownloadJob:
|
||||
"""Create a download job and return its ID."""
|
||||
"""Create and enqueue a download job and return it."""
|
||||
if not self._accept_download_requests:
|
||||
raise ServiceInactiveException(
|
||||
"The download service is not currently accepting requests. Please call start() to initialize the service."
|
||||
)
|
||||
with self._lock:
|
||||
id = self._next_job_id
|
||||
self._next_job_id += 1
|
||||
job = DownloadJob(
|
||||
id=id,
|
||||
source=source,
|
||||
dest=dest,
|
||||
priority=priority,
|
||||
access_token=access_token,
|
||||
)
|
||||
job.set_callbacks(
|
||||
on_start=on_start,
|
||||
on_progress=on_progress,
|
||||
on_complete=on_complete,
|
||||
on_cancelled=on_cancelled,
|
||||
on_error=on_error,
|
||||
)
|
||||
self._jobs[id] = job
|
||||
self._queue.put(job)
|
||||
job = DownloadJob(
|
||||
source=source,
|
||||
dest=dest,
|
||||
priority=priority,
|
||||
access_token=access_token,
|
||||
)
|
||||
self.submit_download_job(
|
||||
job,
|
||||
on_start=on_start,
|
||||
on_progress=on_progress,
|
||||
on_complete=on_complete,
|
||||
on_cancelled=on_cancelled,
|
||||
on_error=on_error,
|
||||
)
|
||||
return job
|
||||
|
||||
def join(self) -> None:
|
||||
@ -150,7 +160,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
with self._lock:
|
||||
to_delete = set()
|
||||
for job_id, job in self._jobs.items():
|
||||
if self._in_terminal_state(job):
|
||||
if job.in_terminal_state:
|
||||
to_delete.add(job_id)
|
||||
for job_id in to_delete:
|
||||
del self._jobs[job_id]
|
||||
@ -172,19 +182,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
with self._lock:
|
||||
job.cancel()
|
||||
|
||||
def cancel_all_jobs(self, preserve_partial: bool = False) -> None:
|
||||
def cancel_all_jobs(self) -> None:
|
||||
"""Cancel all jobs (those not in enqueued, running or paused state)."""
|
||||
for job in self._jobs.values():
|
||||
if not self._in_terminal_state(job):
|
||||
if not job.in_terminal_state:
|
||||
self.cancel_job(job)
|
||||
|
||||
def _in_terminal_state(self, job: DownloadJob) -> bool:
|
||||
return job.status in [
|
||||
DownloadJobStatus.COMPLETED,
|
||||
DownloadJobStatus.CANCELLED,
|
||||
DownloadJobStatus.ERROR,
|
||||
]
|
||||
|
||||
def _start_workers(self, max_workers: int) -> None:
|
||||
"""Start the requested number of worker threads."""
|
||||
self._stop_event.clear()
|
||||
@ -214,7 +217,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
except (OSError, HTTPError) as excp:
|
||||
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
||||
job.error = traceback.format_exc()
|
||||
self._signal_job_error(job)
|
||||
self._signal_job_error(job, excp)
|
||||
except DownloadJobCancelledException:
|
||||
self._signal_job_cancelled(job)
|
||||
self._cleanup_cancelled_job(job)
|
||||
@ -235,6 +238,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
resp = self._requests.get(str(url), headers=header, stream=True)
|
||||
if not resp.ok:
|
||||
raise HTTPError(resp.reason)
|
||||
|
||||
job.content_type = resp.headers.get("Content-Type")
|
||||
content_length = int(resp.headers.get("content-length", 0))
|
||||
job.total_bytes = content_length
|
||||
|
||||
@ -296,6 +301,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
self._signal_job_progress(job)
|
||||
|
||||
# if we get here we are done and can rename the file to the original dest
|
||||
self._logger.debug(f"{job.source}: saved to {job.download_path} (bytes={job.bytes})")
|
||||
in_progress_path.rename(job.download_path)
|
||||
|
||||
def _validate_filename(self, directory: str, filename: str) -> bool:
|
||||
@ -322,7 +328,9 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
try:
|
||||
job.on_start(job)
|
||||
except Exception as e:
|
||||
self._logger.error(e)
|
||||
self._logger.error(
|
||||
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
|
||||
@ -332,7 +340,9 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
try:
|
||||
job.on_progress(job)
|
||||
except Exception as e:
|
||||
self._logger.error(e)
|
||||
self._logger.error(
|
||||
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_progress(
|
||||
@ -348,7 +358,9 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
try:
|
||||
job.on_complete(job)
|
||||
except Exception as e:
|
||||
self._logger.error(e)
|
||||
self._logger.error(
|
||||
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_complete(
|
||||
@ -356,29 +368,36 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
)
|
||||
|
||||
def _signal_job_cancelled(self, job: DownloadJob) -> None:
|
||||
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
|
||||
return
|
||||
job.status = DownloadJobStatus.CANCELLED
|
||||
if job.on_cancelled:
|
||||
try:
|
||||
job.on_cancelled(job)
|
||||
except Exception as e:
|
||||
self._logger.error(e)
|
||||
self._logger.error(
|
||||
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_cancelled(str(job.source))
|
||||
|
||||
def _signal_job_error(self, job: DownloadJob) -> None:
|
||||
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
job.status = DownloadJobStatus.ERROR
|
||||
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
|
||||
if job.on_error:
|
||||
try:
|
||||
job.on_error(job)
|
||||
job.on_error(job, excp)
|
||||
except Exception as e:
|
||||
self._logger.error(e)
|
||||
self._logger.error(
|
||||
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.error_type
|
||||
assert job.error
|
||||
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
|
||||
|
||||
def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
|
||||
self._logger.warning(f"Cleaning up leftover files from cancelled download job {job.download_path}")
|
||||
self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}")
|
||||
try:
|
||||
if job.download_path:
|
||||
partial_file = self._in_progress_path(job.download_path)
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
@ -404,53 +404,72 @@ class EventServiceBase:
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_install_started(self, source: str) -> None:
|
||||
def emit_model_install_downloading(
|
||||
self,
|
||||
source: str,
|
||||
local_path: str,
|
||||
bytes: int,
|
||||
total_bytes: int,
|
||||
parts: List[Dict[str, Union[str, int]]],
|
||||
) -> None:
|
||||
"""
|
||||
Emitted when an install job is started.
|
||||
Emit at intervals while the install job is in progress (remote models only).
|
||||
|
||||
:param source: Source of the model
|
||||
:param local_path: Where model is downloading to
|
||||
:param parts: Progress of downloading URLs that comprise the model, if any.
|
||||
:param bytes: Number of bytes downloaded so far.
|
||||
:param total_bytes: Total size of download, including all files.
|
||||
This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes".
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_downloading",
|
||||
payload={
|
||||
"source": source,
|
||||
"local_path": local_path,
|
||||
"bytes": bytes,
|
||||
"total_bytes": total_bytes,
|
||||
"parts": parts,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_install_running(self, source: str) -> None:
|
||||
"""
|
||||
Emit once when an install job becomes active.
|
||||
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_started",
|
||||
event_name="model_install_running",
|
||||
payload={"source": source},
|
||||
)
|
||||
|
||||
def emit_model_install_completed(self, source: str, key: str) -> None:
|
||||
def emit_model_install_completed(self, source: str, key: str, total_bytes: Optional[int] = None) -> None:
|
||||
"""
|
||||
Emitted when an install job is completed successfully.
|
||||
Emit when an install job is completed successfully.
|
||||
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
:param key: Model config record key
|
||||
:param total_bytes: Size of the model (may be None for installation of a local path)
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_completed",
|
||||
payload={
|
||||
"source": source,
|
||||
"total_bytes": total_bytes,
|
||||
"key": key,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_install_progress(
|
||||
self,
|
||||
source: str,
|
||||
current_bytes: int,
|
||||
total_bytes: int,
|
||||
) -> None:
|
||||
def emit_model_install_cancelled(self, source: str) -> None:
|
||||
"""
|
||||
Emitted while the install job is in progress.
|
||||
(Downloaded models only)
|
||||
Emit when an install job is cancelled.
|
||||
|
||||
:param source: Source of the model
|
||||
:param current_bytes: Number of bytes downloaded so far
|
||||
:param total_bytes: Total bytes to download
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_progress",
|
||||
payload={
|
||||
"source": source,
|
||||
"current_bytes": int,
|
||||
"total_bytes": int,
|
||||
},
|
||||
event_name="model_install_cancelled",
|
||||
payload={"source": source},
|
||||
)
|
||||
|
||||
def emit_model_install_error(
|
||||
@ -460,10 +479,11 @@ class EventServiceBase:
|
||||
error: str,
|
||||
) -> None:
|
||||
"""
|
||||
Emitted when an install job encounters an exception.
|
||||
Emit when an install job encounters an exception.
|
||||
|
||||
:param source: Source of the model
|
||||
:param exception: The exception that raised the error
|
||||
:param error_type: The name of the exception
|
||||
:param error: A text description of the exception
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_error",
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Initialization file for model install service package."""
|
||||
|
||||
from .model_install_base import (
|
||||
CivitaiModelSource,
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
@ -22,4 +23,5 @@ __all__ = [
|
||||
"LocalModelSource",
|
||||
"HFModelSource",
|
||||
"URLModelSource",
|
||||
"CivitaiModelSource",
|
||||
]
|
||||
|
@ -1,27 +1,42 @@
|
||||
# Copyright 2023 Lincoln D. Stein and the InvokeAI development team
|
||||
"""Baseclass definitions for the model installer."""
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
|
||||
WAITING = "waiting" # waiting to be dequeued
|
||||
DOWNLOADING = "downloading" # downloading of model files in process
|
||||
RUNNING = "running" # being processed
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
CANCELLED = "cancelled" # terminated with an error message
|
||||
|
||||
|
||||
class ModelInstallPart(BaseModel):
|
||||
url: AnyHttpUrl
|
||||
path: Path
|
||||
bytes: int = 0
|
||||
total_bytes: int = 0
|
||||
|
||||
|
||||
class UnknownInstallJobException(Exception):
|
||||
@ -74,12 +89,31 @@ class LocalModelSource(StringLikeSource):
|
||||
return Path(self.path).as_posix()
|
||||
|
||||
|
||||
class CivitaiModelSource(StringLikeSource):
|
||||
"""A Civitai version id, with optional variant and access token."""
|
||||
|
||||
version_id: int
|
||||
variant: Optional[ModelRepoVariant] = None
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["civitai"] = "civitai"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of repoid when string rep needed."""
|
||||
base: str = str(self.version_id)
|
||||
base += f" ({self.variant})" if self.variant else ""
|
||||
return base
|
||||
|
||||
|
||||
class HFModelSource(StringLikeSource):
|
||||
"""A HuggingFace repo_id, with optional variant and sub-folder."""
|
||||
"""
|
||||
A HuggingFace repo_id with optional variant, sub-folder and access token.
|
||||
Note that the variant option, if not provided to the constructor, will default to fp16, which is
|
||||
what people (almost) always want.
|
||||
"""
|
||||
|
||||
repo_id: str
|
||||
variant: Optional[str] = None
|
||||
subfolder: Optional[str | Path] = None
|
||||
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
|
||||
subfolder: Optional[Path] = None
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["hf"] = "hf"
|
||||
|
||||
@ -103,19 +137,22 @@ class URLModelSource(StringLikeSource):
|
||||
|
||||
url: AnyHttpUrl
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["generic_url"] = "generic_url"
|
||||
type: Literal["url"] = "url"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of the url when string rep needed."""
|
||||
return str(self.url)
|
||||
|
||||
|
||||
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
|
||||
ModelSource = Annotated[
|
||||
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
|
||||
]
|
||||
|
||||
|
||||
class ModelInstallJob(BaseModel):
|
||||
"""Object that tracks the current status of an install request."""
|
||||
|
||||
id: int = Field(description="Unique ID for this job")
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||
config_in: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||
@ -128,15 +165,74 @@ class ModelInstallJob(BaseModel):
|
||||
)
|
||||
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
|
||||
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
|
||||
error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR")
|
||||
error: Optional[str] = Field(default=None, description="Error traceback") # noqa #501
|
||||
bytes: Optional[int] = Field(
|
||||
default=None, description="For a remote model, the number of bytes downloaded so far (may not be available)"
|
||||
)
|
||||
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
|
||||
source_metadata: Optional[AnyModelRepoMetadata] = Field(
|
||||
default=None, description="Metadata provided by the model source"
|
||||
)
|
||||
download_parts: Set[DownloadJob] = Field(
|
||||
default_factory=set, description="Download jobs contributing to this install"
|
||||
)
|
||||
# internal flags and transitory settings
|
||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||
|
||||
def set_error(self, e: Exception) -> None:
|
||||
"""Record the error and traceback from an exception."""
|
||||
self.error_type = e.__class__.__name__
|
||||
self.error = "".join(traceback.format_exception(e))
|
||||
self._exception = e
|
||||
self.status = InstallStatus.ERROR
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Call to cancel the job."""
|
||||
self.status = InstallStatus.CANCELLED
|
||||
|
||||
@property
|
||||
def error_type(self) -> Optional[str]:
|
||||
"""Class name of the exception that led to status==ERROR."""
|
||||
return self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
@property
|
||||
def error(self) -> Optional[str]:
|
||||
"""Error traceback."""
|
||||
return "".join(traceback.format_exception(self._exception)) if self._exception else None
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
"""Set status to CANCELLED."""
|
||||
return self.status == InstallStatus.CANCELLED
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
"""Return true if job has errored."""
|
||||
return self.status == InstallStatus.ERROR
|
||||
|
||||
@property
|
||||
def waiting(self) -> bool:
|
||||
"""Return true if job is waiting to run."""
|
||||
return self.status == InstallStatus.WAITING
|
||||
|
||||
@property
|
||||
def downloading(self) -> bool:
|
||||
"""Return true if job is downloading."""
|
||||
return self.status == InstallStatus.DOWNLOADING
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""Return true if job is running."""
|
||||
return self.status == InstallStatus.RUNNING
|
||||
|
||||
@property
|
||||
def complete(self) -> bool:
|
||||
"""Return true if job completed without errors."""
|
||||
return self.status == InstallStatus.COMPLETED
|
||||
|
||||
@property
|
||||
def in_terminal_state(self) -> bool:
|
||||
"""Return true if job is in a terminal state."""
|
||||
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
|
||||
|
||||
|
||||
class ModelInstallServiceBase(ABC):
|
||||
"""Abstract base class for InvokeAI model installation."""
|
||||
@ -146,6 +242,8 @@ class ModelInstallServiceBase(ABC):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
metadata_store: ModelMetadataStore,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
):
|
||||
"""
|
||||
@ -156,12 +254,14 @@ class ModelInstallServiceBase(ABC):
|
||||
:param event_bus: InvokeAI event bus for reporting events to.
|
||||
"""
|
||||
|
||||
# make the invoker optional here because we don't need it and it
|
||||
# makes the installer harder to use outside the web app
|
||||
@abstractmethod
|
||||
def start(self, *args: Any, **kwarg: Any) -> None:
|
||||
def start(self, invoker: Optional[Invoker] = None) -> None:
|
||||
"""Start the installer service."""
|
||||
|
||||
@abstractmethod
|
||||
def stop(self, *args: Any, **kwarg: Any) -> None:
|
||||
def stop(self, invoker: Optional[Invoker] = None) -> None:
|
||||
"""Stop the model install service. After this the objection can be safely deleted."""
|
||||
|
||||
@property
|
||||
@ -264,9 +364,13 @@ class ModelInstallServiceBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_job(self, source: ModelSource) -> List[ModelInstallJob]:
|
||||
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
|
||||
"""Return the ModelInstallJob(s) corresponding to the provided source."""
|
||||
|
||||
@abstractmethod
|
||||
def get_job_by_id(self, id: int) -> ModelInstallJob:
|
||||
"""Return the ModelInstallJob corresponding to the provided id. Raises ValueError if no job has that ID."""
|
||||
|
||||
@abstractmethod
|
||||
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
||||
"""
|
||||
@ -278,16 +382,19 @@ class ModelInstallServiceBase(ABC):
|
||||
"""Prune all completed and errored jobs."""
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_installs(self) -> List[ModelInstallJob]:
|
||||
def cancel_job(self, job: ModelInstallJob) -> None:
|
||||
"""Cancel the indicated job."""
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]:
|
||||
"""
|
||||
Wait for all pending installs to complete.
|
||||
|
||||
This will block until all pending installs have
|
||||
completed, been cancelled, or errored out. It will
|
||||
block indefinitely if one or more jobs are in the
|
||||
paused state.
|
||||
completed, been cancelled, or errored out.
|
||||
|
||||
It will return the current list of jobs.
|
||||
:param timeout: Wait up to indicated number of seconds. Raise an Exception('timeout') if
|
||||
installs do not complete within the indicated time.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
@ -1,60 +1,72 @@
|
||||
"""Model installation class."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from hashlib import sha256
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from queue import Empty, Queue
|
||||
from random import randbytes
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from huggingface_hub import HfFolder
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import Session
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, UnknownModelException
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
HuggingFaceMetadataFetch,
|
||||
ModelMetadataStore,
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
)
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||
from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
||||
|
||||
from .model_install_base import (
|
||||
CivitaiModelSource,
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
URLModelSource,
|
||||
)
|
||||
|
||||
# marker that the queue is done and that thread should exit
|
||||
STOP_JOB = ModelInstallJob(
|
||||
source=LocalModelSource(path="stop"),
|
||||
local_path=Path("/dev/null"),
|
||||
)
|
||||
TMPDIR_PREFIX = "tmpinstall_"
|
||||
|
||||
|
||||
class ModelInstallService(ModelInstallServiceBase):
|
||||
"""class for InvokeAI model installation."""
|
||||
|
||||
_app_config: InvokeAIAppConfig
|
||||
_record_store: ModelRecordServiceBase
|
||||
_event_bus: Optional[EventServiceBase] = None
|
||||
_install_queue: Queue[ModelInstallJob]
|
||||
_install_jobs: List[ModelInstallJob]
|
||||
_logger: Logger
|
||||
_cached_model_paths: Set[Path]
|
||||
_models_installed: Set[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
metadata_store: Optional[ModelMetadataStore] = None,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
session: Optional[Session] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the installer object.
|
||||
@ -67,10 +79,26 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._record_store = record_store
|
||||
self._event_bus = event_bus
|
||||
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
|
||||
self._install_jobs = []
|
||||
self._install_queue = Queue()
|
||||
self._cached_model_paths = set()
|
||||
self._models_installed = set()
|
||||
self._install_jobs: List[ModelInstallJob] = []
|
||||
self._install_queue: Queue[ModelInstallJob] = Queue()
|
||||
self._cached_model_paths: Set[Path] = set()
|
||||
self._models_installed: Set[str] = set()
|
||||
self._lock = threading.Lock()
|
||||
self._stop_event = threading.Event()
|
||||
self._downloads_changed_event = threading.Event()
|
||||
self._download_queue = download_queue
|
||||
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
|
||||
self._running = False
|
||||
self._session = session
|
||||
self._next_job_id = 0
|
||||
# There may not necessarily be a metadata store initialized
|
||||
# so we create one and initialize it with the same sql database
|
||||
# used by the record store service.
|
||||
if metadata_store:
|
||||
self._metadata_store = metadata_store
|
||||
else:
|
||||
assert isinstance(record_store, ModelRecordServiceSQL)
|
||||
self._metadata_store = ModelMetadataStore(record_store.db)
|
||||
|
||||
@property
|
||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||
@ -84,69 +112,31 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
||||
return self._event_bus
|
||||
|
||||
def start(self, *args: Any, **kwarg: Any) -> None:
|
||||
# make the invoker optional here because we don't need it and it
|
||||
# makes the installer harder to use outside the web app
|
||||
def start(self, invoker: Optional[Invoker] = None) -> None:
|
||||
"""Start the installer thread."""
|
||||
self._start_installer_thread()
|
||||
self.sync_to_config()
|
||||
with self._lock:
|
||||
if self._running:
|
||||
raise Exception("Attempt to start the installer service twice")
|
||||
self._start_installer_thread()
|
||||
self._remove_dangling_install_dirs()
|
||||
self.sync_to_config()
|
||||
|
||||
def stop(self, *args: Any, **kwarg: Any) -> None:
|
||||
def stop(self, invoker: Optional[Invoker] = None) -> None:
|
||||
"""Stop the installer thread; after this the object can be deleted and garbage collected."""
|
||||
self._install_queue.put(STOP_JOB)
|
||||
|
||||
def _start_installer_thread(self) -> None:
|
||||
threading.Thread(target=self._install_next_item, daemon=True).start()
|
||||
|
||||
def _install_next_item(self) -> None:
|
||||
done = False
|
||||
while not done:
|
||||
job = self._install_queue.get()
|
||||
if job == STOP_JOB:
|
||||
done = True
|
||||
continue
|
||||
|
||||
assert job.local_path is not None
|
||||
try:
|
||||
self._signal_job_running(job)
|
||||
if job.inplace:
|
||||
key = self.register_path(job.local_path, job.config_in)
|
||||
else:
|
||||
key = self.install_path(job.local_path, job.config_in)
|
||||
job.config_out = self.record_store.get_model(key)
|
||||
self._signal_job_completed(job)
|
||||
|
||||
except (OSError, DuplicateModelException, InvalidModelConfigException) as excp:
|
||||
self._signal_job_errored(job, excp)
|
||||
finally:
|
||||
self._install_queue.task_done()
|
||||
self._logger.info("Install thread exiting")
|
||||
|
||||
def _signal_job_running(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.RUNNING
|
||||
self._logger.info(f"{job.source}: model installation started")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_started(str(job.source))
|
||||
|
||||
def _signal_job_completed(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.COMPLETED
|
||||
assert job.config_out
|
||||
self._logger.info(
|
||||
f"{job.source}: model installation completed. {job.local_path} registered key {job.config_out.key}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.local_path is not None
|
||||
assert job.config_out is not None
|
||||
key = job.config_out.key
|
||||
self._event_bus.emit_model_install_completed(str(job.source), key)
|
||||
|
||||
def _signal_job_errored(self, job: ModelInstallJob, excp: Exception) -> None:
|
||||
job.set_error(excp)
|
||||
self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}")
|
||||
if self._event_bus:
|
||||
error_type = job.error_type
|
||||
error = job.error
|
||||
assert error_type is not None
|
||||
assert error is not None
|
||||
self._event_bus.emit_model_install_error(str(job.source), error_type, error)
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
raise Exception("Attempt to stop the install service before it was started")
|
||||
self._stop_event.set()
|
||||
with self._install_queue.mutex:
|
||||
self._install_queue.queue.clear() # get rid of pending jobs
|
||||
active_jobs = [x for x in self.list_jobs() if x.running]
|
||||
if active_jobs:
|
||||
self._logger.warning("Waiting for active install job to complete")
|
||||
self.wait_for_installs()
|
||||
self._download_cache.clear()
|
||||
self._running = False
|
||||
|
||||
def register_path(
|
||||
self,
|
||||
@ -172,7 +162,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||
old_hash = info.original_hash
|
||||
dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name
|
||||
new_path = self._copy_model(model_path, dest_path)
|
||||
try:
|
||||
new_path = self._copy_model(model_path, dest_path)
|
||||
except FileExistsError as excp:
|
||||
raise DuplicateModelException(
|
||||
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
|
||||
) from excp
|
||||
new_hash = FastModelHash.hash(new_path)
|
||||
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
||||
|
||||
@ -182,43 +177,56 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
info,
|
||||
)
|
||||
|
||||
def import_model(
|
||||
self,
|
||||
source: ModelSource,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> ModelInstallJob: # noqa D102
|
||||
if not config:
|
||||
config = {}
|
||||
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
||||
if isinstance(source, LocalModelSource):
|
||||
install_job = self._import_local_model(source, config)
|
||||
self._install_queue.put(install_job) # synchronously install
|
||||
elif isinstance(source, CivitaiModelSource):
|
||||
install_job = self._import_from_civitai(source, config)
|
||||
elif isinstance(source, HFModelSource):
|
||||
install_job = self._import_from_hf(source, config)
|
||||
elif isinstance(source, URLModelSource):
|
||||
install_job = self._import_from_url(source, config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model source: '{type(source)}'")
|
||||
|
||||
# Installing a local path
|
||||
if isinstance(source, LocalModelSource) and Path(source.path).exists(): # a path that is already on disk
|
||||
job = ModelInstallJob(
|
||||
source=source,
|
||||
config_in=config,
|
||||
local_path=Path(source.path),
|
||||
)
|
||||
self._install_jobs.append(job)
|
||||
self._install_queue.put(job)
|
||||
return job
|
||||
|
||||
else: # here is where we'd download a URL or repo_id. Implementation pending download queue.
|
||||
raise UnknownModelException("File or directory not found")
|
||||
self._install_jobs.append(install_job)
|
||||
return install_job
|
||||
|
||||
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
||||
return self._install_jobs
|
||||
|
||||
def get_job(self, source: ModelSource) -> List[ModelInstallJob]: # noqa D102
|
||||
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]: # noqa D102
|
||||
return [x for x in self._install_jobs if x.source == source]
|
||||
|
||||
def wait_for_installs(self) -> List[ModelInstallJob]: # noqa D102
|
||||
def get_job_by_id(self, id: int) -> ModelInstallJob: # noqa D102
|
||||
jobs = [x for x in self._install_jobs if x.id == id]
|
||||
if not jobs:
|
||||
raise ValueError(f"No job with id {id} known")
|
||||
assert len(jobs) == 1
|
||||
assert isinstance(jobs[0], ModelInstallJob)
|
||||
return jobs[0]
|
||||
|
||||
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
|
||||
"""Block until all installation jobs are done."""
|
||||
start = time.time()
|
||||
while len(self._download_cache) > 0:
|
||||
if self._downloads_changed_event.wait(timeout=5): # in case we miss an event
|
||||
self._downloads_changed_event.clear()
|
||||
if timeout > 0 and time.time() - start > timeout:
|
||||
raise Exception("Timeout exceeded")
|
||||
self._install_queue.join()
|
||||
return self._install_jobs
|
||||
|
||||
def cancel_job(self, job: ModelInstallJob) -> None:
|
||||
"""Cancel the indicated job."""
|
||||
job.cancel()
|
||||
with self._lock:
|
||||
self._cancel_download_parts(job)
|
||||
|
||||
def prune_jobs(self) -> None:
|
||||
"""Prune all completed and errored jobs."""
|
||||
unfinished_jobs = [
|
||||
x for x in self._install_jobs if x.status not in [InstallStatus.COMPLETED, InstallStatus.ERROR]
|
||||
]
|
||||
unfinished_jobs = [x for x in self._install_jobs if not x.in_terminal_state]
|
||||
self._install_jobs = unfinished_jobs
|
||||
|
||||
def sync_to_config(self) -> None:
|
||||
@ -234,10 +242,108 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
|
||||
callback = self._scan_install if install else self._scan_register
|
||||
search = ModelSearch(on_model_found=callback)
|
||||
self._models_installed: Set[str] = set()
|
||||
self._models_installed.clear()
|
||||
search.search(scan_dir)
|
||||
return list(self._models_installed)
|
||||
|
||||
def unregister(self, key: str) -> None: # noqa D102
|
||||
self.record_store.del_model(key)
|
||||
|
||||
def delete(self, key: str) -> None: # noqa D102
|
||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||
model = self.record_store.get_model(key)
|
||||
models_dir = self.app_config.models_path
|
||||
model_path = models_dir / model.path
|
||||
if model_path.is_relative_to(models_dir):
|
||||
self.unconditionally_delete(key)
|
||||
else:
|
||||
self.unregister(key)
|
||||
|
||||
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
||||
model = self.record_store.get_model(key)
|
||||
path = self.app_config.models_path / model.path
|
||||
if path.is_dir():
|
||||
rmtree(path)
|
||||
else:
|
||||
path.unlink()
|
||||
self.unregister(key)
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Internal functions that manage the installer threads
|
||||
# --------------------------------------------------------------------------------------------
|
||||
def _start_installer_thread(self) -> None:
|
||||
threading.Thread(target=self._install_next_item, daemon=True).start()
|
||||
self._running = True
|
||||
|
||||
def _install_next_item(self) -> None:
|
||||
done = False
|
||||
while not done:
|
||||
if self._stop_event.is_set():
|
||||
done = True
|
||||
continue
|
||||
try:
|
||||
job = self._install_queue.get(timeout=1)
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
assert job.local_path is not None
|
||||
try:
|
||||
if job.cancelled:
|
||||
self._signal_job_cancelled(job)
|
||||
|
||||
elif job.errored:
|
||||
self._signal_job_errored(job)
|
||||
|
||||
elif (
|
||||
job.waiting or job.downloading
|
||||
): # local jobs will be in waiting state, remote jobs will be downloading state
|
||||
job.total_bytes = self._stat_size(job.local_path)
|
||||
job.bytes = job.total_bytes
|
||||
self._signal_job_running(job)
|
||||
if job.inplace:
|
||||
key = self.register_path(job.local_path, job.config_in)
|
||||
else:
|
||||
key = self.install_path(job.local_path, job.config_in)
|
||||
job.config_out = self.record_store.get_model(key)
|
||||
|
||||
# enter the metadata, if there is any
|
||||
if job.source_metadata:
|
||||
self._metadata_store.add_metadata(key, job.source_metadata)
|
||||
self._signal_job_completed(job)
|
||||
|
||||
except InvalidModelConfigException as excp:
|
||||
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
|
||||
job.set_error(
|
||||
InvalidModelConfigException(
|
||||
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
|
||||
)
|
||||
)
|
||||
else:
|
||||
job.set_error(excp)
|
||||
self._signal_job_errored(job)
|
||||
|
||||
except (OSError, DuplicateModelException) as excp:
|
||||
job.set_error(excp)
|
||||
self._signal_job_errored(job)
|
||||
|
||||
finally:
|
||||
# if this is an install of a remote file, then clean up the temporary directory
|
||||
if job._install_tmpdir is not None:
|
||||
rmtree(job._install_tmpdir)
|
||||
self._install_queue.task_done()
|
||||
|
||||
self._logger.info("Install thread exiting")
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Internal functions that manage the models directory
|
||||
# --------------------------------------------------------------------------------------------
|
||||
def _remove_dangling_install_dirs(self) -> None:
|
||||
"""Remove leftover tmpdirs from aborted installs."""
|
||||
path = self._app_config.models_path
|
||||
for tmpdir in path.glob(f"{TMPDIR_PREFIX}*"):
|
||||
self._logger.info(f"Removing dangling temporary directory {tmpdir}")
|
||||
rmtree(tmpdir)
|
||||
|
||||
def _scan_models_directory(self) -> None:
|
||||
"""
|
||||
Scan the models directory for new and missing models.
|
||||
@ -320,28 +426,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
pass
|
||||
return True
|
||||
|
||||
def unregister(self, key: str) -> None: # noqa D102
|
||||
self.record_store.del_model(key)
|
||||
|
||||
def delete(self, key: str) -> None: # noqa D102
|
||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||
model = self.record_store.get_model(key)
|
||||
models_dir = self.app_config.models_path
|
||||
model_path = models_dir / model.path
|
||||
if model_path.is_relative_to(models_dir):
|
||||
self.unconditionally_delete(key)
|
||||
else:
|
||||
self.unregister(key)
|
||||
|
||||
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
||||
model = self.record_store.get_model(key)
|
||||
path = self.app_config.models_path / model.path
|
||||
if path.is_dir():
|
||||
rmtree(path)
|
||||
else:
|
||||
path.unlink()
|
||||
self.unregister(key)
|
||||
|
||||
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
|
||||
if old_path == new_path:
|
||||
return old_path
|
||||
@ -397,3 +481,279 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
self.record_store.add_model(key, info)
|
||||
return key
|
||||
|
||||
def _next_id(self) -> int:
|
||||
with self._lock:
|
||||
id = self._next_job_id
|
||||
self._next_job_id += 1
|
||||
return id
|
||||
|
||||
@staticmethod
|
||||
def _guess_variant() -> ModelRepoVariant:
|
||||
"""Guess the best HuggingFace variant type to download."""
|
||||
precision = choose_precision(choose_torch_device())
|
||||
return ModelRepoVariant.FP16 if precision == "float16" else ModelRepoVariant.DEFAULT
|
||||
|
||||
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
return ModelInstallJob(
|
||||
id=self._next_id(),
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
local_path=Path(source.path),
|
||||
inplace=source.inplace,
|
||||
)
|
||||
|
||||
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
if not source.access_token:
|
||||
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
|
||||
metadata = CivitaiMetadataFetch(self._session).from_id(str(source.version_id))
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
|
||||
|
||||
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
# Add user's cached access token to HuggingFace requests
|
||||
source.access_token = source.access_token or HfFolder.get_token()
|
||||
if not source.access_token:
|
||||
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
|
||||
|
||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
remote_files = metadata.download_urls(
|
||||
variant=source.variant or self._guess_variant(),
|
||||
subfolder=source.subfolder,
|
||||
session=self._session,
|
||||
)
|
||||
|
||||
return self._import_remote_model(
|
||||
source=source,
|
||||
config=config,
|
||||
remote_files=remote_files,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
# URLs from Civitai or HuggingFace will be handled specially
|
||||
url_patterns = {
|
||||
r"https?://civitai.com/": CivitaiMetadataFetch,
|
||||
r"https?://huggingface.co/": HuggingFaceMetadataFetch,
|
||||
}
|
||||
metadata = None
|
||||
for pattern, fetcher in url_patterns.items():
|
||||
if re.match(pattern, str(source.url), re.IGNORECASE):
|
||||
metadata = fetcher(self._session).from_url(source.url)
|
||||
break
|
||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
else:
|
||||
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
|
||||
|
||||
return self._import_remote_model(
|
||||
source=source,
|
||||
config=config,
|
||||
metadata=metadata,
|
||||
remote_files=remote_files,
|
||||
)
|
||||
|
||||
def _import_remote_model(
|
||||
self,
|
||||
source: ModelSource,
|
||||
remote_files: List[RemoteModelFile],
|
||||
metadata: Optional[AnyModelRepoMetadata],
|
||||
config: Optional[Dict[str, Any]],
|
||||
) -> ModelInstallJob:
|
||||
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
|
||||
# Currently the tmpdir isn't automatically removed at exit because it is
|
||||
# being held in a daemon thread.
|
||||
tmpdir = Path(
|
||||
mkdtemp(
|
||||
dir=self._app_config.models_path,
|
||||
prefix=TMPDIR_PREFIX,
|
||||
)
|
||||
)
|
||||
install_job = ModelInstallJob(
|
||||
id=self._next_id(),
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
source_metadata=metadata,
|
||||
local_path=tmpdir, # local path may change once the download has started due to content-disposition handling
|
||||
bytes=0,
|
||||
total_bytes=0,
|
||||
)
|
||||
# we remember the path up to the top of the tmpdir so that it may be
|
||||
# removed safely at the end of the install process.
|
||||
install_job._install_tmpdir = tmpdir
|
||||
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
|
||||
|
||||
self._logger.info(f"Queuing {source} for downloading")
|
||||
for model_file in remote_files:
|
||||
url = model_file.url
|
||||
path = model_file.path
|
||||
self._logger.info(f"Downloading {url} => {path}")
|
||||
install_job.total_bytes += model_file.size
|
||||
assert hasattr(source, "access_token")
|
||||
dest = tmpdir / path.parent
|
||||
dest.mkdir(parents=True, exist_ok=True)
|
||||
download_job = DownloadJob(
|
||||
source=url,
|
||||
dest=dest,
|
||||
access_token=source.access_token,
|
||||
)
|
||||
self._download_cache[download_job.source] = install_job # matches a download job to an install job
|
||||
install_job.download_parts.add(download_job)
|
||||
|
||||
self._download_queue.submit_download_job(
|
||||
download_job,
|
||||
on_start=self._download_started_callback,
|
||||
on_progress=self._download_progress_callback,
|
||||
on_complete=self._download_complete_callback,
|
||||
on_error=self._download_error_callback,
|
||||
on_cancelled=self._download_cancelled_callback,
|
||||
)
|
||||
return install_job
|
||||
|
||||
def _stat_size(self, path: Path) -> int:
|
||||
size = 0
|
||||
if path.is_file():
|
||||
size = path.stat().st_size
|
||||
elif path.is_dir():
|
||||
for root, _, files in os.walk(path):
|
||||
size += sum(self._stat_size(Path(root, x)) for x in files)
|
||||
return size
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Callbacks are executed by the download queue in a separate thread
|
||||
# ------------------------------------------------------------------
|
||||
def _download_started_callback(self, download_job: DownloadJob) -> None:
|
||||
self._logger.info(f"{download_job.source}: model download started")
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
install_job.status = InstallStatus.DOWNLOADING
|
||||
|
||||
assert download_job.download_path
|
||||
if install_job.local_path == install_job._install_tmpdir:
|
||||
partial_path = download_job.download_path.relative_to(install_job._install_tmpdir)
|
||||
dest_name = partial_path.parts[0]
|
||||
install_job.local_path = install_job._install_tmpdir / dest_name
|
||||
|
||||
# Update the total bytes count for remote sources.
|
||||
if not install_job.total_bytes:
|
||||
install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts)
|
||||
|
||||
def _download_progress_callback(self, download_job: DownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
|
||||
self._cancel_download_parts(install_job)
|
||||
else:
|
||||
# update sizes
|
||||
install_job.bytes = sum(x.bytes for x in install_job.download_parts)
|
||||
self._signal_job_downloading(install_job)
|
||||
|
||||
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
self._download_cache.pop(download_job.source, None)
|
||||
|
||||
# are there any more active jobs left in this task?
|
||||
if all(x.complete for x in install_job.download_parts):
|
||||
# now enqueue job for actual installation into the models directory
|
||||
self._install_queue.put(install_job)
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache.pop(download_job.source, None)
|
||||
assert install_job is not None
|
||||
assert excp is not None
|
||||
install_job.set_error(excp)
|
||||
self._logger.error(
|
||||
f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}"
|
||||
)
|
||||
self._cancel_download_parts(install_job)
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _download_cancelled_callback(self, download_job: DownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache.pop(download_job.source, None)
|
||||
if not install_job:
|
||||
return
|
||||
self._downloads_changed_event.set()
|
||||
self._logger.warning(f"Download {download_job.source} cancelled.")
|
||||
# if install job has already registered an error, then do not replace its status with cancelled
|
||||
if not install_job.errored:
|
||||
install_job.cancel()
|
||||
self._cancel_download_parts(install_job)
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _cancel_download_parts(self, install_job: ModelInstallJob) -> None:
|
||||
# on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks
|
||||
# do not lock here because it gets called within a locked context
|
||||
for s in install_job.download_parts:
|
||||
self._download_queue.cancel_job(s)
|
||||
|
||||
if all(x.in_terminal_state for x in install_job.download_parts):
|
||||
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
||||
self._install_queue.put(install_job)
|
||||
|
||||
# ------------------------------------------------------------------------------------------------
|
||||
# Internal methods that put events on the event bus
|
||||
# ------------------------------------------------------------------------------------------------
|
||||
def _signal_job_running(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.RUNNING
|
||||
self._logger.info(f"{job.source}: model installation started")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_running(str(job.source))
|
||||
|
||||
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
|
||||
if self._event_bus:
|
||||
parts: List[Dict[str, str | int]] = [
|
||||
{
|
||||
"url": str(x.source),
|
||||
"local_path": str(x.download_path),
|
||||
"bytes": x.bytes,
|
||||
"total_bytes": x.total_bytes,
|
||||
}
|
||||
for x in job.download_parts
|
||||
]
|
||||
assert job.bytes is not None
|
||||
assert job.total_bytes is not None
|
||||
self._event_bus.emit_model_install_downloading(
|
||||
str(job.source),
|
||||
local_path=job.local_path.as_posix(),
|
||||
parts=parts,
|
||||
bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
)
|
||||
|
||||
def _signal_job_completed(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.COMPLETED
|
||||
assert job.config_out
|
||||
self._logger.info(
|
||||
f"{job.source}: model installation completed. {job.local_path} registered key {job.config_out.key}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.local_path is not None
|
||||
assert job.config_out is not None
|
||||
key = job.config_out.key
|
||||
self._event_bus.emit_model_install_completed(str(job.source), key)
|
||||
|
||||
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
||||
self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}\n{job.error}")
|
||||
if self._event_bus:
|
||||
error_type = job.error_type
|
||||
error = job.error
|
||||
assert error_type is not None
|
||||
assert error is not None
|
||||
self._event_bus.emit_model_install_error(str(job.source), error_type, error)
|
||||
|
||||
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
|
||||
self._logger.info(f"{job.source}: model installation was cancelled")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_cancelled(str(job.source))
|
||||
|
@ -4,6 +4,8 @@ from .model_records_base import ( # noqa F401
|
||||
InvalidModelException,
|
||||
ModelRecordServiceBase,
|
||||
UnknownModelException,
|
||||
ModelSummary,
|
||||
ModelRecordOrderBy,
|
||||
)
|
||||
from .model_records_sql import ModelRecordServiceSQL # noqa F401
|
||||
|
||||
@ -13,4 +15,6 @@ __all__ = [
|
||||
"DuplicateModelException",
|
||||
"InvalidModelException",
|
||||
"UnknownModelException",
|
||||
"ModelSummary",
|
||||
"ModelRecordOrderBy",
|
||||
]
|
||||
|
@ -4,10 +4,15 @@ Abstract base class for storing and retrieving model configuration records.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
@ -26,11 +31,33 @@ class ConfigFileVersionMismatchException(Exception):
|
||||
"""Raised on an attempt to open a config with an incompatible version."""
|
||||
|
||||
|
||||
class ModelRecordOrderBy(str, Enum):
|
||||
"""The order in which to return model summaries."""
|
||||
|
||||
Default = "default" # order by type, base, format and name
|
||||
Type = "type"
|
||||
Base = "base"
|
||||
Name = "name"
|
||||
Format = "format"
|
||||
|
||||
|
||||
class ModelSummary(BaseModel):
|
||||
"""A short summary of models for UI listing purposes."""
|
||||
|
||||
key: str = Field(description="model key")
|
||||
type: ModelType = Field(description="model type")
|
||||
base: BaseModelType = Field(description="base model")
|
||||
format: ModelFormat = Field(description="model format")
|
||||
name: str = Field(description="model name")
|
||||
description: str = Field(description="short description of model")
|
||||
tags: Set[str] = Field(description="tags associated with model")
|
||||
|
||||
|
||||
class ModelRecordServiceBase(ABC):
|
||||
"""Abstract base class for storage and retrieval of model configs."""
|
||||
|
||||
@abstractmethod
|
||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@ -54,7 +81,7 @@ class ModelRecordServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
@ -75,6 +102,47 @@ class ModelRecordServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def metadata_store(self) -> ModelMetadataStore:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Retrieve metadata (if any) from when model was downloaded from a repo.
|
||||
|
||||
:param key: Model key
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Search model metadata for ones with all listed tags and return their corresponding configs.
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Return a paginated summary listing of each model in the database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
|
@ -42,9 +42,11 @@ Typical usage:
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@ -52,11 +54,14 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
|
||||
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_records_base import (
|
||||
DuplicateModelException,
|
||||
ModelRecordOrderBy,
|
||||
ModelRecordServiceBase,
|
||||
ModelSummary,
|
||||
UnknownModelException,
|
||||
)
|
||||
|
||||
@ -64,9 +69,6 @@ from .model_records_base import (
|
||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||
|
||||
_db: SqliteDatabase
|
||||
_cursor: sqlite3.Cursor
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
@ -78,7 +80,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
@property
|
||||
def db(self) -> SqliteDatabase:
|
||||
"""Return the underlying database."""
|
||||
return self._db
|
||||
|
||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@ -293,3 +300,95 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
return results
|
||||
|
||||
@property
|
||||
def metadata_store(self) -> ModelMetadataStore:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
return ModelMetadataStore(self._db)
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Retrieve metadata (if any) from when model was downloaded from a repo.
|
||||
|
||||
:param key: Model key
|
||||
"""
|
||||
store = self.metadata_store
|
||||
try:
|
||||
metadata = store.get_metadata(key)
|
||||
return metadata
|
||||
except UnknownMetadataException:
|
||||
return None
|
||||
|
||||
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Search model metadata for ones with all listed tags and return their corresponding configs.
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
store = ModelMetadataStore(self._db)
|
||||
keys = store.search_by_tag(tags)
|
||||
return [self.get_model(x) for x in keys]
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
store = ModelMetadataStore(self._db)
|
||||
return store.list_tags()
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
store = ModelMetadataStore(self._db)
|
||||
return store.list_all_metadata()
|
||||
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Return a paginated summary listing of each model in the database."""
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name",
|
||||
ModelRecordOrderBy.Type: "a.type",
|
||||
ModelRecordOrderBy.Base: "a.base",
|
||||
ModelRecordOrderBy.Name: "a.name",
|
||||
ModelRecordOrderBy.Format: "a.format",
|
||||
}
|
||||
|
||||
def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]:
|
||||
"""Fix up results so that there are no null values."""
|
||||
result: Dict[str, Union[str, int, Set[str]]] = {}
|
||||
for key, item in summary.items():
|
||||
result[key] = item or ""
|
||||
result["tags"] = set(json.loads(summary["tags"] or "[]"))
|
||||
return result
|
||||
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
with self._db.lock:
|
||||
# query1: get the total number of model configs
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from model_config;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(self._cursor.fetchone()[0])
|
||||
|
||||
# query2: fetch key fields from the join of model_config and model_metadata
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT a.id as key, a.type, a.base, a.format, a.name,
|
||||
json_extract(a.config, '$.description') as description,
|
||||
json_extract(b.metadata, '$.tags') as tags
|
||||
FROM model_config AS a
|
||||
LEFT JOIN model_metadata AS b on a.id=b.id
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
""",
|
||||
(
|
||||
per_page,
|
||||
page * per_page,
|
||||
),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
items = [ModelSummary.model_validate(_fixup(dict(x))) for x in rows]
|
||||
return PaginatedResults(
|
||||
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
|
||||
)
|
||||
|
@ -6,6 +6,7 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import build_migration_2
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@ -28,7 +29,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator = SqliteMigrator(db=db)
|
||||
migrator.register_migration(build_migration_1())
|
||||
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
|
||||
migrator.register_migration(build_migration_3())
|
||||
migrator.register_migration(build_migration_3(app_config=config, logger=logger))
|
||||
migrator.register_migration(build_migration_4())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
@ -11,8 +11,6 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
UnsafeWorkflowWithVersionValidator,
|
||||
)
|
||||
|
||||
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
|
||||
|
||||
|
||||
class Migration2Callback:
|
||||
def __init__(self, image_files: ImageFileStorageBase, logger: Logger):
|
||||
@ -25,8 +23,6 @@ class Migration2Callback:
|
||||
self._drop_old_workflow_tables(cursor)
|
||||
self._add_workflow_library(cursor)
|
||||
self._drop_model_manager_metadata(cursor)
|
||||
self._recreate_model_config(cursor)
|
||||
self._migrate_model_config_records(cursor)
|
||||
self._migrate_embedded_workflows(cursor)
|
||||
|
||||
def _add_images_has_workflow(self, cursor: sqlite3.Cursor) -> None:
|
||||
@ -100,45 +96,6 @@ class Migration2Callback:
|
||||
"""Drops the `model_manager_metadata` table."""
|
||||
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
|
||||
|
||||
def _recreate_model_config(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
Drops the `model_config` table, recreating it.
|
||||
|
||||
In 3.4.0, this table used explicit columns but was changed to use json_extract 3.5.0.
|
||||
|
||||
Because this table is not used in production, we are able to simply drop it and recreate it.
|
||||
"""
|
||||
|
||||
cursor.execute("DROP TABLE IF EXISTS model_config;")
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_config (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
-- The next 3 fields are enums in python, unrestricted string here
|
||||
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||
original_hash TEXT, -- could be null
|
||||
-- Serialized JSON representation of the whole config object,
|
||||
-- which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""After updating the model config table, we repopulate it."""
|
||||
model_record_migrator = MigrateModelYamlToDb1(cursor)
|
||||
model_record_migrator.migrate()
|
||||
|
||||
def _migrate_embedded_workflows(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
|
||||
|
@ -1,13 +1,16 @@
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
|
||||
|
||||
|
||||
class Migration3Callback:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
|
||||
self._app_config = app_config
|
||||
self._logger = logger
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._drop_model_manager_metadata(cursor)
|
||||
@ -54,11 +57,12 @@ class Migration3Callback:
|
||||
|
||||
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""After updating the model config table, we repopulate it."""
|
||||
model_record_migrator = MigrateModelYamlToDb1(cursor)
|
||||
self._logger.info("Migrating model config records from models.yaml to database")
|
||||
model_record_migrator = MigrateModelYamlToDb1(self._app_config, self._logger, cursor)
|
||||
model_record_migrator.migrate()
|
||||
|
||||
|
||||
def build_migration_3() -> Migration:
|
||||
def build_migration_3(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
|
||||
"""
|
||||
Build the migration from database version 2 to 3.
|
||||
|
||||
@ -69,7 +73,7 @@ def build_migration_3() -> Migration:
|
||||
migration_3 = Migration(
|
||||
from_version=2,
|
||||
to_version=3,
|
||||
callback=Migration3Callback(),
|
||||
callback=Migration3Callback(app_config=app_config, logger=logger),
|
||||
)
|
||||
|
||||
return migration_3
|
||||
|
@ -0,0 +1,83 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration4Callback:
|
||||
"""Callback to do step 4 of migration."""
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None: # noqa D102
|
||||
self._create_model_metadata(cursor)
|
||||
self._create_model_tags(cursor)
|
||||
self._create_tags(cursor)
|
||||
self._create_triggers(cursor)
|
||||
|
||||
def _create_model_metadata(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Create the table used to store model metadata downloaded from remote sources."""
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_metadata (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
name TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.name')) VIRTUAL NOT NULL,
|
||||
author TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.author')) VIRTUAL NOT NULL,
|
||||
-- Serialized JSON representation of the whole metadata object,
|
||||
-- which will contain additional fields from subclasses
|
||||
metadata TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
FOREIGN KEY(id) REFERENCES model_config(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _create_model_tags(self, cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_tags (
|
||||
model_id TEXT NOT NULL,
|
||||
tag_id INTEGER NOT NULL,
|
||||
FOREIGN KEY(model_id) REFERENCES model_config(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY(tag_id) REFERENCES tags(tag_id) ON DELETE CASCADE,
|
||||
UNIQUE(model_id,tag_id)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _create_tags(self, cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS tags (
|
||||
tag_id INTEGER NOT NULL PRIMARY KEY,
|
||||
tag_text TEXT NOT NULL UNIQUE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _create_triggers(self, cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS model_metadata_updated_at
|
||||
AFTER UPDATE
|
||||
ON model_metadata FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE model_metadata SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def build_migration_4() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 3 to 4.
|
||||
|
||||
Adds the tables needed to store model metadata and tags.
|
||||
"""
|
||||
migration_4 = Migration(
|
||||
from_version=3,
|
||||
to_version=4,
|
||||
callback=Migration4Callback(),
|
||||
)
|
||||
|
||||
return migration_4
|
@ -23,7 +23,6 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
@ -46,10 +45,9 @@ class MigrateModelYamlToDb1:
|
||||
logger: Logger
|
||||
cursor: sqlite3.Cursor
|
||||
|
||||
def __init__(self, cursor: sqlite3.Cursor = None) -> None:
|
||||
self.config = InvokeAIAppConfig.get_config()
|
||||
self.config.parse_args()
|
||||
self.logger = InvokeAILogger.get_logger()
|
||||
def __init__(self, config: InvokeAIAppConfig, logger: Logger, cursor: sqlite3.Cursor = None) -> None:
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
self.cursor = cursor
|
||||
|
||||
def get_yaml(self) -> DictConfig:
|
||||
|
@ -6,6 +6,7 @@ from .config import (
|
||||
InvalidModelConfigException,
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
@ -15,15 +16,16 @@ from .probe import ModelProbe
|
||||
from .search import ModelSearch
|
||||
|
||||
__all__ = [
|
||||
"ModelProbe",
|
||||
"ModelSearch",
|
||||
"AnyModelConfig",
|
||||
"BaseModelType",
|
||||
"ModelRepoVariant",
|
||||
"InvalidModelConfigException",
|
||||
"ModelConfigFactory",
|
||||
"BaseModelType",
|
||||
"ModelType",
|
||||
"SubModelType",
|
||||
"ModelVariantType",
|
||||
"ModelFormat",
|
||||
"ModelProbe",
|
||||
"ModelSearch",
|
||||
"ModelType",
|
||||
"ModelVariantType",
|
||||
"SchedulerPredictionType",
|
||||
"AnyModelConfig",
|
||||
"SubModelType",
|
||||
]
|
||||
|
@ -99,6 +99,17 @@ class SchedulerPredictionType(str, Enum):
|
||||
Sample = "sample"
|
||||
|
||||
|
||||
class ModelRepoVariant(str, Enum):
|
||||
"""Various hugging face variants on the diffusers format."""
|
||||
|
||||
DEFAULT = "default" # model files without "fp16" or other qualifier
|
||||
FP16 = "fp16"
|
||||
FP32 = "fp32"
|
||||
ONNX = "onnx"
|
||||
OPENVINO = "openvino"
|
||||
FLAX = "flax"
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
"""Base class for model configuration information."""
|
||||
|
||||
|
50
invokeai/backend/model_manager/metadata/__init__.py
Normal file
50
invokeai/backend/model_manager/metadata/__init__.py
Normal file
@ -0,0 +1,50 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_manager.metadata
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata import(
|
||||
AnyModelRepoMetadata,
|
||||
CommercialUsage,
|
||||
LicenseRestrictions,
|
||||
HuggingFaceMetadata,
|
||||
CivitaiMetadata,
|
||||
)
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
|
||||
|
||||
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
|
||||
assert isinstance(data, CivitaiMetadata)
|
||||
if data.allow_commercial_use:
|
||||
print("Commercial use of this model is allowed")
|
||||
"""
|
||||
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch
|
||||
from .metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
AnyModelRepoMetadataValidator,
|
||||
BaseMetadata,
|
||||
CivitaiMetadata,
|
||||
CommercialUsage,
|
||||
HuggingFaceMetadata,
|
||||
LicenseRestrictions,
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
from .metadata_store import ModelMetadataStore
|
||||
|
||||
__all__ = [
|
||||
"AnyModelRepoMetadata",
|
||||
"AnyModelRepoMetadataValidator",
|
||||
"CivitaiMetadata",
|
||||
"CivitaiMetadataFetch",
|
||||
"CommercialUsage",
|
||||
"HuggingFaceMetadata",
|
||||
"HuggingFaceMetadataFetch",
|
||||
"LicenseRestrictions",
|
||||
"ModelMetadataStore",
|
||||
"BaseMetadata",
|
||||
"ModelMetadataWithFiles",
|
||||
"RemoteModelFile",
|
||||
"UnknownMetadataException",
|
||||
]
|
21
invokeai/backend/model_manager/metadata/fetch/__init__.py
Normal file
21
invokeai/backend/model_manager/metadata/fetch/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_manager.metadata.fetch
|
||||
|
||||
Usage:
|
||||
from invokeai.backend.model_manager.metadata.fetch import (
|
||||
CivitaiMetadataFetch,
|
||||
HuggingFaceMetadataFetch,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import CivitaiMetadata
|
||||
|
||||
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
|
||||
assert isinstance(data, CivitaiMetadata)
|
||||
if data.allow_commercial_use:
|
||||
print("Commercial use of this model is allowed")
|
||||
"""
|
||||
|
||||
from .civitai import CivitaiMetadataFetch
|
||||
from .fetch_base import ModelMetadataFetchBase
|
||||
from .huggingface import HuggingFaceMetadataFetch
|
||||
|
||||
__all__ = ["ModelMetadataFetchBase", "CivitaiMetadataFetch", "HuggingFaceMetadataFetch"]
|
187
invokeai/backend/model_manager/metadata/fetch/civitai.py
Normal file
187
invokeai/backend/model_manager/metadata/fetch/civitai.py
Normal file
@ -0,0 +1,187 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
This module fetches model metadata objects from the Civitai model repository.
|
||||
In addition to the `from_url()` and `from_id()` methods inherited from the
|
||||
`ModelMetadataFetchBase` base class.
|
||||
|
||||
Civitai has two separate ID spaces: a model ID and a version ID. The
|
||||
version ID corresponds to a specific model, and is the ID accepted by
|
||||
`from_id()`. The model ID corresponds to a family of related models,
|
||||
such as different training checkpoints or 16 vs 32-bit versions. The
|
||||
`from_civitai_modelid()` method will accept a model ID and return the
|
||||
metadata from the default version within this model set. The default
|
||||
version is the same as what the user sees when they click on a model's
|
||||
thumbnail.
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
|
||||
|
||||
fetcher = CivitaiMetadataFetch()
|
||||
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
||||
print(metadata.trained_words)
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from ..metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadata,
|
||||
CommercialUsage,
|
||||
LicenseRestrictions,
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
from .fetch_base import ModelMetadataFetchBase
|
||||
|
||||
CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)"
|
||||
CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
|
||||
CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)"
|
||||
|
||||
CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
||||
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||
|
||||
|
||||
class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
"""Fetch model metadata from Civitai."""
|
||||
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
By providing a configurable Session object, we can support unit tests on
|
||||
this module without an internet connection.
|
||||
"""
|
||||
self._requests = session or requests.Session()
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Given a URL to a CivitAI model or version page, return a ModelMetadata object.
|
||||
|
||||
In the event that the URL points to a model page without the particular version
|
||||
indicated, the default model version is returned. Otherwise, the requested version
|
||||
is returned.
|
||||
"""
|
||||
if match := re.match(CIVITAI_VERSION_PAGE_RE, str(url), re.IGNORECASE):
|
||||
model_id = match.group(1)
|
||||
version_id = match.group(2)
|
||||
return self.from_civitai_versionid(int(version_id), int(model_id))
|
||||
elif match := re.match(CIVITAI_MODEL_PAGE_RE, str(url), re.IGNORECASE):
|
||||
model_id = match.group(1)
|
||||
return self.from_civitai_modelid(int(model_id))
|
||||
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url), re.IGNORECASE):
|
||||
version_id = match.group(1)
|
||||
return self.from_civitai_versionid(int(version_id))
|
||||
raise UnknownMetadataException("The url '{url}' does not match any known Civitai URL patterns")
|
||||
|
||||
def from_id(self, id: str) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Given a Civitai model version ID, return a ModelRepoMetadata object.
|
||||
|
||||
May raise an `UnknownMetadataException`.
|
||||
"""
|
||||
return self.from_civitai_versionid(int(id))
|
||||
|
||||
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
|
||||
"""
|
||||
Return metadata from the default version of the indicated model.
|
||||
|
||||
May raise an `UnknownMetadataException`.
|
||||
"""
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(model_url).json()
|
||||
return self._from_model_json(model_json)
|
||||
|
||||
def _from_model_json(self, model_json: Dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
try:
|
||||
version_id = version_id or model_json["modelVersions"][0]["id"]
|
||||
except TypeError as excp:
|
||||
raise UnknownMetadataException from excp
|
||||
|
||||
# loop till we find the section containing the version requested
|
||||
version_sections = [x for x in model_json["modelVersions"] if x["id"] == version_id]
|
||||
if not version_sections:
|
||||
raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
|
||||
|
||||
version_json = version_sections[0]
|
||||
safe_thumbnails = [x["url"] for x in version_json["images"] if x["nsfw"] == "None"]
|
||||
|
||||
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
|
||||
primary = [x for x in version_json["files"] if x.get("primary")]
|
||||
assert len(primary) == 1
|
||||
primary_file = primary[0]
|
||||
|
||||
url = primary_file["downloadUrl"]
|
||||
if "?" not in url: # work around apparent bug in civitai api
|
||||
metadata_string = ""
|
||||
for key, value in primary_file["metadata"].items():
|
||||
if not value:
|
||||
continue
|
||||
metadata_string += f"&{key}={value}"
|
||||
url = url + f"?type={primary_file['type']}{metadata_string}"
|
||||
model_files = [
|
||||
RemoteModelFile(
|
||||
url=url,
|
||||
path=Path(primary_file["name"]),
|
||||
size=int(primary_file["sizeKB"] * 1024),
|
||||
sha256=primary_file["hashes"]["SHA256"],
|
||||
)
|
||||
]
|
||||
return CivitaiMetadata(
|
||||
id=model_json["id"],
|
||||
name=version_json["name"],
|
||||
version_id=version_json["id"],
|
||||
version_name=version_json["name"],
|
||||
created=datetime.fromisoformat(_fix_timezone(version_json["createdAt"])),
|
||||
updated=datetime.fromisoformat(_fix_timezone(version_json["updatedAt"])),
|
||||
published=datetime.fromisoformat(_fix_timezone(version_json["publishedAt"])),
|
||||
base_model_trained_on=version_json["baseModel"], # note - need a dictionary to turn into a BaseModelType
|
||||
files=model_files,
|
||||
download_url=version_json["downloadUrl"],
|
||||
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
|
||||
author=model_json["creator"]["username"],
|
||||
description=model_json["description"],
|
||||
version_description=version_json["description"] or "",
|
||||
tags=model_json["tags"],
|
||||
trained_words=version_json["trainedWords"],
|
||||
nsfw=model_json["nsfw"],
|
||||
restrictions=LicenseRestrictions(
|
||||
AllowNoCredit=model_json["allowNoCredit"],
|
||||
AllowCommercialUse=CommercialUsage(model_json["allowCommercialUse"]),
|
||||
AllowDerivatives=model_json["allowDerivatives"],
|
||||
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
||||
),
|
||||
)
|
||||
|
||||
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
"""
|
||||
Return a CivitaiMetadata object given a model version id.
|
||||
|
||||
May raise an `UnknownMetadataException`.
|
||||
"""
|
||||
if model_id is None:
|
||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||
version = self._requests.get(version_url).json()
|
||||
model_id = version["modelId"]
|
||||
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(model_url).json()
|
||||
return self._from_model_json(model_json, version_id)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: str) -> CivitaiMetadata:
|
||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||
metadata = CivitaiMetadata.model_validate_json(json)
|
||||
return metadata
|
||||
|
||||
|
||||
def _fix_timezone(date: str) -> str:
|
||||
return re.sub(r"Z$", "+00:00", date)
|
61
invokeai/backend/model_manager/metadata/fetch/fetch_base.py
Normal file
61
invokeai/backend/model_manager/metadata/fetch/fetch_base.py
Normal file
@ -0,0 +1,61 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
This module is the base class for subclasses that fetch metadata from model repositories
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import CivitAIMetadataFetch
|
||||
|
||||
fetcher = CivitaiMetadataFetch()
|
||||
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
||||
print(metadata.trained_words)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator
|
||||
|
||||
|
||||
class ModelMetadataFetchBase(ABC):
|
||||
"""Fetch metadata from remote generative model repositories."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
By providing a configurable Session object, we can support unit tests on
|
||||
this module without an internet connection.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Given a URL to a model repository, return a ModelMetadata object.
|
||||
|
||||
This method will raise a `UnknownMetadataException`
|
||||
in the event that the requested model metadata is not found at the provided location.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def from_id(self, id: str) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Given an ID for a model, return a ModelMetadata object.
|
||||
|
||||
This method will raise a `UnknownMetadataException`
|
||||
in the event that the requested model's metadata is not found at the provided id.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: str) -> AnyModelRepoMetadata:
|
||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||
metadata = AnyModelRepoMetadataValidator.validate_json(json)
|
||||
return metadata
|
92
invokeai/backend/model_manager/metadata/fetch/huggingface.py
Normal file
92
invokeai/backend/model_manager/metadata/fetch/huggingface.py
Normal file
@ -0,0 +1,92 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
This module fetches model metadata objects from the HuggingFace model repository,
|
||||
using either a `repo_id` or the model page URL.
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch
|
||||
|
||||
fetcher = HuggingFaceMetadataFetch()
|
||||
metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
|
||||
print(metadata.tags)
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from ..metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
HuggingFaceMetadata,
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
from .fetch_base import ModelMetadataFetchBase
|
||||
|
||||
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
|
||||
|
||||
|
||||
class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
"""Fetch model metadata from HuggingFace."""
|
||||
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
By providing a configurable Session object, we can support unit tests on
|
||||
this module without an internet connection.
|
||||
"""
|
||||
self._requests = session or requests.Session()
|
||||
configure_http_backend(backend_factory=lambda: self._requests)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: str) -> HuggingFaceMetadata:
|
||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||
metadata = HuggingFaceMetadata.model_validate_json(json)
|
||||
return metadata
|
||||
|
||||
def from_id(self, id: str) -> AnyModelRepoMetadata:
|
||||
"""Return a HuggingFaceMetadata object given the model's repo_id."""
|
||||
try:
|
||||
model_info = HfApi().model_info(repo_id=id, files_metadata=True)
|
||||
except RepositoryNotFoundError as excp:
|
||||
raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp
|
||||
|
||||
_, name = id.split("/")
|
||||
return HuggingFaceMetadata(
|
||||
id=model_info.id,
|
||||
author=model_info.author,
|
||||
name=name,
|
||||
last_modified=model_info.last_modified,
|
||||
tag_dict=model_info.card_data.to_dict() if model_info.card_data else {},
|
||||
tags=model_info.tags,
|
||||
files=[
|
||||
RemoteModelFile(
|
||||
url=hf_hub_url(id, x.rfilename),
|
||||
path=Path(name, x.rfilename),
|
||||
size=x.size,
|
||||
sha256=x.lfs.get("sha256") if x.lfs else None,
|
||||
)
|
||||
for x in model_info.siblings
|
||||
],
|
||||
)
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Return a HuggingFaceMetadata object given the model's web page URL.
|
||||
|
||||
In the case of an invalid or missing URL, raises a ModelNotFound exception.
|
||||
"""
|
||||
if match := re.match(HF_MODEL_RE, str(url), re.IGNORECASE):
|
||||
repo_id = match.group(1)
|
||||
return self.from_id(repo_id)
|
||||
else:
|
||||
raise UnknownMetadataException(f"'{url}' does not look like a HuggingFace model page")
|
202
invokeai/backend/model_manager/metadata/metadata_base.py
Normal file
202
invokeai/backend/model_manager/metadata/metadata_base.py
Normal file
@ -0,0 +1,202 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""This module defines core text-to-image model metadata fields.
|
||||
|
||||
Metadata comprises any descriptive information that is not essential
|
||||
for getting the model to run. For example "author" is metadata, while
|
||||
"type", "base" and "format" are not. The latter fields are part of the
|
||||
model's config, as defined in invokeai.backend.model_manager.config.
|
||||
|
||||
Note that the "name" and "description" are also present in `config`
|
||||
records. This is intentional. The config record fields are intended to
|
||||
be editable by the user as a form of customization. The metadata
|
||||
versions of these fields are intended to be kept in sync with the
|
||||
remote repo.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||
|
||||
from huggingface_hub import configure_http_backend, hf_hub_url
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
|
||||
from ..util import select_hf_files
|
||||
|
||||
|
||||
class UnknownMetadataException(Exception):
|
||||
"""Raised when no metadata is available for a model."""
|
||||
|
||||
|
||||
class CommercialUsage(str, Enum):
|
||||
"""Type of commercial usage allowed."""
|
||||
|
||||
No = "None"
|
||||
Image = "Image"
|
||||
Rent = "Rent"
|
||||
RentCivit = "RentCivit"
|
||||
Sell = "Sell"
|
||||
|
||||
|
||||
class LicenseRestrictions(BaseModel):
|
||||
"""Broad categories of licensing restrictions."""
|
||||
|
||||
AllowNoCredit: bool = Field(
|
||||
description="if true, model can be redistributed without crediting author", default=False
|
||||
)
|
||||
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
|
||||
AllowDifferentLicense: bool = Field(
|
||||
description="if true, derivatives of this model be redistributed under a different license", default=False
|
||||
)
|
||||
AllowCommercialUse: CommercialUsage = Field(
|
||||
description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set
|
||||
)
|
||||
|
||||
|
||||
class RemoteModelFile(BaseModel):
|
||||
"""Information about a downloadable file that forms part of a model."""
|
||||
|
||||
url: AnyHttpUrl = Field(description="The url to download this model file")
|
||||
path: Path = Field(description="The path to the file, relative to the model root")
|
||||
size: int = Field(description="The size of this file, in bytes")
|
||||
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
||||
|
||||
|
||||
class ModelMetadataBase(BaseModel):
|
||||
"""Base class for model metadata information."""
|
||||
|
||||
name: str = Field(description="model's name")
|
||||
author: str = Field(description="model's author")
|
||||
tags: Set[str] = Field(description="tags provided by model source")
|
||||
|
||||
|
||||
class BaseMetadata(ModelMetadataBase):
|
||||
"""Adds typing data for discriminated union."""
|
||||
|
||||
type: Literal["basemetadata"] = "basemetadata"
|
||||
|
||||
|
||||
class ModelMetadataWithFiles(ModelMetadataBase):
|
||||
"""Base class for metadata that contains a list of downloadable model file(s)."""
|
||||
|
||||
files: List[RemoteModelFile] = Field(description="model files and their sizes", default_factory=list)
|
||||
|
||||
def download_urls(
|
||||
self,
|
||||
variant: Optional[ModelRepoVariant] = None,
|
||||
subfolder: Optional[Path] = None,
|
||||
session: Optional[Session] = None,
|
||||
) -> List[RemoteModelFile]:
|
||||
"""
|
||||
Return a list of URLs needed to download the model.
|
||||
|
||||
:param variant: Return files needed to reconstruct the indicated variant (e.g. ModelRepoVariant('fp16'))
|
||||
:param subfolder: Return files in the designated subfolder only
|
||||
:param session: A request.Session object for offline testing
|
||||
|
||||
Note that the "variant" and "subfolder" concepts currently only apply to HuggingFace.
|
||||
However Civitai does have fields for the precision and format of its models, and may
|
||||
provide variant selection criteria in the future.
|
||||
"""
|
||||
return self.files
|
||||
|
||||
|
||||
class CivitaiMetadata(ModelMetadataWithFiles):
|
||||
"""Extended metadata fields provided by Civitai."""
|
||||
|
||||
type: Literal["civitai"] = "civitai"
|
||||
id: int = Field(description="Civitai version identifier")
|
||||
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
|
||||
version_id: int = Field(description="Civitai model version identifier")
|
||||
created: datetime = Field(description="date the model was created")
|
||||
updated: datetime = Field(description="date the model was last modified")
|
||||
published: datetime = Field(description="date the model was published to Civitai")
|
||||
description: str = Field(description="text description of model; may contain HTML")
|
||||
version_description: str = Field(
|
||||
description="text description of the model's reversion; usually change history; may contain HTML"
|
||||
)
|
||||
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
|
||||
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
|
||||
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
|
||||
download_url: AnyHttpUrl = Field(description="download URL for this model")
|
||||
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
|
||||
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
|
||||
weight_minmax: Tuple[float, float] = Field(
|
||||
description="minimum and maximum slider values for a LoRA or other secondary model", default=(-1.0, +2.0)
|
||||
) # note: For future use
|
||||
|
||||
@property
|
||||
def credit_required(self) -> bool:
|
||||
"""Return True if you must give credit for derivatives of this model and images generated from it."""
|
||||
return not self.restrictions.AllowNoCredit
|
||||
|
||||
@property
|
||||
def allow_commercial_use(self) -> bool:
|
||||
"""Return True if commercial use is allowed."""
|
||||
return self.restrictions.AllowCommercialUse != CommercialUsage("None")
|
||||
|
||||
@property
|
||||
def allow_derivatives(self) -> bool:
|
||||
"""Return True if derivatives of this model can be redistributed."""
|
||||
return self.restrictions.AllowDerivatives
|
||||
|
||||
@property
|
||||
def allow_different_license(self) -> bool:
|
||||
"""Return true if derivatives of this model can use a different license."""
|
||||
return self.restrictions.AllowDifferentLicense
|
||||
|
||||
|
||||
class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||
"""Extended metadata fields provided by HuggingFace."""
|
||||
|
||||
type: Literal["huggingface"] = "huggingface"
|
||||
id: str = Field(description="huggingface model id")
|
||||
tag_dict: Dict[str, Any]
|
||||
last_modified: datetime = Field(description="date of last commit to repo")
|
||||
|
||||
def download_urls(
|
||||
self,
|
||||
variant: Optional[ModelRepoVariant] = None,
|
||||
subfolder: Optional[Path] = None,
|
||||
session: Optional[Session] = None,
|
||||
) -> List[RemoteModelFile]:
|
||||
"""
|
||||
Return list of downloadable files, filtering by variant and subfolder, if any.
|
||||
|
||||
:param variant: Return model files needed to reconstruct the indicated variant
|
||||
:param subfolder: Return model files from the designated subfolder only
|
||||
:param session: A request.Session object used for internet-free testing
|
||||
|
||||
Note that there is special variant-filtering behavior here:
|
||||
When the fp16 variant is requested and not available, the
|
||||
full-precision model is returned.
|
||||
"""
|
||||
session = session or Session()
|
||||
configure_http_backend(backend_factory=lambda: session) # used in testing
|
||||
|
||||
paths = select_hf_files.filter_files(
|
||||
[x.path for x in self.files], variant, subfolder
|
||||
) # all files in the model
|
||||
prefix = f"{subfolder}/" if subfolder else ""
|
||||
|
||||
# the next step reads model_index.json to determine which subdirectories belong
|
||||
# to the model
|
||||
if Path(f"{prefix}model_index.json") in paths:
|
||||
url = hf_hub_url(self.id, filename="model_index.json", subfolder=subfolder)
|
||||
resp = session.get(url)
|
||||
resp.raise_for_status()
|
||||
submodels = resp.json()
|
||||
paths = [Path(subfolder or "", x) for x in paths if Path(x).parent.as_posix() in submodels]
|
||||
paths.insert(0, Path(f"{prefix}model_index.json"))
|
||||
|
||||
return [x for x in self.files if x.path in paths]
|
||||
|
||||
|
||||
AnyModelRepoMetadata = Annotated[Union[BaseMetadata, HuggingFaceMetadata, CivitaiMetadata], Field(discriminator="type")]
|
||||
AnyModelRepoMetadataValidator = TypeAdapter(AnyModelRepoMetadata)
|
221
invokeai/backend/model_manager/metadata/metadata_store.py
Normal file
221
invokeai/backend/model_manager/metadata/metadata_store.py
Normal file
@ -0,0 +1,221 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
from .fetch import ModelMetadataFetchBase
|
||||
from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException
|
||||
|
||||
|
||||
class ModelMetadataStore:
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise UnknownMetadataException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id,metadata FROM model_metadata;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select tag_text from tags;
|
||||
"""
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.model_id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches if matches else set()
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
@ -496,9 +496,9 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
|
||||
scheduler_conf = json.load(file)
|
||||
if scheduler_conf["prediction_type"] == "v_prediction":
|
||||
if scheduler_conf.get("prediction_type", "epsilon") == "v_prediction":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
elif scheduler_conf["prediction_type"] == "epsilon":
|
||||
elif scheduler_conf.get("prediction_type", "epsilon") == "epsilon":
|
||||
return SchedulerPredictionType.Epsilon
|
||||
else:
|
||||
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
|
||||
|
132
invokeai/backend/model_manager/util/select_hf_files.py
Normal file
132
invokeai/backend/model_manager/util/select_hf_files.py
Normal file
@ -0,0 +1,132 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Select the files from a HuggingFace repository needed for a particular model variant.
|
||||
|
||||
Usage:
|
||||
```
|
||||
from invokeai.backend.model_manager.util.select_hf_files import select_hf_model_files
|
||||
from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch
|
||||
|
||||
metadata = HuggingFaceMetadataFetch().from_url("https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0")
|
||||
files_to_download = select_hf_model_files(metadata.files, variant='onnx')
|
||||
```
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from ..config import ModelRepoVariant
|
||||
|
||||
|
||||
def filter_files(
|
||||
files: List[Path],
|
||||
variant: Optional[ModelRepoVariant] = None,
|
||||
subfolder: Optional[Path] = None,
|
||||
) -> List[Path]:
|
||||
"""
|
||||
Take a list of files in a HuggingFace repo root and return paths to files needed to load the model.
|
||||
|
||||
:param files: List of files relative to the repo root.
|
||||
:param subfolder: Filter by the indicated subfolder.
|
||||
:param variant: Filter by files belonging to a particular variant, such as fp16.
|
||||
|
||||
The file list can be obtained from the `files` field of HuggingFaceMetadata,
|
||||
as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
|
||||
"""
|
||||
variant = variant or ModelRepoVariant.DEFAULT
|
||||
paths: List[Path] = []
|
||||
|
||||
# Start by filtering on model file extensions, discarding images, docs, etc
|
||||
for file in files:
|
||||
if file.name.endswith((".json", ".txt")):
|
||||
paths.append(file)
|
||||
elif file.name.endswith(("learned_embeds.bin", "ip_adapter.bin", "lora_weights.safetensors")):
|
||||
paths.append(file)
|
||||
# BRITTLENESS WARNING!!
|
||||
# Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid
|
||||
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
||||
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
||||
# will adhere to this naming convention, so this is an area of brittleness.
|
||||
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||
paths.append(file)
|
||||
|
||||
# limit search to subfolder if requested
|
||||
if subfolder:
|
||||
paths = [x for x in paths if x.parent == Path(subfolder)]
|
||||
|
||||
# _filter_by_variant uniquifies the paths and returns a set
|
||||
return sorted(_filter_by_variant(paths, variant))
|
||||
|
||||
|
||||
def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]:
|
||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||
result = set()
|
||||
basenames: Dict[Path, Path] = {}
|
||||
for path in files:
|
||||
if path.suffix == ".onnx":
|
||||
if variant == ModelRepoVariant.ONNX:
|
||||
result.add(path)
|
||||
|
||||
elif "openvino_model" in path.name:
|
||||
if variant == ModelRepoVariant.OPENVINO:
|
||||
result.add(path)
|
||||
|
||||
elif "flax_model" in path.name:
|
||||
if variant == ModelRepoVariant.FLAX:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".json", ".txt"]:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [
|
||||
ModelRepoVariant.FP16,
|
||||
ModelRepoVariant.FP32,
|
||||
ModelRepoVariant.DEFAULT,
|
||||
]:
|
||||
parent = path.parent
|
||||
suffixes = path.suffixes
|
||||
if len(suffixes) == 2:
|
||||
variant_label, suffix = suffixes
|
||||
basename = parent / Path(path.stem).stem
|
||||
else:
|
||||
variant_label = ""
|
||||
suffix = suffixes[0]
|
||||
basename = parent / path.stem
|
||||
|
||||
if previous := basenames.get(basename):
|
||||
if (
|
||||
previous.suffix != ".safetensors" and suffix == ".safetensors"
|
||||
): # replace non-safetensors with safetensors when available
|
||||
basenames[basename] = path
|
||||
if variant_label == f".{variant}":
|
||||
basenames[basename] = path
|
||||
elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]:
|
||||
basenames[basename] = path
|
||||
else:
|
||||
basenames[basename] = path
|
||||
|
||||
else:
|
||||
continue
|
||||
|
||||
for v in basenames.values():
|
||||
result.add(v)
|
||||
|
||||
# If one of the architecture-related variants was specified and no files matched other than
|
||||
# config and text files then we return an empty list
|
||||
if (
|
||||
variant
|
||||
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OPENVINO, ModelRepoVariant.FLAX]
|
||||
and not any(variant.value in x.name for x in result)
|
||||
):
|
||||
return set()
|
||||
|
||||
# Prune folders that contain just a `config.json`. This happens when
|
||||
# the requested variant (e.g. "onnx") is missing
|
||||
directories: Dict[Path, int] = {}
|
||||
for x in result:
|
||||
if not x.parent:
|
||||
continue
|
||||
directories[x.parent] = directories.get(x.parent, 0) + 1
|
||||
|
||||
return {x for x in result if directories[x.parent] > 1 or x.name != "config.json"}
|
@ -228,6 +228,7 @@ exclude = [
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = true # ignores missing types in third-party libraries
|
||||
strict = true
|
||||
plugins = "pydantic.mypy"
|
||||
exclude = ["tests/*"]
|
||||
|
||||
# overrides for specific modules
|
||||
|
@ -5,11 +5,10 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
@ -19,8 +18,8 @@ TestAdapter.__test__ = False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session() -> requests.sessions.Session:
|
||||
sess = requests.Session()
|
||||
def session() -> Session:
|
||||
sess = TestSession()
|
||||
for i in ["12345", "9999", "54321"]:
|
||||
content = (
|
||||
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
|
||||
@ -160,7 +159,7 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||
queue.stop()
|
||||
|
||||
|
||||
def test_broken_callbacks(tmp_path: Path, session: requests.sessions.Session, capsys) -> None:
|
||||
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=session,
|
||||
)
|
||||
@ -191,7 +190,7 @@ def test_broken_callbacks(tmp_path: Path, session: requests.sessions.Session, ca
|
||||
queue.stop()
|
||||
|
||||
|
||||
def test_cancel(tmp_path: Path, session: requests.sessions.Session) -> None:
|
||||
def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
event_bus = DummyEventService()
|
||||
|
||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
||||
|
@ -2,11 +2,12 @@
|
||||
Test the model installer
|
||||
"""
|
||||
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import ValidationError
|
||||
from pydantic.networks import Url
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
@ -14,104 +15,50 @@ from invokeai.app.services.model_install import (
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallService,
|
||||
ModelInstallServiceBase,
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
|
||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
||||
|
||||
OS = platform.uname().system
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_file(datadir: Path) -> Path:
|
||||
return datadir / "test_embedding.safetensors"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_config(datadir: Path) -> InvokeAIAppConfig:
|
||||
return InvokeAIAppConfig(
|
||||
root=datadir / "root",
|
||||
models_dir=datadir / "root/models",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(
|
||||
app_config: InvokeAIAppConfig,
|
||||
) -> ModelRecordServiceBase:
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
db = create_mock_sqlite_database(app_config, logger)
|
||||
store: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def installer(app_config: InvokeAIAppConfig, store: ModelRecordServiceBase) -> ModelInstallServiceBase:
|
||||
installer = ModelInstallService(
|
||||
app_config=app_config,
|
||||
record_store=store,
|
||||
event_bus=DummyEventService(),
|
||||
)
|
||||
installer.start()
|
||||
return installer
|
||||
|
||||
|
||||
class DummyEvent(BaseModel):
|
||||
"""Dummy Event to use with Dummy Event service."""
|
||||
|
||||
event_name: str
|
||||
payload: Dict[str, Any]
|
||||
|
||||
|
||||
class DummyEventService(EventServiceBase):
|
||||
"""Dummy event service for testing."""
|
||||
|
||||
events: List[DummyEvent]
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.events = []
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
"""Dispatch an event by appending it to self.events."""
|
||||
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
|
||||
|
||||
|
||||
def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||
store = installer.record_store
|
||||
def test_registration(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
store = mm2_installer.record_store
|
||||
matches = store.search_by_attr(model_name="test_embedding")
|
||||
assert len(matches) == 0
|
||||
key = installer.register_path(test_file)
|
||||
key = mm2_installer.register_path(embedding_file)
|
||||
assert key is not None
|
||||
assert len(key) == 32
|
||||
|
||||
|
||||
def test_registration_meta(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||
store = installer.record_store
|
||||
key = installer.register_path(test_file)
|
||||
def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
store = mm2_installer.record_store
|
||||
key = mm2_installer.register_path(embedding_file)
|
||||
model_record = store.get_model(key)
|
||||
assert model_record is not None
|
||||
assert model_record.name == "test_embedding"
|
||||
assert model_record.type == ModelType.TextualInversion
|
||||
assert Path(model_record.path) == test_file
|
||||
assert Path(model_record.path) == embedding_file
|
||||
assert model_record.base == BaseModelType("sd-1")
|
||||
assert model_record.description is not None
|
||||
assert model_record.source is not None
|
||||
assert Path(model_record.source) == test_file
|
||||
assert Path(model_record.source) == embedding_file
|
||||
|
||||
|
||||
def test_registration_meta_override_fail(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
key = None
|
||||
with pytest.raises(ValidationError):
|
||||
key = installer.register_path(test_file, {"name": "banana_sushi", "type": ModelType("lora")})
|
||||
key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")})
|
||||
assert key is None
|
||||
|
||||
|
||||
def test_registration_meta_override_succeed(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||
store = installer.record_store
|
||||
key = installer.register_path(
|
||||
test_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash"}
|
||||
def test_registration_meta_override_succeed(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
store = mm2_installer.record_store
|
||||
key = mm2_installer.register_path(
|
||||
embedding_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash"}
|
||||
)
|
||||
model_record = store.get_model(key)
|
||||
assert model_record.name == "banana_sushi"
|
||||
@ -119,40 +66,59 @@ def test_registration_meta_override_succeed(installer: ModelInstallServiceBase,
|
||||
assert model_record.current_hash == "New Hash"
|
||||
|
||||
|
||||
def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
|
||||
store = installer.record_store
|
||||
key = installer.install_path(test_file)
|
||||
def test_install(
|
||||
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
|
||||
) -> None:
|
||||
store = mm2_installer.record_store
|
||||
key = mm2_installer.install_path(embedding_file)
|
||||
model_record = store.get_model(key)
|
||||
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
|
||||
assert model_record.source == test_file.as_posix()
|
||||
assert model_record.source == embedding_file.as_posix()
|
||||
|
||||
|
||||
def test_background_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
|
||||
@pytest.mark.parametrize(
|
||||
"fixture_name,size,destination",
|
||||
[
|
||||
("embedding_file", 15440, "sd-1/embedding/test_embedding.safetensors"),
|
||||
("diffusers_dir", 8241 if OS == "Windows" else 7907, "sdxl/main/test-diffusers-main"), # EOL chars
|
||||
],
|
||||
)
|
||||
def test_background_install(
|
||||
mm2_installer: ModelInstallServiceBase,
|
||||
fixture_name: str,
|
||||
size: int,
|
||||
destination: str,
|
||||
mm2_app_config: InvokeAIAppConfig,
|
||||
request: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""Note: may want to break this down into several smaller unit tests."""
|
||||
path = test_file
|
||||
path: Path = request.getfixturevalue(fixture_name)
|
||||
description = "Test of metadata assignment"
|
||||
source = LocalModelSource(path=path, inplace=False)
|
||||
job = installer.import_model(source, config={"description": description})
|
||||
job = mm2_installer.import_model(source, config={"description": description})
|
||||
assert job is not None
|
||||
assert isinstance(job, ModelInstallJob)
|
||||
|
||||
# See if job is registered properly
|
||||
assert job in installer.get_job(source)
|
||||
assert job in mm2_installer.get_job_by_source(source)
|
||||
|
||||
# test that the job object tracked installation correctly
|
||||
jobs = installer.wait_for_installs()
|
||||
jobs = mm2_installer.wait_for_installs()
|
||||
assert len(jobs) > 0
|
||||
my_job = [x for x in jobs if x.source == source]
|
||||
assert len(my_job) == 1
|
||||
assert my_job[0].status == InstallStatus.COMPLETED
|
||||
assert job == my_job[0]
|
||||
assert job.status == InstallStatus.COMPLETED
|
||||
assert job.total_bytes == size
|
||||
|
||||
# test that the expected events were issued
|
||||
bus = installer.event_bus
|
||||
assert bus is not None # sigh - ruff is a stickler for type checking
|
||||
assert isinstance(bus, DummyEventService)
|
||||
bus = mm2_installer.event_bus
|
||||
assert bus
|
||||
assert hasattr(bus, "events")
|
||||
|
||||
assert len(bus.events) == 2
|
||||
event_names = [x.event_name for x in bus.events]
|
||||
assert "model_install_started" in event_names
|
||||
assert "model_install_running" in event_names
|
||||
assert "model_install_completed" in event_names
|
||||
assert Path(bus.events[0].payload["source"]) == source
|
||||
assert Path(bus.events[1].payload["source"]) == source
|
||||
@ -160,41 +126,134 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path,
|
||||
assert key is not None
|
||||
|
||||
# see if the thing actually got installed at the expected location
|
||||
model_record = installer.record_store.get_model(key)
|
||||
model_record = mm2_installer.record_store.get_model(key)
|
||||
assert model_record is not None
|
||||
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
|
||||
assert Path(app_config.models_dir / model_record.path).exists()
|
||||
assert model_record.path == destination
|
||||
assert Path(mm2_app_config.models_dir / model_record.path).exists()
|
||||
|
||||
# see if metadata was properly passed through
|
||||
assert model_record.description == description
|
||||
|
||||
# see if job filtering works
|
||||
assert mm2_installer.get_job_by_source(source)[0] == job
|
||||
|
||||
# see if prune works properly
|
||||
installer.prune_jobs()
|
||||
assert not installer.get_job(source)
|
||||
mm2_installer.prune_jobs()
|
||||
assert not mm2_installer.get_job_by_source(source)
|
||||
|
||||
|
||||
def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
|
||||
store = installer.record_store
|
||||
key = installer.install_path(test_file)
|
||||
def test_not_inplace_install(
|
||||
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
|
||||
) -> None:
|
||||
source = LocalModelSource(path=embedding_file, inplace=False)
|
||||
job = mm2_installer.import_model(source)
|
||||
mm2_installer.wait_for_installs()
|
||||
assert job is not None
|
||||
assert job.config_out is not None
|
||||
assert Path(job.config_out.path) != embedding_file
|
||||
assert Path(mm2_app_config.models_dir / job.config_out.path).exists()
|
||||
|
||||
|
||||
def test_inplace_install(
|
||||
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
|
||||
) -> None:
|
||||
source = LocalModelSource(path=embedding_file, inplace=True)
|
||||
job = mm2_installer.import_model(source)
|
||||
mm2_installer.wait_for_installs()
|
||||
assert job is not None
|
||||
assert job.config_out is not None
|
||||
assert Path(job.config_out.path) == embedding_file
|
||||
|
||||
|
||||
def test_delete_install(
|
||||
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
|
||||
) -> None:
|
||||
store = mm2_installer.record_store
|
||||
key = mm2_installer.install_path(embedding_file)
|
||||
model_record = store.get_model(key)
|
||||
assert Path(app_config.models_dir / model_record.path).exists()
|
||||
assert test_file.exists() # original should still be there after installation
|
||||
installer.delete(key)
|
||||
assert Path(mm2_app_config.models_dir / model_record.path).exists()
|
||||
assert embedding_file.exists() # original should still be there after installation
|
||||
mm2_installer.delete(key)
|
||||
assert not Path(
|
||||
app_config.models_dir / model_record.path
|
||||
mm2_app_config.models_dir / model_record.path
|
||||
).exists() # after deletion, installed copy should not exist
|
||||
assert test_file.exists() # but original should still be there
|
||||
assert embedding_file.exists() # but original should still be there
|
||||
with pytest.raises(UnknownModelException):
|
||||
store.get_model(key)
|
||||
|
||||
|
||||
def test_delete_register(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
|
||||
store = installer.record_store
|
||||
key = installer.register_path(test_file)
|
||||
def test_delete_register(
|
||||
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
|
||||
) -> None:
|
||||
store = mm2_installer.record_store
|
||||
key = mm2_installer.register_path(embedding_file)
|
||||
model_record = store.get_model(key)
|
||||
assert Path(app_config.models_dir / model_record.path).exists()
|
||||
assert test_file.exists() # original should still be there after installation
|
||||
installer.delete(key)
|
||||
assert Path(app_config.models_dir / model_record.path).exists()
|
||||
assert Path(mm2_app_config.models_dir / model_record.path).exists()
|
||||
assert embedding_file.exists() # original should still be there after installation
|
||||
mm2_installer.delete(key)
|
||||
assert Path(mm2_app_config.models_dir / model_record.path).exists()
|
||||
with pytest.raises(UnknownModelException):
|
||||
store.get_model(key)
|
||||
|
||||
|
||||
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
|
||||
|
||||
bus = mm2_installer.event_bus
|
||||
store = mm2_installer.record_store
|
||||
assert store is not None
|
||||
assert bus is not None
|
||||
assert hasattr(bus, "events") # the dummy event service has this
|
||||
|
||||
job = mm2_installer.import_model(source)
|
||||
assert job.source == source
|
||||
job_list = mm2_installer.wait_for_installs(timeout=10)
|
||||
assert len(job_list) == 1
|
||||
assert job.complete
|
||||
assert job.config_out
|
||||
|
||||
key = job.config_out.key
|
||||
model_record = store.get_model(key)
|
||||
assert Path(mm2_app_config.models_dir / model_record.path).exists()
|
||||
|
||||
assert len(bus.events) == 3
|
||||
event_names = [x.event_name for x in bus.events]
|
||||
assert event_names == ["model_install_downloading", "model_install_running", "model_install_completed"]
|
||||
|
||||
|
||||
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
|
||||
bus = mm2_installer.event_bus
|
||||
store = mm2_installer.record_store
|
||||
assert isinstance(bus, EventServiceBase)
|
||||
assert store is not None
|
||||
|
||||
job = mm2_installer.import_model(source)
|
||||
job_list = mm2_installer.wait_for_installs(timeout=10)
|
||||
assert len(job_list) == 1
|
||||
assert job.complete
|
||||
assert job.config_out
|
||||
|
||||
key = job.config_out.key
|
||||
model_record = store.get_model(key)
|
||||
assert Path(mm2_app_config.models_dir / model_record.path).exists()
|
||||
assert model_record.type == ModelType.Main
|
||||
assert model_record.format == ModelFormat.Diffusers
|
||||
|
||||
assert hasattr(bus, "events") # the dummyeventservice has this
|
||||
assert len(bus.events) >= 3
|
||||
event_names = {x.event_name for x in bus.events}
|
||||
assert event_names == {"model_install_downloading", "model_install_running", "model_install_completed"}
|
||||
|
||||
|
||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
|
||||
job = mm2_installer.import_model(source)
|
||||
mm2_installer.wait_for_installs(timeout=10)
|
||||
assert job.status == InstallStatus.ERROR
|
||||
assert job.errored
|
||||
assert job.error_type == "HTTPError"
|
||||
assert job.error
|
||||
assert "NOT FOUND" in job.error
|
||||
assert "Traceback" in job.error
|
||||
|
@ -1 +0,0 @@
|
||||
This directory is used by pytest-datadir.
|
@ -1 +0,0 @@
|
||||
Dummy file to establish git path.
|
@ -10,6 +10,7 @@ import pytest
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
ModelRecordOrderBy,
|
||||
ModelRecordServiceBase,
|
||||
ModelRecordServiceSQL,
|
||||
UnknownModelException,
|
||||
@ -22,14 +23,16 @@ from invokeai.backend.model_manager.config import (
|
||||
TextualInversionConfig,
|
||||
VaeDiffusersConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import BaseMetadata
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(
|
||||
datadir: Any,
|
||||
) -> ModelRecordServiceBase:
|
||||
) -> ModelRecordServiceSQL:
|
||||
config = InvokeAIAppConfig(root=datadir)
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
db = create_mock_sqlite_database(config, logger)
|
||||
@ -268,3 +271,50 @@ def test_filter_2(store: ModelRecordServiceBase):
|
||||
model_name="dup_name1",
|
||||
)
|
||||
assert len(matches) == 1
|
||||
|
||||
|
||||
def test_summary(mm2_record_store: ModelRecordServiceSQL) -> None:
|
||||
# The fixture provides us with five configs.
|
||||
for x in range(1, 5):
|
||||
key = f"test_config_{x}"
|
||||
name = f"name_{x}"
|
||||
author = f"author_{x}"
|
||||
tags = {f"tag{y}" for y in range(1, x)}
|
||||
mm2_record_store.metadata_store.add_metadata(
|
||||
model_key=key, metadata=BaseMetadata(name=name, author=author, tags=tags)
|
||||
)
|
||||
# sanity check that the tags sent in all right
|
||||
assert mm2_record_store.get_metadata("test_config_3").tags == {"tag1", "tag2"}
|
||||
assert mm2_record_store.get_metadata("test_config_4").tags == {"tag1", "tag2", "tag3"}
|
||||
|
||||
# get summary
|
||||
summary1 = mm2_record_store.list_models(page=0, per_page=100)
|
||||
assert summary1.page == 0
|
||||
assert summary1.pages == 1
|
||||
assert summary1.per_page == 100
|
||||
assert summary1.total == 5
|
||||
assert len(summary1.items) == 5
|
||||
assert summary1.items[0].name == "test5" # lora / sd-1 / diffusers / test5
|
||||
|
||||
# find test_config_3
|
||||
config3 = [x for x in summary1.items if x.key == "test_config_3"][0]
|
||||
assert config3.description == "This is test 3"
|
||||
assert config3.tags == {"tag1", "tag2"}
|
||||
|
||||
# find test_config_5
|
||||
config5 = [x for x in summary1.items if x.key == "test_config_5"][0]
|
||||
assert config5.tags == set()
|
||||
assert config5.description == ""
|
||||
|
||||
# test paging
|
||||
summary2 = mm2_record_store.list_models(page=1, per_page=2)
|
||||
assert summary2.page == 1
|
||||
assert summary2.per_page == 2
|
||||
assert summary2.pages == 3
|
||||
assert summary1.items[2].name == summary2.items[0].name
|
||||
|
||||
# test sorting
|
||||
summary = mm2_record_store.list_models(page=0, per_page=100, order_by=ModelRecordOrderBy.Name)
|
||||
print(summary.items)
|
||||
assert summary.items[0].name == "model1"
|
||||
assert summary.items[-1].name == "test5"
|
||||
|
1
tests/backend/model_manager_2/data/invokeai_root/README
Normal file
1
tests/backend/model_manager_2/data/invokeai_root/README
Normal file
@ -0,0 +1 @@
|
||||
This is an empty invokeai root that is used as a template for model manager tests.
|
@ -0,0 +1 @@
|
||||
This is a template empty invokeai root directory used to test model management.
|
@ -0,0 +1 @@
|
||||
This is a template empty invokeai root directory used to test model management.
|
@ -0,0 +1,34 @@
|
||||
{
|
||||
"_class_name": "StableDiffusionXLPipeline",
|
||||
"_diffusers_version": "0.23.0",
|
||||
"_name_or_path": "stabilityai/sdxl-turbo",
|
||||
"force_zeros_for_empty_prompt": true,
|
||||
"scheduler": [
|
||||
"diffusers",
|
||||
"EulerAncestralDiscreteScheduler"
|
||||
],
|
||||
"text_encoder": [
|
||||
"transformers",
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"text_encoder_2": [
|
||||
"transformers",
|
||||
"CLIPTextModelWithProjection"
|
||||
],
|
||||
"tokenizer": [
|
||||
"transformers",
|
||||
"CLIPTokenizer"
|
||||
],
|
||||
"tokenizer_2": [
|
||||
"transformers",
|
||||
"CLIPTokenizer"
|
||||
],
|
||||
"unet": [
|
||||
"diffusers",
|
||||
"UNet2DConditionModel"
|
||||
],
|
||||
"vae": [
|
||||
"diffusers",
|
||||
"AutoencoderKL"
|
||||
]
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
{
|
||||
"_class_name": "EulerAncestralDiscreteScheduler",
|
||||
"_diffusers_version": "0.23.0",
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"clip_sample": false,
|
||||
"interpolation_type": "linear",
|
||||
"num_train_timesteps": 1000,
|
||||
"prediction_type": "epsilon",
|
||||
"sample_max_value": 1.0,
|
||||
"set_alpha_to_one": false,
|
||||
"skip_prk_steps": true,
|
||||
"steps_offset": 1,
|
||||
"timestep_spacing": "trailing",
|
||||
"trained_betas": null
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
{
|
||||
"_name_or_path": "/home/lstein/.cache/huggingface/hub/models--stabilityai--sdxl-turbo/snapshots/fbda35297a8280789ffe2e25206800702fa5c4c1/text_encoder",
|
||||
"architectures": [
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"max_position_embeddings": 77,
|
||||
"model_type": "clip_text_model",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 1,
|
||||
"projection_dim": 768,
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.35.0",
|
||||
"vocab_size": 49408
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
{
|
||||
"_name_or_path": "/home/lstein/.cache/huggingface/hub/models--stabilityai--sdxl-turbo/snapshots/fbda35297a8280789ffe2e25206800702fa5c4c1/text_encoder_2",
|
||||
"architectures": [
|
||||
"CLIPTextModelWithProjection"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": 1280,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 5120,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"max_position_embeddings": 77,
|
||||
"model_type": "clip_text_model",
|
||||
"num_attention_heads": 20,
|
||||
"num_hidden_layers": 32,
|
||||
"pad_token_id": 1,
|
||||
"projection_dim": 1280,
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.35.0",
|
||||
"vocab_size": 49408
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"49406": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49407": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<|startoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"do_lower_case": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "!",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "!",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49406": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49407": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<|startoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"do_lower_case": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"pad_token": "!",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
@ -0,0 +1,73 @@
|
||||
{
|
||||
"_class_name": "UNet2DConditionModel",
|
||||
"_diffusers_version": "0.23.0",
|
||||
"_name_or_path": "/home/lstein/.cache/huggingface/hub/models--stabilityai--sdxl-turbo/snapshots/fbda35297a8280789ffe2e25206800702fa5c4c1/unet",
|
||||
"act_fn": "silu",
|
||||
"addition_embed_type": "text_time",
|
||||
"addition_embed_type_num_heads": 64,
|
||||
"addition_time_embed_dim": 256,
|
||||
"attention_head_dim": [
|
||||
5,
|
||||
10,
|
||||
20
|
||||
],
|
||||
"attention_type": "default",
|
||||
"block_out_channels": [
|
||||
320,
|
||||
640,
|
||||
1280
|
||||
],
|
||||
"center_input_sample": false,
|
||||
"class_embed_type": null,
|
||||
"class_embeddings_concat": false,
|
||||
"conv_in_kernel": 3,
|
||||
"conv_out_kernel": 3,
|
||||
"cross_attention_dim": 2048,
|
||||
"cross_attention_norm": null,
|
||||
"down_block_types": [
|
||||
"DownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D"
|
||||
],
|
||||
"downsample_padding": 1,
|
||||
"dropout": 0.0,
|
||||
"dual_cross_attention": false,
|
||||
"encoder_hid_dim": null,
|
||||
"encoder_hid_dim_type": null,
|
||||
"flip_sin_to_cos": true,
|
||||
"freq_shift": 0,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_only_cross_attention": null,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_attention_heads": null,
|
||||
"num_class_embeds": null,
|
||||
"only_cross_attention": false,
|
||||
"out_channels": 4,
|
||||
"projection_class_embeddings_input_dim": 2816,
|
||||
"resnet_out_scale_factor": 1.0,
|
||||
"resnet_skip_time_act": false,
|
||||
"resnet_time_scale_shift": "default",
|
||||
"reverse_transformer_layers_per_block": null,
|
||||
"sample_size": 64,
|
||||
"time_cond_proj_dim": null,
|
||||
"time_embedding_act_fn": null,
|
||||
"time_embedding_dim": null,
|
||||
"time_embedding_type": "positional",
|
||||
"timestep_post_act": null,
|
||||
"transformer_layers_per_block": [
|
||||
1,
|
||||
2,
|
||||
10
|
||||
],
|
||||
"up_block_types": [
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"UpBlock2D"
|
||||
],
|
||||
"upcast_attention": null,
|
||||
"use_linear_projection": true
|
||||
}
|
@ -0,0 +1,32 @@
|
||||
{
|
||||
"_class_name": "AutoencoderKL",
|
||||
"_diffusers_version": "0.23.0",
|
||||
"_name_or_path": "/home/lstein/.cache/huggingface/hub/models--stabilityai--sdxl-turbo/snapshots/fbda35297a8280789ffe2e25206800702fa5c4c1/vae",
|
||||
"act_fn": "silu",
|
||||
"block_out_channels": [
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
512
|
||||
],
|
||||
"down_block_types": [
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D"
|
||||
],
|
||||
"force_upcast": true,
|
||||
"in_channels": 3,
|
||||
"latent_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"norm_num_groups": 32,
|
||||
"out_channels": 3,
|
||||
"sample_size": 1024,
|
||||
"scaling_factor": 0.13025,
|
||||
"up_block_types": [
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D"
|
||||
]
|
||||
}
|
265
tests/backend/model_manager_2/model_manager_2_fixtures.py
Normal file
265
tests/backend/model_manager_2/model_manager_2_fixtures.py
Normal file
@ -0,0 +1,265 @@
|
||||
# Fixtures to support testing of the model_manager v2 installer, metadata and record store
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
|
||||
RepoCivitaiModelMetadata1,
|
||||
RepoCivitaiVersionMetadata1,
|
||||
RepoHFMetadata1,
|
||||
RepoHFMetadata1_nofp16,
|
||||
RepoHFModelJson1,
|
||||
)
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
|
||||
|
||||
class DummyEvent(BaseModel):
|
||||
"""Dummy Event to use with Dummy Event service."""
|
||||
|
||||
event_name: str
|
||||
payload: Dict[str, Any]
|
||||
|
||||
|
||||
class DummyEventService(EventServiceBase):
|
||||
"""Dummy event service for testing."""
|
||||
|
||||
events: List[DummyEvent]
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.events = []
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
"""Dispatch an event by appending it to self.events."""
|
||||
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
|
||||
|
||||
|
||||
# Create a temporary directory using the contents of `./data/invokeai_root` as the template
|
||||
@pytest.fixture
|
||||
def mm2_root_dir(tmp_path_factory) -> Path:
|
||||
root_template = Path(__file__).resolve().parent / "data" / "invokeai_root"
|
||||
temp_dir: Path = tmp_path_factory.mktemp("data") / "invokeai_root"
|
||||
shutil.copytree(root_template, temp_dir)
|
||||
return temp_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_model_files(tmp_path_factory) -> Path:
|
||||
root_template = Path(__file__).resolve().parent / "data" / "test_files"
|
||||
temp_dir: Path = tmp_path_factory.mktemp("data") / "test_files"
|
||||
shutil.copytree(root_template, temp_dir)
|
||||
return temp_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_file(mm2_model_files: Path) -> Path:
|
||||
return mm2_model_files / "test_embedding.safetensors"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def diffusers_dir(mm2_model_files: Path) -> Path:
|
||||
return mm2_model_files / "test-diffusers-main"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
|
||||
app_config = InvokeAIAppConfig(
|
||||
root=mm2_root_dir,
|
||||
models_dir=mm2_root_dir / "models",
|
||||
)
|
||||
return app_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
|
||||
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
store = ModelRecordServiceSQL(db)
|
||||
# add five simple config records to the database
|
||||
raw1 = {
|
||||
"path": "/tmp/foo1",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test2",
|
||||
"base": BaseModelType("sd-2"),
|
||||
"type": ModelType("vae"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "stabilityai/sdxl-vae",
|
||||
}
|
||||
raw2 = {
|
||||
"path": "/tmp/foo2.ckpt",
|
||||
"name": "model1",
|
||||
"format": ModelFormat("checkpoint"),
|
||||
"base": BaseModelType("sd-1"),
|
||||
"type": "main",
|
||||
"config": "/tmp/foo.yaml",
|
||||
"variant": "normal",
|
||||
"original_hash": "111222333444",
|
||||
"source": "https://civitai.com/models/206883/split",
|
||||
}
|
||||
raw3 = {
|
||||
"path": "/tmp/foo3",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test3",
|
||||
"base": BaseModelType("sdxl"),
|
||||
"type": ModelType("main"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "author3/model3",
|
||||
"description": "This is test 3",
|
||||
}
|
||||
raw4 = {
|
||||
"path": "/tmp/foo4",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test4",
|
||||
"base": BaseModelType("sdxl"),
|
||||
"type": ModelType("lora"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "author4/model4",
|
||||
}
|
||||
raw5 = {
|
||||
"path": "/tmp/foo5",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test5",
|
||||
"base": BaseModelType("sd-1"),
|
||||
"type": ModelType("lora"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "author4/model5",
|
||||
}
|
||||
store.add_model("test_config_1", raw1)
|
||||
store.add_model("test_config_2", raw2)
|
||||
store.add_model("test_config_3", raw3)
|
||||
store.add_model("test_config_4", raw4)
|
||||
store.add_model("test_config_5", raw5)
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStore:
|
||||
db = mm2_record_store._db # to ensure we are sharing the same database
|
||||
return ModelMetadataStore(db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||
"""This fixtures defines a series of mock URLs for testing download and installation."""
|
||||
sess = TestSession()
|
||||
sess.mount(
|
||||
"https://test.com/missing_model.safetensors",
|
||||
TestAdapter(
|
||||
b"missing",
|
||||
status=404,
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://huggingface.co/api/models/stabilityai/sdxl-turbo",
|
||||
TestAdapter(
|
||||
RepoHFMetadata1,
|
||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://huggingface.co/api/models/stabilityai/sdxl-turbo-nofp16",
|
||||
TestAdapter(
|
||||
RepoHFMetadata1_nofp16,
|
||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1_nofp16)},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://civitai.com/api/v1/model-versions/242807",
|
||||
TestAdapter(
|
||||
RepoCivitaiVersionMetadata1,
|
||||
headers={
|
||||
"Content-Length": len(RepoCivitaiVersionMetadata1),
|
||||
},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://civitai.com/api/v1/models/215485",
|
||||
TestAdapter(
|
||||
RepoCivitaiModelMetadata1,
|
||||
headers={
|
||||
"Content-Length": len(RepoCivitaiModelMetadata1),
|
||||
},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/model_index.json",
|
||||
TestAdapter(
|
||||
RepoHFModelJson1,
|
||||
headers={
|
||||
"Content-Length": len(RepoHFModelJson1),
|
||||
},
|
||||
),
|
||||
)
|
||||
with open(embedding_file, "rb") as f:
|
||||
data = f.read() # file is small - just 15K
|
||||
sess.mount(
|
||||
"https://www.test.foo/download/test_embedding.safetensors",
|
||||
TestAdapter(data, headers={"Content-Type": "application/octet-stream", "Content-Length": len(data)}),
|
||||
)
|
||||
sess.mount(
|
||||
"https://huggingface.co/api/models/stabilityai/sdxl-turbo",
|
||||
TestAdapter(
|
||||
RepoHFMetadata1,
|
||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
|
||||
),
|
||||
)
|
||||
for root, _, files in os.walk(diffusers_dir):
|
||||
for name in files:
|
||||
path = Path(root, name)
|
||||
url_base = path.relative_to(diffusers_dir).as_posix()
|
||||
url = f"https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/{url_base}"
|
||||
with open(path, "rb") as f:
|
||||
data = f.read()
|
||||
sess.mount(
|
||||
url,
|
||||
TestAdapter(
|
||||
data,
|
||||
headers={
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Content-Length": len(data),
|
||||
},
|
||||
),
|
||||
)
|
||||
return sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> ModelInstallServiceBase:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
events = DummyEventService()
|
||||
store = ModelRecordServiceSQL(db)
|
||||
metadata_store = ModelMetadataStore(db)
|
||||
|
||||
download_queue = DownloadQueueService(requests_session=mm2_session)
|
||||
download_queue.start()
|
||||
|
||||
installer = ModelInstallService(
|
||||
app_config=mm2_app_config,
|
||||
record_store=store,
|
||||
download_queue=download_queue,
|
||||
metadata_store=metadata_store,
|
||||
event_bus=events,
|
||||
session=mm2_session,
|
||||
)
|
||||
installer.start()
|
||||
return installer
|
File diff suppressed because one or more lines are too long
@ -0,0 +1,201 @@
|
||||
"""
|
||||
Test model metadata fetching and storage.
|
||||
"""
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pydantic.networks import HttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
CivitaiMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
CommercialUsage,
|
||||
HuggingFaceMetadata,
|
||||
HuggingFaceMetadataFetch,
|
||||
ModelMetadataStore,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
from invokeai.backend.model_manager.util import select_hf_files
|
||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
||||
|
||||
|
||||
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
tags = {"text-to-image", "diffusers"}
|
||||
input_metadata = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags=tags,
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
mm2_metadata_store.add_metadata("test_config_1", input_metadata)
|
||||
output_metadata = mm2_metadata_store.get_metadata("test_config_1")
|
||||
assert input_metadata == output_metadata
|
||||
with pytest.raises(UnknownMetadataException):
|
||||
mm2_metadata_store.add_metadata("unknown_key", input_metadata)
|
||||
assert mm2_metadata_store.list_tags() == tags
|
||||
|
||||
|
||||
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
input_metadata = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers"},
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
mm2_metadata_store.add_metadata("test_config_1", input_metadata)
|
||||
input_metadata.name = "new-name"
|
||||
mm2_metadata_store.update_metadata("test_config_1", input_metadata)
|
||||
output_metadata = mm2_metadata_store.get_metadata("test_config_1")
|
||||
assert output_metadata.name == "new-name"
|
||||
assert input_metadata == output_metadata
|
||||
|
||||
|
||||
def test_metadata_search(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
metadata1 = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers"},
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata2 = HuggingFaceMetadata(
|
||||
name="model2",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers", "community-contributed"},
|
||||
id="author2/model2",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata3 = HuggingFaceMetadata(
|
||||
name="model3",
|
||||
author="author3",
|
||||
tags={"text-to-image", "checkpoint", "community-contributed"},
|
||||
id="author3/model3",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
mm2_metadata_store.add_metadata("test_config_1", metadata1)
|
||||
mm2_metadata_store.add_metadata("test_config_2", metadata2)
|
||||
mm2_metadata_store.add_metadata("test_config_3", metadata3)
|
||||
|
||||
matches = mm2_metadata_store.search_by_author("stabilityai")
|
||||
assert len(matches) == 2
|
||||
assert "test_config_1" in matches
|
||||
assert "test_config_2" in matches
|
||||
matches = mm2_metadata_store.search_by_author("Sherlock Holmes")
|
||||
assert not matches
|
||||
|
||||
matches = mm2_metadata_store.search_by_name("model3")
|
||||
assert len(matches) == 1
|
||||
assert "test_config_3" in matches
|
||||
|
||||
matches = mm2_metadata_store.search_by_tag({"text-to-image"})
|
||||
assert len(matches) == 3
|
||||
|
||||
matches = mm2_metadata_store.search_by_tag({"text-to-image", "diffusers"})
|
||||
assert len(matches) == 2
|
||||
assert "test_config_1" in matches
|
||||
assert "test_config_2" in matches
|
||||
|
||||
matches = mm2_metadata_store.search_by_tag({"checkpoint", "community-contributed"})
|
||||
assert len(matches) == 1
|
||||
assert "test_config_3" in matches
|
||||
|
||||
# does the tag table update correctly?
|
||||
matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
|
||||
assert not matches
|
||||
assert mm2_metadata_store.list_tags() == {"text-to-image", "diffusers", "community-contributed", "checkpoint"}
|
||||
metadata3.tags.add("licensed-for-commercial-use")
|
||||
mm2_metadata_store.update_metadata("test_config_3", metadata3)
|
||||
assert mm2_metadata_store.list_tags() == {
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
"community-contributed",
|
||||
"checkpoint",
|
||||
"licensed-for-commercial-use",
|
||||
}
|
||||
matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
|
||||
assert len(matches) == 1
|
||||
|
||||
|
||||
def test_metadata_civitai_fetch(mm2_session: Session) -> None:
|
||||
fetcher = CivitaiMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_url(HttpUrl("https://civitai.com/models/215485/SDXL-turbo"))
|
||||
assert isinstance(metadata, CivitaiMetadata)
|
||||
assert metadata.id == 215485
|
||||
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
|
||||
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
|
||||
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
|
||||
assert metadata.version_id == 242807
|
||||
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
|
||||
|
||||
|
||||
def test_metadata_hf_fetch(mm2_session: Session) -> None:
|
||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
||||
assert metadata.author == "test_author" # this is not the same as the original
|
||||
assert metadata.files
|
||||
assert metadata.tags == {
|
||||
"diffusers",
|
||||
"onnx",
|
||||
"safetensors",
|
||||
"text-to-image",
|
||||
"license:other",
|
||||
"has_space",
|
||||
"diffusers:StableDiffusionXLPipeline",
|
||||
"region:us",
|
||||
}
|
||||
|
||||
|
||||
def test_metadata_hf_filter(mm2_session: Session) -> None:
|
||||
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
||||
files = [x.path for x in metadata.files]
|
||||
fp16_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16"))
|
||||
assert Path("sdxl-turbo/text_encoder/model.fp16.safetensors") in fp16_files
|
||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") not in fp16_files
|
||||
|
||||
fp32_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp32"))
|
||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") in fp32_files
|
||||
assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in fp32_files
|
||||
|
||||
onnx_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("onnx"))
|
||||
assert Path("sdxl-turbo/text_encoder/model.onnx") in onnx_files
|
||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") not in onnx_files
|
||||
|
||||
default_files = select_hf_files.filter_files(files)
|
||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") in default_files
|
||||
assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in default_files
|
||||
|
||||
openvino_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("openvino"))
|
||||
print(openvino_files)
|
||||
assert len(openvino_files) == 0
|
||||
|
||||
flax_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("flax"))
|
||||
print(flax_files)
|
||||
assert not flax_files
|
||||
|
||||
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(
|
||||
HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo-nofp16")
|
||||
)
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
||||
files = [x.path for x in metadata.files]
|
||||
filtered_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16"))
|
||||
assert (
|
||||
Path("sdxl-turbo-nofp16/text_encoder/model.safetensors") in filtered_files
|
||||
) # confirm that default is returned
|
||||
assert Path("sdxl-turbo-nofp16/text_encoder/model.16.safetensors") not in filtered_files
|
||||
|
||||
|
||||
def test_metadata_hf_urls(mm2_session: Session) -> None:
|
||||
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
239
tests/backend/model_manager_2/util/test_hf_model_select.py
Normal file
239
tests/backend/model_manager_2/util/test_hf_model_select.py
Normal file
@ -0,0 +1,239 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.util.select_hf_files import filter_files
|
||||
|
||||
|
||||
# This is the full list of model paths returned by the HF API for sdxl-base
|
||||
@pytest.fixture
|
||||
def sdxl_base_files() -> List[Path]:
|
||||
return [
|
||||
Path(x)
|
||||
for x in [
|
||||
".gitattributes",
|
||||
"01.png",
|
||||
"LICENSE.md",
|
||||
"README.md",
|
||||
"comparison.png",
|
||||
"model_index.json",
|
||||
"pipeline.png",
|
||||
"scheduler/scheduler_config.json",
|
||||
"sd_xl_base_1.0.safetensors",
|
||||
"sd_xl_base_1.0_0.9vae.safetensors",
|
||||
"sd_xl_offset_example-lora_1.0.safetensors",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/flax_model.msgpack",
|
||||
"text_encoder/model.fp16.safetensors",
|
||||
"text_encoder/model.onnx",
|
||||
"text_encoder/model.safetensors",
|
||||
"text_encoder/openvino_model.bin",
|
||||
"text_encoder/openvino_model.xml",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/flax_model.msgpack",
|
||||
"text_encoder_2/model.fp16.safetensors",
|
||||
"text_encoder_2/model.onnx",
|
||||
"text_encoder_2/model.onnx_data",
|
||||
"text_encoder_2/model.safetensors",
|
||||
"text_encoder_2/openvino_model.bin",
|
||||
"text_encoder_2/openvino_model.xml",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_flax_model.msgpack",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
"unet/model.onnx",
|
||||
"unet/model.onnx_data",
|
||||
"unet/openvino_model.bin",
|
||||
"unet/openvino_model.xml",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_flax_model.msgpack",
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
"vae_1_0/config.json",
|
||||
"vae_1_0/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae_1_0/diffusion_pytorch_model.safetensors",
|
||||
"vae_decoder/config.json",
|
||||
"vae_decoder/model.onnx",
|
||||
"vae_decoder/openvino_model.bin",
|
||||
"vae_decoder/openvino_model.xml",
|
||||
"vae_encoder/config.json",
|
||||
"vae_encoder/model.onnx",
|
||||
"vae_encoder/openvino_model.bin",
|
||||
"vae_encoder/openvino_model.xml",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
# This are what we expect to get when various diffusers variants are requested
|
||||
@pytest.mark.parametrize(
|
||||
"variant,expected_list",
|
||||
[
|
||||
(
|
||||
None,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.safetensors",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/model.safetensors",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
"vae_1_0/config.json",
|
||||
"vae_1_0/diffusion_pytorch_model.safetensors",
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.DEFAULT,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.safetensors",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/model.safetensors",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
"vae_1_0/config.json",
|
||||
"vae_1_0/diffusion_pytorch_model.safetensors",
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.OPENVINO,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/openvino_model.bin",
|
||||
"text_encoder/openvino_model.xml",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/openvino_model.bin",
|
||||
"text_encoder_2/openvino_model.xml",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/openvino_model.bin",
|
||||
"unet/openvino_model.xml",
|
||||
"vae_decoder/config.json",
|
||||
"vae_decoder/openvino_model.bin",
|
||||
"vae_decoder/openvino_model.xml",
|
||||
"vae_encoder/config.json",
|
||||
"vae_encoder/openvino_model.bin",
|
||||
"vae_encoder/openvino_model.xml",
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.FP16,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.fp16.safetensors",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/model.fp16.safetensors",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae_1_0/config.json",
|
||||
"vae_1_0/diffusion_pytorch_model.fp16.safetensors",
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.ONNX,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.onnx",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/model.onnx",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/model.onnx",
|
||||
"vae_decoder/config.json",
|
||||
"vae_decoder/model.onnx",
|
||||
"vae_encoder/config.json",
|
||||
"vae_encoder/model.onnx",
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.FLAX,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/flax_model.msgpack",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/flax_model.msgpack",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_flax_model.msgpack",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_flax_model.msgpack",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[Path]) -> None:
|
||||
print(f"testing variant {variant}")
|
||||
filtered_files = filter_files(sdxl_base_files, variant)
|
||||
assert set(filtered_files) == {Path(x) for x in expected_list}
|
@ -1,6 +1,7 @@
|
||||
# conftest.py is a special pytest file. Fixtures defined in this file will be accessible to all tests in this directory
|
||||
# without needing to explicitly import them. (https://docs.pytest.org/en/6.2.x/fixture.html)
|
||||
|
||||
|
||||
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
|
||||
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
|
||||
from invokeai.backend.util.test_utils import model_installer, torch_device # noqa: F401
|
||||
|
Loading…
Reference in New Issue
Block a user