Refactor model manager: model installer component (#5171)

## What type of PR is this? (check all applicable)

- [X] Refactor
- [X] Feature
- [ ] Bug Fix
- [ ] Optimization
- [X] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [X] Yes
- [ ] No


## Description

This is the next phase of the model manager refactor, as discussed with
@psychedelicious and @RyanJDick. This implements the model installer,
which is responsible for managing model weights on disk and installing
new models.

Currently only installation of local files and directories is supported.
Remote installation will be implemented after the queued download
manager is reviewed and approved.

Please see the documentation located at
[docs/contributing/MODEL_MANAGER.md](8695ad6f59/docs/contributing/MODEL_MANAGER.md (model-installation))
for an explanation of how this module works.

Things that have changed relative to the current implementation.

1. Model importation runs in a background thread. Access to the
installation status is through a ModelInstallJob object returned by the
`import_model()` call. In addition, the installation process generates a
series of `model_install` events on the event bus.
2. `model_install_progress` events are documented, but not currently
issued. These will be issued when background downloading is implemented.
3. The model installer currently runs in parallel to the current model
manager. The frontend continues to use `configs/models.yaml` and ignores
what is in the `model_config` table of `invokeai.db`.
4. When the installer is initialized at app startup time, it
synchronizes its database to the contents of the InvokeAI `models`
directory. The current model manager does this as well, so you will see
two log messages indicating that this directory is being scanned.


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

You can test using the FastAPI swagger pages at
http://localhost:9090/docs. Use the calls listed under
`model_manager_v2`. Be aware that only installation of local models
(indicated by their file or directory path) are currently supported.

## Added/updated tests?

- [X] Yes -- see
`tests/app/services/model_install/test_model_install.py`
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
This commit is contained in:
Kent Keirsey 2023-12-10 23:16:39 -05:00 committed by GitHub
commit ef807cf63a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 2582 additions and 418 deletions

View File

@ -10,40 +10,36 @@ model. These are the:
tracks the type of the model, its provenance, and where it can be
found on disk.
* _ModelLoadServiceBase_ Responsible for loading a model from disk
into RAM and VRAM and getting it ready for inference.
* _DownloadQueueServiceBase_ A multithreaded downloader responsible
for downloading models from a remote source to disk. The download
queue has special methods for downloading repo_id folders from
Hugging Face, as well as discriminating among model versions in
Civitai, but can be used for arbitrary content.
* _ModelInstallServiceBase_ A service for installing models to
disk. It uses `DownloadQueueServiceBase` to download models and
their metadata, and `ModelRecordServiceBase` to store that
information. It is also responsible for managing the InvokeAI
`models` directory and its contents.
* _DownloadQueueServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
A multithreaded downloader responsible
for downloading models from a remote source to disk. The download
queue has special methods for downloading repo_id folders from
Hugging Face, as well as discriminating among model versions in
Civitai, but can be used for arbitrary content.
* _ModelLoadServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
Responsible for loading a model from disk
into RAM and VRAM and getting it ready for inference.
## Location of the Code
All four of these services can be found in
`invokeai/app/services` in the following directories:
* `invokeai/app/services/model_records/`
* `invokeai/app/services/downloads/`
* `invokeai/app/services/model_loader/`
* `invokeai/app/services/model_install/`
With the exception of the install service, each of these is a thin
shell around a corresponding implementation located in
`invokeai/backend/model_manager`. The main difference between the
modules found in app services and those in the backend folder is that
the former add support for event reporting and are more tied to the
needs of the InvokeAI API.
* `invokeai/app/services/model_loader/` (**under development**)
* `invokeai/app/services/downloads/`(**under development**)
Code related to the FastAPI web API can be found in
`invokeai/app/api/routers/models.py`.
`invokeai/app/api/routers/model_records.py`.
***
@ -165,10 +161,6 @@ of the fields, including `name`, `model_type` and `base_model`, are
shared between `ModelConfigBase` and `ModelBase`, and this is a
potential source of confusion.
** TO DO: ** The `ModelBase` code needs to be revised to reduce the
duplication of similar classes and to support using the `key` as the
primary model identifier.
## Reading and Writing Model Configuration Records
The `ModelRecordService` provides the ability to retrieve model
@ -362,7 +354,7 @@ model and pass its key to `get_model()`.
Several methods allow you to create and update stored model config
records.
#### add_model(key, config) -> ModelConfigBase:
#### add_model(key, config) -> AnyModelConfig:
Given a key and a configuration, this will add the model's
configuration record to the database. `config` can either be a subclass of
@ -386,27 +378,356 @@ fields to be updated. This will return an `AnyModelConfig` on success,
or raise `InvalidModelConfigException` or `UnknownModelException`
exceptions on failure.
***TO DO:*** Investigate why `update_model()` returns an
`AnyModelConfig` while `add_model()` returns a `ModelConfigBase`.
### rename_model(key, new_name) -> ModelConfigBase:
This is a special case of `update_model()` for the use case of
changing the model's name. It is broken out because there are cases in
which the InvokeAI application wants to synchronize the model's name
with its path in the `models` directory after changing the name, type
or base. However, when using the ModelRecordService directly, the call
is equivalent to:
```
store.rename_model(key, {'name': 'new_name'})
```
***TO DO:*** Investigate why `rename_model()` is returning a
`ModelConfigBase` while `update_model()` returns a `AnyModelConfig`.
***
## Model installation
The `ModelInstallService` class implements the
`ModelInstallServiceBase` abstract base class, and provides a one-stop
shop for all your model install needs. It provides the following
functionality:
- Registering a model config record for a model already located on the
local filesystem, without moving it or changing its path.
- Installing a model alreadiy located on the local filesystem, by
moving it into the InvokeAI root directory under the
`models` folder (or wherever config parameter `models_dir`
specifies).
- Probing of models to determine their type, base type and other key
information.
- Interface with the InvokeAI event bus to provide status updates on
the download, installation and registration process.
- Downloading a model from an arbitrary URL and installing it in
`models_dir` (_implementation pending_).
- Special handling for Civitai model URLs which allow the user to
paste in a model page's URL or download link (_implementation pending_).
- Special handling for HuggingFace repo_ids to recursively download
the contents of the repository, paying attention to alternative
variants such as fp16. (_implementation pending_)
### Initializing the installer
A default installer is created at InvokeAI api startup time and stored
in `ApiDependencies.invoker.services.model_install` and can
also be retrieved from an invocation's `context` argument with
`context.services.model_install`.
In the event you wish to create a new installer, you may use the
following initialization pattern:
```
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config, logger)
store = ModelRecordServiceSQL(db)
installer = ModelInstallService(config, store)
```
The full form of `ModelInstallService()` takes the following
required parameters:
| **Argument** | **Type** | **Description** |
|------------------|------------------------------|------------------------------|
| `config` | InvokeAIAppConfig | InvokeAI app configuration object |
| `record_store` | ModelRecordServiceBase | Config record storage database |
| `event_bus` | EventServiceBase | Optional event bus to send download/install progress events to |
Once initialized, the installer will provide the following methods:
#### install_job = installer.import_model()
The `import_model()` method is the core of the installer. The
following illustrates basic usage:
```
from invokeai.app.services.model_install import (
LocalModelSource,
HFModelSource,
URLModelSource,
)
source1 = LocalModelSource(path='/opt/models/sushi.safetensors') # a local safetensors file
source2 = LocalModelSource(path='/opt/models/sushi_diffusers') # a local diffusers folder
source3 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5') # a repo_id
source4 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='vae') # a subfolder within a repo_id
source5 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', variant='fp16') # a named variant of a HF model
source6 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL
source7 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token
for source in [source1, source2, source3, source4, source5, source6, source7]:
install_job = installer.install_model(source)
source2job = installer.wait_for_installs()
for source in sources:
job = source2job[source]
if job.status == "completed":
model_config = job.config_out
model_key = model_config.key
print(f"{source} installed as {model_key}")
elif job.status == "error":
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
```
As shown here, the `import_model()` method accepts a variety of
sources, including local safetensors files, local diffusers folders,
HuggingFace repo_ids with and without a subfolder designation,
Civitai model URLs and arbitrary URLs that point to checkpoint files
(but not to folders).
Each call to `import_model()` return a `ModelInstallJob` job,
an object which tracks the progress of the install.
If a remote model is requested, the model's files are downloaded in
parallel across a multiple set of threads using the download
queue. During the download process, the `ModelInstallJob` is updated
to provide status and progress information. After the files (if any)
are downloaded, the remainder of the installation runs in a single
serialized background thread. These are the model probing, file
copying, and config record database update steps.
Multiple install jobs can be queued up. You may block until all
install jobs are completed (or errored) by calling the
`wait_for_installs()` method as shown in the code
example. `wait_for_installs()` will return a `dict` that maps the
requested source to its job. This object can be interrogated
to determine its status. If the job errored out, then the error type
and details can be recovered from `job.error_type` and `job.error`.
The full list of arguments to `import_model()` is as follows:
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `source` | Union[str, Path, AnyHttpUrl] | | The source of the model, Path, URL or repo_id |
| `inplace` | bool | True | Leave a local model in its current location |
| `variant` | str | None | Desired variant, such as 'fp16' or 'onnx' (HuggingFace only) |
| `subfolder` | str | None | Repository subfolder (HuggingFace only) |
| `config` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
| `access_token` | str | None | Provide authorization information needed to download |
The `inplace` field controls how local model Paths are handled. If
True (the default), then the model is simply registered in its current
location by the installer's `ModelConfigRecordService`. Otherwise, a
copy of the model put into the location specified by the `models_dir`
application configuration parameter.
The `variant` field is used for HuggingFace repo_ids only. If
provided, the repo_id download handler will look for and download
tensors files that follow the convention for the selected variant:
- "fp16" will select files named "*model.fp16.{safetensors,bin}"
- "onnx" will select files ending with the suffix ".onnx"
- "openvino" will select files beginning with "openvino_model"
In the special case of the "fp16" variant, the installer will select
the 32-bit version of the files if the 16-bit version is unavailable.
`subfolder` is used for HuggingFace repo_ids only. If provided, the
model will be downloaded from the designated subfolder rather than the
top-level repository folder. If a subfolder is attached to the repo_id
using the format `repo_owner/repo_name:subfolder`, then the subfolder
specified by the repo_id will override the subfolder argument.
`config` can be used to override all or a portion of the configuration
attributes returned by the model prober. See the section below for
details.
`access_token` is passed to the download queue and used to access
repositories that require it.
#### Monitoring the install job process
When you create an install job with `import_model()`, it launches the
download and installation process in the background and returns a
`ModelInstallJob` object for monitoring the process.
The `ModelInstallJob` class has the following structure:
| **Attribute** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `status` | `InstallStatus` | An enum of ["waiting", "running", "completed" and "error" |
| `config_in` | `dict` | Overriding configuration values provided by the caller |
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
| `local_path` | `Path` | If a remote model, holds the path of the model after it is downloaded; if a local model, same as `source` |
| `error_type` | `str` | Name of the exception that led to an error status |
| `error` | `str` | Traceback of the error |
If the `event_bus` argument was provided, events will also be
broadcast to the InvokeAI event bus. The events will appear on the bus
as an event of type `EventServiceBase.model_event`, a timestamp and
the following event names:
- `model_install_started`
The payload will contain the keys `timestamp` and `source`. The latter
indicates the requested model source for installation.
- `model_install_progress`
Emitted at regular intervals when downloading a remote model, the
payload will contain the keys `timestamp`, `source`, `current_bytes`
and `total_bytes`. These events are _not_ emitted when a local model
already on the filesystem is imported.
- `model_install_completed`
Issued once at the end of a successful installation. The payload will
contain the keys `timestamp`, `source` and `key`, where `key` is the
ID under which the model has been registered.
- `model_install_error`
Emitted if the installation process fails for some reason. The payload
will contain the keys `timestamp`, `source`, `error_type` and
`error`. `error_type` is a short message indicating the nature of the
error, and `error` is the long traceback to help debug the problem.
#### Model confguration and probing
The install service uses the `invokeai.backend.model_manager.probe`
module during import to determine the model's type, base type, and
other configuration parameters. Among other things, it assigns a
default name and description for the model based on probed
fields.
When downloading remote models is implemented, additional
configuration information, such as list of trigger terms, will be
retrieved from the HuggingFace and Civitai model repositories.
The probed values can be overriden by providing a dictionary in the
optional `config` argument passed to `import_model()`. You may provide
overriding values for any of the model's configuration
attributes. Here is an example of setting the
`SchedulerPredictionType` and `name` for an sd-2 model:
This is typically used to set
the model's name and description, but can also be used to overcome
cases in which automatic probing is unable to (correctly) determine
the model's attribute. The most common situation is the
`prediction_type` field for sd-2 (and rare sd-1) models. Here is an
example of how it works:
```
install_job = installer.import_model(
source='stabilityai/stable-diffusion-2-1',
variant='fp16',
config=dict(
prediction_type=SchedulerPredictionType('v_prediction')
name='stable diffusion 2 base model',
)
)
```
### Other installer methods
This section describes additional methods provided by the installer class.
#### jobs = installer.wait_for_installs()
Block until all pending installs are completed or errored and then
returns a list of completed jobs.
#### jobs = installer.list_jobs([source])
Return a list of all active and complete `ModelInstallJobs`. An
optional `source` argument allows you to filter the returned list by a
model source string pattern using a partial string match.
#### jobs = installer.get_job(source)
Return a list of `ModelInstallJob` corresponding to the indicated
model source.
#### installer.prune_jobs
Remove non-pending jobs (completed or errored) from the job list
returned by `list_jobs()` and `get_job()`.
#### installer.app_config, installer.record_store,
installer.event_bus
Properties that provide access to the installer's `InvokeAIAppConfig`,
`ModelRecordServiceBase` and `EventServiceBase` objects.
#### key = installer.register_path(model_path, config), key = installer.install_path(model_path, config)
These methods bypass the download queue and directly register or
install the model at the indicated path, returning the unique ID for
the installed model.
Both methods accept a Path object corresponding to a checkpoint or
diffusers folder, and an optional dict of config attributes to use to
override the values derived from model probing.
The difference between `register_path()` and `install_path()` is that
the former creates a model configuration record without changing the
location of the model in the filesystem. The latter makes a copy of
the model inside the InvokeAI models directory before registering
it.
#### installer.unregister(key)
This will remove the model config record for the model at key, and is
equivalent to `installer.record_store.del_model(key)`
#### installer.delete(key)
This is similar to `unregister()` but has the additional effect of
conditionally deleting the underlying model file(s) if they reside
within the InvokeAI models directory
#### installer.unconditionally_delete(key)
This method is similar to `unregister()`, but also unconditionally
deletes the corresponding model weights file(s), regardless of whether
they are inside or outside the InvokeAI models hierarchy.
#### List[str]=installer.scan_directory(scan_dir: Path, install: bool)
This method will recursively scan the directory indicated in
`scan_dir` for new models and either install them in the models
directory or register them in place, depending on the setting of
`install` (default False).
The return value is the list of keys of the new installed/registered
models.
#### installer.sync_to_config()
This method synchronizes models in the models directory and autoimport
directory to those in the `ModelConfigRecordService` database. New
models are registered and orphan models are unregistered.
#### installer.start(invoker)
The `start` method is called by the API intialization routines when
the API starts up. Its effect is to call `sync_to_config()` to
synchronize the model record store database with what's currently on
disk.
# The remainder of this documentation is provisional, pending implementation of the Download and Load services
## Let's get loaded, the lowdown on ModelLoadService
The `ModelLoadService` is responsible for loading a named model into
@ -863,351 +1184,3 @@ other resources that it might have been using.
This will start/pause/cancel all jobs that have been submitted to the
queue and have not yet reached a terminal state.
## Model installation
The `ModelInstallService` class implements the
`ModelInstallServiceBase` abstract base class, and provides a one-stop
shop for all your model install needs. It provides the following
functionality:
- Registering a model config record for a model already located on the
local filesystem, without moving it or changing its path.
- Installing a model alreadiy located on the local filesystem, by
moving it into the InvokeAI root directory under the
`models` folder (or wherever config parameter `models_dir`
specifies).
- Downloading a model from an arbitrary URL and installing it in
`models_dir`.
- Special handling for Civitai model URLs which allow the user to
paste in a model page's URL or download link. Any metadata provided
by Civitai, such as trigger terms, are captured and placed in the
model config record.
- Special handling for HuggingFace repo_ids to recursively download
the contents of the repository, paying attention to alternative
variants such as fp16.
- Probing of models to determine their type, base type and other key
information.
- Interface with the InvokeAI event bus to provide status updates on
the download, installation and registration process.
### Initializing the installer
A default installer is created at InvokeAI api startup time and stored
in `ApiDependencies.invoker.services.model_install_service` and can
also be retrieved from an invocation's `context` argument with
`context.services.model_install_service`.
In the event you wish to create a new installer, you may use the
following initialization pattern:
```
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download_manager import DownloadQueueServive
from invokeai.app.services.model_record_service import ModelRecordServiceBase
config = InvokeAI.get_config()
queue = DownloadQueueService()
store = ModelRecordServiceBase.open(config)
installer = ModelInstallService(config=config, queue=queue, store=store)
```
The full form of `ModelInstallService()` takes the following
parameters. Each parameter will default to a reasonable value, but it
is recommended that you set them explicitly as shown in the above example.
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `config` | InvokeAIAppConfig | Use system-wide config | InvokeAI app configuration object |
| `queue` | DownloadQueueServiceBase | Create a new download queue for internal use | Download queue |
| `store` | ModelRecordServiceBase | Use config to select the database to open | Config storage database |
| `event_bus` | EventServiceBase | None | An event bus to send download/install progress events to |
| `event_handlers` | List[DownloadEventHandler] | None | Event handlers for the download queue |
Note that if `store` is not provided, then the class will use
`ModelRecordServiceBase.open(config)` to select the database to use.
Once initialized, the installer will provide the following methods:
#### install_job = installer.install_model()
The `install_model()` method is the core of the installer. The
following illustrates basic usage:
```
sources = [
Path('/opt/models/sushi.safetensors'), # a local safetensors file
Path('/opt/models/sushi_diffusers/'), # a local diffusers folder
'runwayml/stable-diffusion-v1-5', # a repo_id
'runwayml/stable-diffusion-v1-5:vae', # a subfolder within a repo_id
'https://civitai.com/api/download/models/63006', # a civitai direct download link
'https://civitai.com/models/8765?modelVersionId=10638', # civitai model page
'https://s3.amazon.com/fjacks/sd-3.safetensors', # arbitrary URL
]
for source in sources:
install_job = installer.install_model(source)
source2key = installer.wait_for_installs()
for source in sources:
model_key = source2key[source]
print(f"{source} installed as {model_key}")
```
As shown here, the `install_model()` method accepts a variety of
sources, including local safetensors files, local diffusers folders,
HuggingFace repo_ids with and without a subfolder designation,
Civitai model URLs and arbitrary URLs that point to checkpoint files
(but not to folders).
Each call to `install_model()` will return a `ModelInstallJob` job, a
subclass of `DownloadJobBase`. The install job has additional
install-specific fields described in the next section.
Each install job will run in a series of background threads using
the object's download queue. You may block until all install jobs are
completed (or errored) by calling the `wait_for_installs()` method as
shown in the code example. `wait_for_installs()` will return a `dict`
that maps the requested source to the key of the installed model. In
the case that a model fails to download or install, its value in the
dict will be None. The actual cause of the error will be reported in
the corresponding job's `error` field.
Alternatively you may install event handlers and/or listen for events
on the InvokeAI event bus in order to monitor the progress of the
requested installs.
The full list of arguments to `model_install()` is as follows:
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `source` | Union[str, Path, AnyHttpUrl] | | The source of the model, Path, URL or repo_id |
| `inplace` | bool | True | Leave a local model in its current location |
| `variant` | str | None | Desired variant, such as 'fp16' or 'onnx' (HuggingFace only) |
| `subfolder` | str | None | Repository subfolder (HuggingFace only) |
| `probe_override` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
| `metadata` | ModelSourceMetadata | None | Provide metadata that will be added to model's config |
| `access_token` | str | None | Provide authorization information needed to download |
| `priority` | int | 10 | Download queue priority for the job |
The `inplace` field controls how local model Paths are handled. If
True (the default), then the model is simply registered in its current
location by the installer's `ModelConfigRecordService`. Otherwise, the
model will be moved into the location specified by the `models_dir`
application configuration parameter.
The `variant` field is used for HuggingFace repo_ids only. If
provided, the repo_id download handler will look for and download
tensors files that follow the convention for the selected variant:
- "fp16" will select files named "*model.fp16.{safetensors,bin}"
- "onnx" will select files ending with the suffix ".onnx"
- "openvino" will select files beginning with "openvino_model"
In the special case of the "fp16" variant, the installer will select
the 32-bit version of the files if the 16-bit version is unavailable.
`subfolder` is used for HuggingFace repo_ids only. If provided, the
model will be downloaded from the designated subfolder rather than the
top-level repository folder. If a subfolder is attached to the repo_id
using the format `repo_owner/repo_name:subfolder`, then the subfolder
specified by the repo_id will override the subfolder argument.
`probe_override` can be used to override all or a portion of the
attributes returned by the model prober. This can be used to overcome
cases in which automatic probing is unable to (correctly) determine
the model's attribute. The most common situation is the
`prediction_type` field for sd-2 (and rare sd-1) models. Here is an
example of how it works:
```
install_job = installer.install_model(
source='stabilityai/stable-diffusion-2-1',
variant='fp16',
probe_override=dict(
prediction_type=SchedulerPredictionType('v_prediction')
)
)
```
`metadata` allows you to attach custom metadata to the installed
model. See the next section for details.
`priority` and `access_token` are passed to the download queue and
have the same effect as they do for the DownloadQueueServiceBase.
#### Monitoring the install job process
When you create an install job with `model_install()`, events will be
passed to the list of `DownloadEventHandlers` provided at installer
initialization time. Event handlers can also be added to individual
model install jobs by calling their `add_handler()` method as
described earlier for the `DownloadQueueService`.
If the `event_bus` argument was provided, events will also be
broadcast to the InvokeAI event bus. The events will appear on the bus
as a singular event type named `model_event` with a payload of
`job`. You can then retrieve the job and check its status.
** TO DO: ** consider breaking `model_event` into
`model_install_started`, `model_install_completed`, etc. The event bus
features have not yet been tested with FastAPI/websockets, and it may
turn out that the job object is not serializable.
#### Model metadata and probing
The install service has special handling for HuggingFace and Civitai
URLs that capture metadata from the source and include it in the model
configuration record. For example, fetching the Civitai model 8765
will produce a config record similar to this (using YAML
representation):
```
5abc3ef8600b6c1cc058480eaae3091e:
path: sd-1/lora/to8contrast-1-5.safetensors
name: to8contrast-1-5
base_model: sd-1
model_type: lora
model_format: lycoris
key: 5abc3ef8600b6c1cc058480eaae3091e
hash: 5abc3ef8600b6c1cc058480eaae3091e
description: 'Trigger terms: to8contrast style'
author: theovercomer8
license: allowCommercialUse=Sell; allowDerivatives=True; allowNoCredit=True
source: https://civitai.com/models/8765?modelVersionId=10638
thumbnail_url: null
tags:
- model
- style
- portraits
```
For sources that do not provide model metadata, you can attach custom
fields by providing a `metadata` argument to `model_install()` using
an initialized `ModelSourceMetadata` object (available for import from
`model_install_service.py`):
```
from invokeai.app.services.model_install_service import ModelSourceMetadata
meta = ModelSourceMetadata(
name="my model",
author="Sushi Chef",
description="Highly customized model; trigger with 'sushi',"
license="mit",
thumbnail_url="http://s3.amazon.com/ljack/pics/sushi.png",
tags=list('sfw', 'food')
)
install_job = installer.install_model(
source='sushi_chef/model3',
variant='fp16',
metadata=meta,
)
```
It is not currently recommended to provide custom metadata when
installing from Civitai or HuggingFace source, as the metadata
provided by the source will overwrite the fields you provide. Instead,
after the model is installed you can use
`ModelRecordService.update_model()` to change the desired fields.
** TO DO: ** Change the logic so that the caller's metadata fields take
precedence over those provided by the source.
#### Other installer methods
This section describes additional, less-frequently-used attributes and
methods provided by the installer class.
##### installer.wait_for_installs()
This is equivalent to the `DownloadQueue` `join()` method. It will
block until all the active jobs in the install queue have reached a
terminal state (completed, errored or cancelled).
##### installer.queue, installer.store, installer.config
These attributes provide access to the `DownloadQueueServiceBase`,
`ModelConfigRecordServiceBase`, and `InvokeAIAppConfig` objects that
the installer uses.
For example, to temporarily pause all pending installations, you can
do this:
```
installer.queue.pause_all_jobs()
```
##### key = installer.register_path(model_path, overrides), key = installer.install_path(model_path, overrides)
These methods bypass the download queue and directly register or
install the model at the indicated path, returning the unique ID for
the installed model.
Both methods accept a Path object corresponding to a checkpoint or
diffusers folder, and an optional dict of attributes to use to
override the values derived from model probing.
The difference between `register_path()` and `install_path()` is that
the former will not move the model from its current position, while
the latter will move it into the `models_dir` hierarchy.
##### installer.unregister(key)
This will remove the model config record for the model at key, and is
equivalent to `installer.store.unregister(key)`
##### installer.delete(key)
This is similar to `unregister()` but has the additional effect of
deleting the underlying model file(s) -- even if they were outside the
`models_dir` directory!
##### installer.conditionally_delete(key)
This method will call `unregister()` if the model identified by `key`
is outside the `models_dir` hierarchy, and call `delete()` if the
model is inside.
#### List[str]=installer.scan_directory(scan_dir: Path, install: bool)
This method will recursively scan the directory indicated in
`scan_dir` for new models and either install them in the models
directory or register them in place, depending on the setting of
`install` (default False).
The return value is the list of keys of the new installed/registered
models.
#### installer.scan_models_directory()
This method scans the models directory for new models and registers
them in place. Models that are present in the
`ModelConfigRecordService` database whose paths are not found will be
unregistered.
#### installer.sync_to_config()
This method synchronizes models in the models directory and autoimport
directory to those in the `ModelConfigRecordService` database. New
models are registered and orphan models are unregistered.
#### hash=installer.hash(model_path)
This method is calls the fasthash algorithm on a model's Path
(either a file or a folder) to generate a unique ID based on the
contents of the model.
##### installer.start(invoker)
The `start` method is called by the API intialization routines when
the API starts up. Its effect is to call `sync_to_config()` to
synchronize the model record store database with what's currently on
disk.
This method should not ordinarily be called manually.

View File

@ -22,6 +22,7 @@ from ..services.invoker import Invoker
from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
from ..services.model_install import ModelInstallService
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
@ -86,6 +87,9 @@ class ApiDependencies:
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)
model_install_service = ModelInstallService(
app_config=config, record_store=model_record_service, event_bus=events
)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor()
@ -112,6 +116,7 @@ class ApiDependencies:
logger=logger,
model_manager=model_manager,
model_records=model_record_service,
model_install=model_install_service,
names=names,
performance_statistics=performance_statistics,
processor=processor,

View File

@ -4,7 +4,7 @@
from hashlib import sha1
from random import randbytes
from typing import List, Optional
from typing import Any, Dict, List, Optional
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
@ -12,6 +12,7 @@ from pydantic import BaseModel, ConfigDict
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
@ -25,7 +26,7 @@ from invokeai.backend.model_manager.config import (
from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2"])
class ModelsList(BaseModel):
@ -43,15 +44,18 @@ class ModelsList(BaseModel):
async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
found_models: list[AnyModelConfig] = []
if base_models:
for base_model in base_models:
found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
found_models.extend(
record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name)
)
else:
found_models.extend(record_store.search_by_attr(model_type=model_type))
found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name))
return ModelsList(models=found_models)
@ -117,12 +121,17 @@ async def update_model_record(
async def del_model_record(
key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response:
"""Delete Model"""
"""
Delete model record from database.
The configuration record will be removed. The corresponding weights files will be
deleted as well if they reside within the InvokeAI "models" directory.
"""
logger = ApiDependencies.invoker.services.logger
try:
record_store = ApiDependencies.invoker.services.model_records
record_store.del_model(key)
installer = ApiDependencies.invoker.services.model_install
installer.delete(key)
logger.info(f"Deleted model: {key}")
return Response(status_code=204)
except UnknownModelException as e:
@ -162,3 +171,145 @@ async def add_model_record(
# now fetch it out
return record_store.get_model(config.key)
@model_records_router.post(
"/import",
operation_id="import_model_record",
responses={
201: {"description": "The model imported successfully"},
415: {"description": "Unrecognized file/folder format"},
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
)
async def import_model(
source: ModelSource,
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
),
) -> ModelInstallJob:
"""Add a model using its local path, repo_id, or remote URL.
Models will be downloaded, probed, configured and installed in a
series of background threads. The return object has `status` attribute
that can be used to monitor progress.
The source object is a discriminated Union of LocalModelSource,
HFModelSource and URLModelSource. Set the "type" field to the
appropriate value:
* To install a local path using LocalModelSource, pass a source of form:
`{
"type": "local",
"path": "/path/to/model",
"inplace": false
}`
The "inplace" flag, if true, will register the model in place in its
current filesystem location. Otherwise, the model will be copied
into the InvokeAI models directory.
* To install a HuggingFace repo_id using HFModelSource, pass a source of form:
`{
"type": "hf",
"repo_id": "stabilityai/stable-diffusion-2.0",
"variant": "fp16",
"subfolder": "vae",
"access_token": "f5820a918aaf01"
}`
The `variant`, `subfolder` and `access_token` fields are optional.
* To install a remote model using an arbitrary URL, pass:
`{
"type": "url",
"url": "http://www.civitai.com/models/123456",
"access_token": "f5820a918aaf01"
}`
The `access_token` field is optonal
The model's configuration record will be probed and filled in
automatically. To override the default guesses, pass "metadata"
with a Dict containing the attributes you wish to override.
Installation occurs in the background. Either use list_model_install_jobs()
to poll for completion, or listen on the event bus for the following events:
"model_install_started"
"model_install_completed"
"model_install_error"
On successful completion, the event's payload will contain the field "key"
containing the installed ID of the model. On an error, the event's payload
will contain the fields "error_type" and "error" describing the nature of the
error and its traceback, respectively.
"""
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_install
result: ModelInstallJob = installer.import_model(
source=source,
config=config,
)
logger.info(f"Started installation of {source}")
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return result
@model_records_router.get(
"/import",
operation_id="list_model_install_jobs",
)
async def list_model_install_jobs() -> List[ModelInstallJob]:
"""
Return list of model install jobs.
If the optional 'source' argument is provided, then the list will be filtered
for partial string matches against the install source.
"""
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs()
return jobs
@model_records_router.patch(
"/import",
operation_id="prune_model_install_jobs",
responses={
204: {"description": "All completed and errored jobs have been pruned"},
400: {"description": "Bad request"},
},
)
async def prune_model_install_jobs() -> Response:
"""
Prune all completed and errored jobs from the install job list.
"""
ApiDependencies.invoker.services.model_install.prune_jobs()
return Response(status_code=204)
@model_records_router.patch(
"/sync",
operation_id="sync_models_to_config",
responses={
204: {"description": "Model config record database resynced with files on disk"},
400: {"description": "Bad request"},
},
)
async def sync_models_to_config() -> Response:
"""
Traverse the models and autoimport directories. Model files without a corresponding
record in the database are added. Orphan records without a models file are deleted.
"""
ApiDependencies.invoker.services.model_install.sync_to_config()
return Response(status_code=204)

View File

@ -20,6 +20,7 @@ class SocketIO:
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
async def _handle_queue_event(self, event: Event):
await self.__sio.emit(
@ -28,10 +29,13 @@ class SocketIO:
room=event[1]["data"]["queue_id"],
)
async def _handle_sub_queue(self, sid, data, *args, **kwargs):
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data:
await self.__sio.enter_room(sid, data["queue_id"])
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data:
await self.__sio.leave_room(sid, data["queue_id"])
async def _handle_model_event(self, event: Event) -> None:
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])

View File

@ -1,6 +1,5 @@
"""
Init file for InvokeAI configure package
"""
"""Init file for InvokeAI configure package."""
from .config_base import PagingArgumentParser # noqa F401
from .config_default import InvokeAIAppConfig, get_invokeai_config # noqa F401
from .config_default import InvokeAIAppConfig, get_invokeai_config
__all__ = ["InvokeAIAppConfig", "get_invokeai_config"]

View File

@ -173,7 +173,7 @@ from __future__ import annotations
import os
from pathlib import Path
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
from omegaconf import DictConfig, OmegaConf
from pydantic import Field, TypeAdapter
@ -334,7 +334,7 @@ class InvokeAIAppConfig(InvokeAISettings):
)
@classmethod
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
def get_config(cls, **kwargs: Dict[str, Any]) -> InvokeAIAppConfig:
"""Return a singleton InvokeAIAppConfig configuration object."""
if (
cls.singleton_config is None
@ -383,17 +383,17 @@ class InvokeAIAppConfig(InvokeAISettings):
return db_dir / DB_FILE
@property
def model_conf_path(self) -> Optional[Path]:
def model_conf_path(self) -> Path:
"""Path to models configuration file."""
return self._resolve(self.conf_path)
@property
def legacy_conf_path(self) -> Optional[Path]:
def legacy_conf_path(self) -> Path:
"""Path to directory of legacy configuration files (e.g. v1-inference.yaml)."""
return self._resolve(self.legacy_conf_dir)
@property
def models_path(self) -> Optional[Path]:
def models_path(self) -> Path:
"""Path to the models directory."""
return self._resolve(self.models_dir)

View File

@ -0,0 +1 @@
from .events_base import EventServiceBase # noqa F401

View File

@ -1,5 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Optional
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
@ -16,6 +17,7 @@ from invokeai.backend.model_management.models.base import BaseModelType, ModelTy
class EventServiceBase:
queue_event: str = "queue_event"
model_event: str = "model_event"
"""Basic event bus, to have an empty stand-in when not needed"""
@ -30,6 +32,13 @@ class EventServiceBase:
payload={"event": event_name, "data": payload},
)
def __emit_model_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.model_event,
payload={"event": event_name, "data": payload},
)
# Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(
@ -313,3 +322,73 @@ class EventServiceBase:
event_name="queue_cleared",
payload={"queue_id": queue_id},
)
def emit_model_install_started(self, source: str) -> None:
"""
Emitted when an install job is started.
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_started",
payload={"source": source},
)
def emit_model_install_completed(self, source: str, key: str) -> None:
"""
Emitted when an install job is completed successfully.
:param source: Source of the model; local path, repo_id or url
:param key: Model config record key
"""
self.__emit_model_event(
event_name="model_install_completed",
payload={
"source": source,
"key": key,
},
)
def emit_model_install_progress(
self,
source: str,
current_bytes: int,
total_bytes: int,
) -> None:
"""
Emitted while the install job is in progress.
(Downloaded models only)
:param source: Source of the model
:param current_bytes: Number of bytes downloaded so far
:param total_bytes: Total bytes to download
"""
self.__emit_model_event(
event_name="model_install_progress",
payload={
"source": source,
"current_bytes": int,
"total_bytes": int,
},
)
def emit_model_install_error(
self,
source: str,
error_type: str,
error: str,
) -> None:
"""
Emitted when an install job encounters an exception.
:param source: Source of the model
:param exception: The exception that raised the error
"""
self.__emit_model_event(
event_name="model_install_error",
payload={
"source": source,
"error_type": error_type,
"error": error,
},
)

View File

@ -21,6 +21,7 @@ if TYPE_CHECKING:
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .item_storage.item_storage_base import ItemStorageABC
from .latents_storage.latents_storage_base import LatentsStorageBase
from .model_install import ModelInstallServiceBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase
from .names.names_base import NameServiceBase
@ -50,6 +51,7 @@ class InvocationServices:
logger: "Logger"
model_manager: "ModelManagerServiceBase"
model_records: "ModelRecordServiceBase"
model_install: "ModelInstallServiceBase"
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC"
@ -77,6 +79,7 @@ class InvocationServices:
logger: "Logger",
model_manager: "ModelManagerServiceBase",
model_records: "ModelRecordServiceBase",
model_install: "ModelInstallServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
@ -102,6 +105,7 @@ class InvocationServices:
self.logger = logger
self.model_manager = model_manager
self.model_records = model_records
self.model_install = model_install
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue

View File

@ -0,0 +1,25 @@
"""Initialization file for model install service package."""
from .model_install_base import (
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
UnknownInstallJobException,
URLModelSource,
)
from .model_install_default import ModelInstallService
__all__ = [
"ModelInstallServiceBase",
"ModelInstallService",
"InstallStatus",
"ModelInstallJob",
"UnknownInstallJobException",
"ModelSource",
"LocalModelSource",
"HFModelSource",
"URLModelSource",
]

View File

@ -0,0 +1,317 @@
import re
import traceback
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from fastapi import Body
from pydantic import BaseModel, Field, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
class StringLikeSource(BaseModel):
"""
Base class for model sources, implements functions that lets the source be sorted and indexed.
These shenanigans let this stuff work:
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
mydict = {source1: 'model 1'}
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
assert source1 == source2
assert source1 == 'C:/users/mort/foo.safetensors'
"""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: object) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: object) -> bool:
"""Return equality on the stringified version."""
if isinstance(other, Path):
return str(self) == other.as_posix()
else:
return str(self) == str(other)
class LocalModelSource(StringLikeSource):
"""A local file or directory path."""
path: str | Path
inplace: Optional[bool] = False
type: Literal["local"] = "local"
# these methods allow the source to be used in a string-like way,
# for example as an index into a dict
def __str__(self) -> str:
"""Return string version of path when string rep needed."""
return Path(self.path).as_posix()
class HFModelSource(StringLikeSource):
"""A HuggingFace repo_id, with optional variant and sub-folder."""
repo_id: str
variant: Optional[str] = None
subfolder: Optional[str | Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@field_validator("repo_id")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
base += f":{self.subfolder}" if self.subfolder else ""
base += f" ({self.variant})" if self.variant else ""
return base
class URLModelSource(StringLikeSource):
"""A generic URL point to a checkpoint file."""
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["generic_url"] = "generic_url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
# Body() is being applied here rather than Field() because otherwise FastAPI will
# refuse to generate a schema. Relevant links:
#
# "Model Manager Refactor Phase 1 - SQL-based config storage
# https://github.com/invoke-ai/InvokeAI/pull/5039#discussion_r1389752119 (comment)
# Param: xyz can only be a request body, using Body() when using discriminated unions
# https://github.com/tiangolo/fastapi/discussions/9761
# Body parameter cannot be a pydantic union anymore sinve v0.95
# https://github.com/tiangolo/fastapi/discussions/9287
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Body(discriminator="type")]
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."
)
inplace: bool = Field(
default=False, description="Leave model in its current location; otherwise install under models directory"
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR")
error: Optional[str] = Field(default=None, description="Error traceback") # noqa #501
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self.error_type = e.__class__.__name__
self.error = "".join(traceback.format_exception(e))
self.status = InstallStatus.ERROR
class ModelInstallServiceBase(ABC):
"""Abstract base class for InvokeAI model installation."""
@abstractmethod
def __init__(
self,
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
event_bus: Optional["EventServiceBase"] = None,
):
"""
Create ModelInstallService object.
:param config: Systemwide InvokeAIAppConfig.
:param store: Systemwide ModelConfigStore
:param event_bus: InvokeAI event bus for reporting events to.
"""
def start(self, invoker: Invoker) -> None:
"""Call at InvokeAI startup time."""
self.sync_to_config()
@abstractmethod
def stop(self) -> None:
"""Stop the model install service. After this the objection can be safely deleted."""
@property
@abstractmethod
def app_config(self) -> InvokeAIAppConfig:
"""Return the appConfig object associated with the installer."""
@property
@abstractmethod
def record_store(self) -> ModelRecordServiceBase:
"""Return the ModelRecoreService object associated with the installer."""
@property
@abstractmethod
def event_bus(self) -> Optional[EventServiceBase]:
"""Return the event service base object associated with the installer."""
@abstractmethod
def register_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
) -> str:
"""
Probe and register the model at model_path.
This keeps the model in its current location.
:param model_path: Filesystem Path to the model.
:param config: Dict of attributes that will override autoassigned values.
:returns id: The string ID of the registered model.
"""
@abstractmethod
def unregister(self, key: str) -> None:
"""Remove model with indicated key from the database."""
@abstractmethod
def delete(self, key: str) -> None:
"""Remove model with indicated key from the database. Delete its files only if they are within our models directory."""
@abstractmethod
def unconditionally_delete(self, key: str) -> None:
"""Remove model with indicated key from the database and unconditionally delete weight files from disk."""
@abstractmethod
def install_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
) -> str:
"""
Probe, register and install the model in the models directory.
This moves the model from its current location into
the models directory handled by InvokeAI.
:param model_path: Filesystem Path to the model.
:param config: Dict of attributes that will override autoassigned values.
:returns id: The string ID of the registered model.
"""
@abstractmethod
def import_model(
self,
source: ModelSource,
config: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob:
"""Install the indicated model.
:param source: ModelSource object
:param config: Optional dict. Any fields in this dict
will override corresponding autoassigned probe fields in the
model's config record. Use it to override
`name`, `description`, `base_type`, `model_type`, `format`,
`prediction_type`, `image_size`, and/or `ztsnr_training`.
This will download the model located at `source`,
probe it, and install it into the models directory.
This call is executed asynchronously in a separate
thread and will issue the following events on the event bus:
- model_install_started
- model_install_error
- model_install_completed
The `inplace` flag does not affect the behavior of downloaded
models, which are always moved into the `models` directory.
The call returns a ModelInstallJob object which can be
polled to learn the current status and/or error message.
Variants recognized by HuggingFace currently are:
1. onnx
2. openvino
3. fp16
4. None (usually returns fp32 model)
"""
@abstractmethod
def get_job(self, source: ModelSource) -> List[ModelInstallJob]:
"""Return the ModelInstallJob(s) corresponding to the provided source."""
@abstractmethod
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
"""
List active and complete install jobs.
"""
@abstractmethod
def prune_jobs(self) -> None:
"""Prune all completed and errored jobs."""
@abstractmethod
def wait_for_installs(self) -> List[ModelInstallJob]:
"""
Wait for all pending installs to complete.
This will block until all pending installs have
completed, been cancelled, or errored out. It will
block indefinitely if one or more jobs are in the
paused state.
It will return the current list of jobs.
"""
@abstractmethod
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
"""
Recursively scan directory for new models and register or install them.
:param scan_dir: Path to the directory to scan.
:param install: Install if True, otherwise register in place.
:returns list of IDs: Returns list of IDs of models registered/installed
"""
@abstractmethod
def sync_to_config(self) -> None:
"""Synchronize models on disk to those in the model record database."""

View File

@ -0,0 +1,395 @@
"""Model installation class."""
import threading
from hashlib import sha256
from logging import Logger
from pathlib import Path
from queue import Queue
from random import randbytes
from shutil import copyfile, copytree, move, rmtree
from typing import Any, Dict, List, Optional, Set, Union
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, UnknownModelException
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
InvalidModelConfigException,
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import Chdir, InvokeAILogger
from .model_install_base import (
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
)
# marker that the queue is done and that thread should exit
STOP_JOB = ModelInstallJob(
source=LocalModelSource(path="stop"),
local_path=Path("/dev/null"),
)
class ModelInstallService(ModelInstallServiceBase):
"""class for InvokeAI model installation."""
_app_config: InvokeAIAppConfig
_record_store: ModelRecordServiceBase
_event_bus: Optional[EventServiceBase] = None
_install_queue: Queue[ModelInstallJob]
_install_jobs: List[ModelInstallJob]
_logger: Logger
_cached_model_paths: Set[Path]
_models_installed: Set[str]
def __init__(
self,
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
event_bus: Optional[EventServiceBase] = None,
):
"""
Initialize the installer object.
:param app_config: InvokeAIAppConfig object
:param record_store: Previously-opened ModelRecordService database
:param event_bus: Optional EventService object
"""
self._app_config = app_config
self._record_store = record_store
self._event_bus = event_bus
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
self._install_jobs = []
self._install_queue = Queue()
self._cached_model_paths = set()
self._models_installed = set()
self._start_installer_thread()
@property
def app_config(self) -> InvokeAIAppConfig: # noqa D102
return self._app_config
@property
def record_store(self) -> ModelRecordServiceBase: # noqa D102
return self._record_store
@property
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
return self._event_bus
def stop(self) -> None:
"""Stop the install thread; after this the object can be deleted and garbage collected."""
self._install_queue.put(STOP_JOB)
def _start_installer_thread(self) -> None:
threading.Thread(target=self._install_next_item, daemon=True).start()
def _install_next_item(self) -> None:
done = False
while not done:
job = self._install_queue.get()
if job == STOP_JOB:
done = True
continue
assert job.local_path is not None
try:
self._signal_job_running(job)
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:
key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job)
except (OSError, DuplicateModelException, InvalidModelConfigException) as excp:
self._signal_job_errored(job, excp)
finally:
self._install_queue.task_done()
self._logger.info("Install thread exiting")
def _signal_job_running(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.RUNNING
self._logger.info(f"{job.source}: model installation started")
if self._event_bus:
self._event_bus.emit_model_install_started(str(job.source))
def _signal_job_completed(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.COMPLETED
assert job.config_out
self._logger.info(
f"{job.source}: model installation completed. {job.local_path} registered key {job.config_out.key}"
)
if self._event_bus:
assert job.local_path is not None
assert job.config_out is not None
key = job.config_out.key
self._event_bus.emit_model_install_completed(str(job.source), key)
def _signal_job_errored(self, job: ModelInstallJob, excp: Exception) -> None:
job.set_error(excp)
self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}")
if self._event_bus:
error_type = job.error_type
error = job.error
assert error_type is not None
assert error is not None
self._event_bus.emit_model_install_error(str(job.source), error_type, error)
def register_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
if config.get("source") is None:
config["source"] = model_path.resolve().as_posix()
return self._register(model_path, config)
def install_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
if config.get("source") is None:
config["source"] = model_path.resolve().as_posix()
info: AnyModelConfig = self._probe_model(Path(model_path), config)
old_hash = info.original_hash
dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name
new_path = self._copy_model(model_path, dest_path)
new_hash = FastModelHash.hash(new_path)
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
return self._register(
new_path,
config,
info,
)
def import_model(
self,
source: ModelSource,
config: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob: # noqa D102
if not config:
config = {}
# Installing a local path
if isinstance(source, LocalModelSource) and Path(source.path).exists(): # a path that is already on disk
job = ModelInstallJob(
source=source,
config_in=config,
local_path=Path(source.path),
)
self._install_jobs.append(job)
self._install_queue.put(job)
return job
else: # here is where we'd download a URL or repo_id. Implementation pending download queue.
raise UnknownModelException("File or directory not found")
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
return self._install_jobs
def get_job(self, source: ModelSource) -> List[ModelInstallJob]: # noqa D102
return [x for x in self._install_jobs if x.source == source]
def wait_for_installs(self) -> List[ModelInstallJob]: # noqa D102
self._install_queue.join()
return self._install_jobs
def prune_jobs(self) -> None:
"""Prune all completed and errored jobs."""
unfinished_jobs = [
x for x in self._install_jobs if x.status not in [InstallStatus.COMPLETED, InstallStatus.ERROR]
]
self._install_jobs = unfinished_jobs
def sync_to_config(self) -> None:
"""Synchronize models on disk to those in the config record store database."""
self._scan_models_directory()
if autoimport := self._app_config.autoimport_dir:
self._logger.info("Scanning autoimport directory for new models")
installed = self.scan_directory(self._app_config.root_path / autoimport)
self._logger.info(f"{len(installed)} new models registered")
self._logger.info("Model installer (re)initialized")
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback)
self._models_installed: Set[str] = set()
search.search(scan_dir)
return list(self._models_installed)
def _scan_models_directory(self) -> None:
"""
Scan the models directory for new and missing models.
New models will be added to the storage backend. Missing models
will be deleted.
"""
defunct_models = set()
installed = set()
with Chdir(self._app_config.models_path):
self._logger.info("Checking for models that have been moved or deleted from disk")
for model_config in self.record_store.all_models():
path = Path(model_config.path)
if not path.exists():
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering")
defunct_models.add(model_config.key)
for key in defunct_models:
self.unregister(key)
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
for cur_base_model in BaseModelType:
for cur_model_type in ModelType:
models_dir = Path(cur_base_model.value, cur_model_type.value)
installed.update(self.scan_directory(models_dir))
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
def _sync_model_path(self, key: str, ignore_hash_change: bool = False) -> AnyModelConfig:
"""
Move model into the location indicated by its basetype, type and name.
Call this after updating a model's attributes in order to move
the model's path into the location indicated by its basetype, type and
name. Applies only to models whose paths are within the root `models_dir`
directory.
May raise an UnknownModelException.
"""
model = self.record_store.get_model(key)
old_path = Path(model.path)
models_dir = self.app_config.models_path
if not old_path.is_relative_to(models_dir):
return model
new_path = models_dir / model.base.value / model.type.value / model.name
self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path)
new_hash = FastModelHash.hash(new_path)
model.path = new_path.relative_to(models_dir).as_posix()
if model.current_hash != new_hash:
assert (
ignore_hash_change
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
model.current_hash = new_hash
self._logger.info(f"Model has new hash {model.current_hash}, but will continue to be identified by {key}")
self.record_store.update_model(key, model)
return model
def _scan_register(self, model: Path) -> bool:
if model in self._cached_model_paths:
return True
try:
id = self.register_path(model)
self._sync_model_path(id) # possibly move it to right place in `models`
self._logger.info(f"Registered {model.name} with id {id}")
self._models_installed.add(id)
except DuplicateModelException:
pass
return True
def _scan_install(self, model: Path) -> bool:
if model in self._cached_model_paths:
return True
try:
id = self.install_path(model)
self._logger.info(f"Installed {model} with id {id}")
self._models_installed.add(id)
except DuplicateModelException:
pass
return True
def unregister(self, key: str) -> None: # noqa D102
self.record_store.del_model(key)
def delete(self, key: str) -> None: # noqa D102
"""Unregister the model. Delete its files only if they are within our models directory."""
model = self.record_store.get_model(key)
models_dir = self.app_config.models_path
model_path = models_dir / model.path
if model_path.is_relative_to(models_dir):
self.unconditionally_delete(key)
else:
self.unregister(key)
def unconditionally_delete(self, key: str) -> None: # noqa D102
model = self.record_store.get_model(key)
path = self.app_config.models_path / model.path
if path.is_dir():
rmtree(path)
else:
path.unlink()
self.unregister(key)
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
new_path.parent.mkdir(parents=True, exist_ok=True)
if old_path.is_dir():
copytree(old_path, new_path)
else:
copyfile(old_path, new_path)
return new_path
def _move_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
new_path.parent.mkdir(parents=True, exist_ok=True)
# if path already exists then we jigger the name to make it unique
counter: int = 1
while new_path.exists():
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
if not path.exists():
new_path = path
counter += 1
move(old_path, new_path)
return new_path
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
if config: # used to override probe fields
for key, value in config.items():
setattr(info, key, value)
return info
def _create_key(self) -> str:
return sha256(randbytes(100)).hexdigest()[0:32]
def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str:
info = info or ModelProbe.probe(model_path, config)
key = self._create_key()
model_path = model_path.absolute()
if model_path.is_relative_to(self.app_config.models_path):
model_path = model_path.relative_to(self.app_config.models_path)
info.path = model_path.as_posix()
# add 'main' specific fields
if hasattr(info, "config"):
# make config relative to our root
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
self.record_store.add_model(key, info)
return key

View File

@ -6,3 +6,11 @@ from .model_records_base import ( # noqa F401
UnknownModelException,
)
from .model_records_sql import ModelRecordServiceSQL # noqa F401
__all__ = [
"ModelRecordServiceBase",
"ModelRecordServiceSQL",
"DuplicateModelException",
"InvalidModelException",
"UnknownModelException",
]

View File

@ -32,6 +32,8 @@ class ModelProbeInfo(object):
upcast_attention: bool
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
image_size: int
name: Optional[str] = None
description: Optional[str] = None
class ProbeBase(object):
@ -113,12 +115,16 @@ class ModelProbe(object):
base_type = probe.get_base_type()
variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type()
name = cls.get_model_name(model_path)
description = f"{base_type.value} {model_type.value} model {name}"
format = probe.get_format()
model_info = ModelProbeInfo(
model_type=model_type,
base_type=base_type,
variant_type=variant_type,
prediction_type=prediction_type,
name=name,
description=description,
upcast_attention=(
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
@ -142,6 +148,13 @@ class ModelProbe(object):
return model_info
@classmethod
def get_model_name(cls, model_path: Path) -> str:
if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
return model_path.stem
else:
return model_path.name
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):

View File

@ -0,0 +1,29 @@
"""Re-export frequently-used symbols from the Model Manager backend."""
from .config import (
AnyModelConfig,
BaseModelType,
InvalidModelConfigException,
ModelConfigFactory,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubModelType,
)
from .probe import ModelProbe
from .search import ModelSearch
__all__ = [
"ModelProbe",
"ModelSearch",
"InvalidModelConfigException",
"ModelConfigFactory",
"BaseModelType",
"ModelType",
"SubModelType",
"ModelVariantType",
"ModelFormat",
"SchedulerPredictionType",
"AnyModelConfig",
]

View File

@ -23,7 +23,7 @@ from enum import Enum
from typing import Literal, Optional, Type, Union
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated
from typing_extensions import Annotated, Any, Dict
class InvalidModelConfigException(Exception):
@ -122,7 +122,7 @@ class ModelConfigBase(BaseModel):
validate_assignment=True,
)
def update(self, attributes: dict):
def update(self, attributes: Dict[str, Any]) -> None:
"""Update the object with fields in dict."""
for key, value in attributes.items():
setattr(self, key, value) # may raise a validation error
@ -195,8 +195,6 @@ class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
"""Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
# Note that we do not need prediction_type or upcast_attention here
# because they are provided in the checkpoint's own config file.
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):

View File

@ -0,0 +1,684 @@
import json
import re
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union
import safetensors.torch
import torch
from picklescan.scanner import scan_file_path
from invokeai.backend.model_management.models.base import read_checkpoint_meta
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
from invokeai.backend.model_management.util import lora_token_vector_length
from invokeai.backend.util.util import SilenceWarnings
from .config import (
AnyModelConfig,
BaseModelType,
InvalidModelConfigException,
ModelConfigFactory,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
)
from .hash import FastModelHash
CkptType = Dict[str, Any]
LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
},
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
},
}
class ProbeBase(object):
"""Base class for probes."""
def __init__(self, model_path: Path):
self.model_path = model_path
def get_base_type(self) -> BaseModelType:
"""Get model base type."""
raise NotImplementedError
def get_format(self) -> ModelFormat:
"""Get model file format."""
raise NotImplementedError
def get_variant_type(self) -> Optional[ModelVariantType]:
"""Get model variant type."""
return None
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
"""Get model scheduler prediction type."""
return None
class ModelProbe(object):
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
"diffusers": {},
"checkpoint": {},
"onnx": {},
}
CLASS2TYPE = {
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"AutoencoderTiny": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter,
}
@classmethod
def register_probe(
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase]
) -> None:
cls.PROBES[format][model_type] = probe_class
@classmethod
def heuristic_probe(
cls,
model_path: Path,
fields: Optional[Dict[str, Any]] = None,
) -> AnyModelConfig:
return cls.probe(model_path, fields)
@classmethod
def probe(
cls,
model_path: Path,
fields: Optional[Dict[str, Any]] = None,
) -> AnyModelConfig:
"""
Probe the model at model_path and return its configuration record.
:param model_path: Path to the model file (checkpoint) or directory (diffusers).
:param fields: An optional dictionary that can be used to override probed
fields. Typically used for fields that don't probe well, such as prediction_type.
Returns: The appropriate model configuration derived from ModelConfigBase.
"""
if fields is None:
fields = {}
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None
model_type = None
if format_type == "diffusers":
model_type = cls.get_model_type_from_folder(model_path)
else:
model_type = cls.get_model_type_from_checkpoint(model_path)
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class:
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
hash = FastModelHash.hash(model_path)
probe = probe_class(model_path)
fields["path"] = model_path.as_posix()
fields["type"] = fields.get("type") or model_type
fields["base"] = fields.get("base") or probe.get_base_type()
fields["variant"] = fields.get("variant") or probe.get_variant_type()
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
fields["description"] = (
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["original_hash"] = fields.get("original_hash") or hash
fields["current_hash"] = fields.get("current_hash") or hash
# additional fields needed for main and controlnet models
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
fields["config"] = cls._get_checkpoint_config_path(
model_path,
model_type=fields["type"],
base_type=fields["base"],
variant_type=fields["variant"],
prediction_type=fields["prediction_type"],
).as_posix()
# additional fields needed for main non-checkpoint models
elif fields["type"] == ModelType.Main and fields["format"] in [
ModelFormat.Onnx,
ModelFormat.Olive,
ModelFormat.Diffusers,
]:
fields["upcast_attention"] = fields.get("upcast_attention") or (
fields["base"] == BaseModelType.StableDiffusion2
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
)
model_info = ModelConfigFactory.make_config(fields)
return model_info
@classmethod
def get_model_name(cls, model_path: Path) -> str:
if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
return model_path.stem
else:
return model_path.name
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType:
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
raise InvalidModelConfigException(f"{model_path}: unrecognized suffix")
if model_path.name == "learned_embeds.bin":
return ModelType.TextualInversion
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
ckpt = ckpt.get("state_dict", ckpt)
for key in ckpt.keys():
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
return ModelType.Vae
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.Lora
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
return ModelType.Lora
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.ControlNet
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
else:
# diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
@classmethod
def get_model_type_from_folder(cls, folder_path: Path) -> ModelType:
"""Get the model type of a hugging-face style folder."""
class_name = None
error_hint = None
for suffix in ["bin", "safetensors"]:
if (folder_path / f"learned_embeds.{suffix}").exists():
return ModelType.TextualInversion
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
return ModelType.Lora
if (folder_path / "unet/model.onnx").exists():
return ModelType.ONNX
if (folder_path / "image_encoder.txt").exists():
return ModelType.IPAdapter
i = folder_path / "model_index.json"
c = folder_path / "config.json"
config_path = i if i.exists() else c if c.exists() else None
if config_path:
with open(config_path, "r") as file:
conf = json.load(file)
if "_class_name" in conf:
class_name = conf["_class_name"]
elif "architectures" in conf:
class_name = conf["architectures"][0]
else:
class_name = None
else:
error_hint = f"No model_index.json or config.json found in {folder_path}."
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type
else:
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
# give up
raise InvalidModelConfigException(
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
)
@classmethod
def _get_checkpoint_config_path(
cls,
model_path: Path,
model_type: ModelType,
base_type: BaseModelType,
variant_type: ModelVariantType,
prediction_type: SchedulerPredictionType,
) -> Path:
# look for a YAML file adjacent to the model file first
possible_conf = model_path.with_suffix(".yaml")
if possible_conf.exists():
return possible_conf.absolute()
if model_type == ModelType.Main:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
elif model_type == ModelType.ControlNet:
config_file = (
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
)
else:
raise InvalidModelConfigException(
f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}"
)
assert isinstance(config_file, str)
return Path(config_file)
@classmethod
def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType:
with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model_path.name, model_path)
model = torch.load(model_path)
assert isinstance(model, dict)
return model
else:
return safetensors.torch.load_file(model_path)
@classmethod
def _scan_model(cls, model_name: str, checkpoint: Path) -> None:
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
# ##################################################3
# Checkpoint probing
# ##################################################3
class CheckpointProbeBase(ProbeBase):
def __init__(self, model_path: Path):
super().__init__(model_path)
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
def get_format(self) -> ModelFormat:
return ModelFormat("checkpoint")
def get_variant_type(self) -> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
if model_type != ModelType.Main:
return ModelVariantType.Normal
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
else:
raise InvalidModelConfigException(
f"Cannot determine variant type (in_channels={in_channels}) at {self.model_path}"
)
class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
else:
raise InvalidModelConfigException("Cannot determine base type")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
"""Return model prediction type."""
type = self.get_base_type()
if type == BaseModelType.StableDiffusion2:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if "global_step" in checkpoint:
if checkpoint["global_step"] == 220000:
return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000:
return SchedulerPredictionType.VPrediction
return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
elif type == BaseModelType.StableDiffusion1:
return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
else:
return SchedulerPredictionType.Epsilon
class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1
class LoRACheckpointProbe(CheckpointProbeBase):
"""Class for LoRA checkpoints."""
def get_format(self) -> ModelFormat:
return ModelFormat("lycoris")
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"Unknown LoRA type: {self.model_path}")
class TextualInversionCheckpointProbe(CheckpointProbeBase):
"""Class for probing embeddings."""
def get_format(self) -> ModelFormat:
return ModelFormat.EmbeddingFile
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
if "string_to_token" in checkpoint:
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
elif "emb_params" in checkpoint:
token_dim = checkpoint["emb_params"].shape[-1]
elif "clip_g" in checkpoint:
token_dim = checkpoint["clip_g"].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768:
return BaseModelType.StableDiffusion1
elif token_dim == 1024:
return BaseModelType.StableDiffusion2
elif token_dim == 1280:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type")
class ControlNetCheckpointProbe(CheckpointProbeBase):
"""Class for probing controlnets."""
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
for key_name in (
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
):
if key_name not in checkpoint:
continue
if checkpoint[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
elif checkpoint[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
raise InvalidModelConfigException("{self.model_path}: Unable to determine base type")
class IPAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
########################################################
# classes for probing folders
#######################################################
class FolderProbeBase(ProbeBase):
def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal
def get_format(self) -> ModelFormat:
return ModelFormat("diffusers")
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
with open(self.model_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
scheduler_conf = json.load(file)
if scheduler_conf["prediction_type"] == "v_prediction":
return SchedulerPredictionType.VPrediction
elif scheduler_conf["prediction_type"] == "epsilon":
return SchedulerPredictionType.Epsilon
else:
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
def get_variant_type(self) -> ModelVariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the
# "normal" variant type
try:
config_file = self.model_path / "unet" / "config.json"
with open(config_file, "r") as file:
conf = json.load(file)
in_channels = conf["in_channels"]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
except Exception:
pass
return ModelVariantType.Normal
class VaeFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._name_looks_like_sdxl():
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
return BaseModelType.StableDiffusionXL
else:
return BaseModelType.StableDiffusion1
def _config_looks_like_sdxl(self) -> bool:
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
config_file = self.model_path / "config.json"
if not config_file.exists():
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
with open(config_file, "r") as file:
config = json.load(file)
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
def _name_looks_like_sdxl(self) -> bool:
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
def _guess_name(self) -> str:
name = self.model_path.name
if name == "vae":
name = self.model_path.parent.name
return name
class TextualInversionFolderProbe(FolderProbeBase):
def get_format(self) -> ModelFormat:
return ModelFormat.EmbeddingFolder
def get_base_type(self) -> BaseModelType:
path = self.model_path / "learned_embeds.bin"
if not path.exists():
raise InvalidModelConfigException(
f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file"
)
return TextualInversionCheckpointProbe(path).get_base_type()
class ONNXFolderProbe(FolderProbeBase):
def get_format(self) -> ModelFormat:
return ModelFormat("onnx")
def get_base_type(self) -> BaseModelType:
return BaseModelType.StableDiffusion1
def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal
class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
if not config_file.exists():
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
with open(config_file, "r") as file:
config = json.load(file)
# no obvious way to distinguish between sd2-base and sd2-768
dimension = config["cross_attention_dim"]
base_model = (
BaseModelType.StableDiffusion1
if dimension == 768
else (
BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
)
)
if not base_model:
raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}")
return base_model
class LoRAFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
model_file = None
for suffix in ["safetensors", "bin"]:
base_file = self.model_path / f"pytorch_lora_weights.{suffix}"
if base_file.exists():
model_file = base_file
break
if not model_file:
raise InvalidModelConfigException("Unknown LoRA format encountered")
return LoRACheckpointProbe(model_file).get_base_type()
class IPAdapterFolderProbe(FolderProbeBase):
def get_format(self) -> IPAdapterModelFormat:
return IPAdapterModelFormat.InvokeAI.value
def get_base_type(self) -> BaseModelType:
model_file = self.model_path / "ip_adapter.bin"
if not model_file.exists():
raise InvalidModelConfigException("Unknown IP-Adapter model format.")
state_dict = torch.load(model_file, map_location="cpu")
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
return BaseModelType.StableDiffusion1
elif cross_attention_dim == 1024:
return BaseModelType.StableDiffusion2
elif cross_attention_dim == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
)
class CLIPVisionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
if not config_file.exists():
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
with open(config_file, "r") as file:
config = json.load(file)
adapter_type = config.get("adapter_type", None)
if adapter_type == "full_adapter_xl":
return BaseModelType.StableDiffusionXL
elif adapter_type == "full_adapter" or "light_adapter":
# I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
return BaseModelType.StableDiffusion1
else:
raise InvalidModelConfigException(
f"Unable to determine base model for '{self.model_path}' (adapter_type = {adapter_type})."
)
############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@ -0,0 +1,190 @@
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
"""
Abstract base class and implementation for recursive directory search for models.
Example usage:
```
from invokeai.backend.model_manager import ModelSearch, ModelProbe
def find_main_models(model: Path) -> bool:
info = ModelProbe.probe(model)
if info.model_type == 'main' and info.base_type == 'sd-1':
return True
else:
return False
search = ModelSearch(on_model_found=report_it)
found = search.search('/tmp/models')
print(found) # list of matching model paths
print(search.stats) # search stats
```
"""
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Optional, Set, Union
from pydantic import BaseModel, Field
from invokeai.backend.util.logging import InvokeAILogger
default_logger = InvokeAILogger.get_logger()
class SearchStats(BaseModel):
items_scanned: int = 0
models_found: int = 0
models_filtered: int = 0
class ModelSearchBase(ABC, BaseModel):
"""
Abstract directory traversal model search class
Usage:
search = ModelSearchBase(
on_search_started = search_started_callback,
on_search_completed = search_completed_callback,
on_model_found = model_found_callback,
)
models_found = search.search('/path/to/directory')
"""
# fmt: off
on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
logger : InvokeAILogger = Field(default=default_logger, description="Logger instance.") # noqa E221
# fmt: on
class Config:
arbitrary_types_allowed = True
@abstractmethod
def search_started(self) -> None:
"""
Called before the scan starts.
Passes the root search directory to the Callable `on_search_started`.
"""
pass
@abstractmethod
def model_found(self, model: Path) -> None:
"""
Called when a model is found during search.
:param model: Model to process - could be a directory or checkpoint.
Passes the model's Path to the Callable `on_model_found`.
This Callable receives the path to the model and returns a boolean
to indicate whether the model should be returned in the search
results.
"""
pass
@abstractmethod
def search_completed(self) -> None:
"""
Called before the scan starts.
Passes the Set of found model Paths to the Callable `on_search_completed`.
"""
pass
@abstractmethod
def search(self, directory: Union[Path, str]) -> Set[Path]:
"""
Recursively search for models in `directory` and return a set of model paths.
If provided, the `on_search_started`, `on_model_found` and `on_search_completed`
Callables will be invoked during the search.
"""
pass
class ModelSearch(ModelSearchBase):
"""
Implementation of ModelSearch with callbacks.
Usage:
search = ModelSearch()
search.model_found = lambda path : 'anime' in path.as_posix()
found = search.list_models(['/tmp/models1','/tmp/models2'])
# returns all models that have 'anime' in the path
"""
models_found: Set[Path] = Field(default=None)
scanned_dirs: Set[Path] = Field(default=None)
pruned_paths: Set[Path] = Field(default=None)
def search_started(self) -> None:
self.models_found = set()
self.scanned_dirs = set()
self.pruned_paths = set()
if self.on_search_started:
self.on_search_started(self._directory)
def model_found(self, model: Path) -> None:
self.stats.models_found += 1
if not self.on_model_found or self.on_model_found(model):
self.stats.models_filtered += 1
self.models_found.add(model)
def search_completed(self) -> None:
if self.on_search_completed:
self.on_search_completed(self._models_found)
def search(self, directory: Union[Path, str]) -> Set[Path]:
self._directory = Path(directory)
self.stats = SearchStats() # zero out
self.search_started() # This will initialize _models_found to empty
self._walk_directory(directory)
self.search_completed()
return self.models_found
def _walk_directory(self, path: Union[Path, str]) -> None:
for root, dirs, files in os.walk(path, followlinks=True):
# don't descend into directories that start with a "."
# to avoid the Mac .DS_STORE issue.
if str(Path(root).name).startswith("."):
self.pruned_paths.add(Path(root))
if any(Path(root).is_relative_to(x) for x in self.pruned_paths):
continue
self.stats.items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path.parent in self.scanned_dirs:
self.scanned_dirs.add(path)
continue
if any(
(path / x).exists()
for x in [
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
]
):
self.scanned_dirs.add(path)
try:
self.model_found(path)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path.parent in self.scanned_dirs:
continue
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
try:
self.model_found(path)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))

View File

@ -11,4 +11,7 @@ from .devices import ( # noqa: F401
normalize_device,
torch_dtype,
)
from .logging import InvokeAILogger
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401
__all__ = ["Chdir", "InvokeAILogger", "choose_precision", "choose_torch_device"]

View File

@ -164,6 +164,7 @@ nav:
- Overview: 'contributing/contribution_guides/development.md'
- New Contributors: 'contributing/contribution_guides/newContributorChecklist.md'
- InvokeAI Architecture: 'contributing/ARCHITECTURE.md'
- Model Manager v2: 'contributing/MODEL_MANAGER.md'
- Frontend Documentation: 'contributing/contribution_guides/contributingToFrontend.md'
- Local Development: 'contributing/LOCAL_DEVELOPMENT.md'
- Adding Tests: 'contributing/TESTS.md'

View File

@ -221,6 +221,8 @@ exclude = [
# global mypy config
[tool.mypy]
ignore_missing_imports = true # ignores missing types in third-party libraries
strict = true
exclude = ["tests/*"]
# overrides for specific modules
[[tool.mypy.overrides]]

View File

@ -1,9 +1,11 @@
#!/bin/env python
"""Little command-line utility for probing a model on disk."""
import argparse
from pathlib import Path
from invokeai.backend.model_management.model_probe import ModelProbe
from invokeai.backend.model_manager import InvalidModelConfigException, ModelProbe
parser = argparse.ArgumentParser(description="Probe model type")
parser.add_argument(
@ -14,5 +16,8 @@ parser.add_argument(
args = parser.parse_args()
for path in args.model_path:
info = ModelProbe().probe(path)
print(f"{path}: {info}")
try:
info = ModelProbe.probe(path)
print(f"{path}:{info.model_dump_json(indent=4)}")
except InvalidModelConfigException as exc:
print(exc)

View File

@ -69,6 +69,7 @@ def mock_services() -> InvocationServices:
logger=logging, # type: ignore
model_manager=None, # type: ignore
model_records=None, # type: ignore
model_install=None, # type: ignore
names=None, # type: ignore
performance_statistics=InvocationStatsService(),
processor=DefaultInvocationProcessor(),

View File

@ -74,6 +74,7 @@ def mock_services() -> InvocationServices:
logger=logging, # type: ignore
model_manager=None, # type: ignore
model_records=None, # type: ignore
model_install=None, # type: ignore
names=None, # type: ignore
performance_statistics=InvocationStatsService(),
processor=DefaultInvocationProcessor(),

View File

@ -12,7 +12,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
prepare_values_to_insert,
)
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
from tests.nodes.test_nodes import PromptTestInvocation
from tests.aa_nodes.test_nodes import PromptTestInvocation
@pytest.fixture

View File

@ -0,0 +1,195 @@
"""
Test the model installer
"""
from pathlib import Path
from typing import Any, Dict, List
import pytest
from pydantic import BaseModel, ValidationError
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import (
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallService,
ModelInstallServiceBase,
)
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.util.logging import InvokeAILogger
@pytest.fixture
def test_file(datadir: Path) -> Path:
return datadir / "test_embedding.safetensors"
@pytest.fixture
def app_config(datadir: Path) -> InvokeAIAppConfig:
return InvokeAIAppConfig(
root=datadir / "root",
models_dir=datadir / "root/models",
)
@pytest.fixture
def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
database = SqliteDatabase(app_config, InvokeAILogger.get_logger(config=app_config))
store: ModelRecordServiceBase = ModelRecordServiceSQL(database)
return store
@pytest.fixture
def installer(app_config: InvokeAIAppConfig, store: ModelRecordServiceBase) -> ModelInstallServiceBase:
return ModelInstallService(
app_config=app_config,
record_store=store,
event_bus=DummyEventService(),
)
class DummyEvent(BaseModel):
"""Dummy Event to use with Dummy Event service."""
event_name: str
payload: Dict[str, Any]
class DummyEventService(EventServiceBase):
"""Dummy event service for testing."""
events: List[DummyEvent]
def __init__(self) -> None:
super().__init__()
self.events = []
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events."""
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> None:
store = installer.record_store
matches = store.search_by_attr(model_name="test_embedding")
assert len(matches) == 0
key = installer.register_path(test_file)
assert key is not None
assert len(key) == 32
def test_registration_meta(installer: ModelInstallServiceBase, test_file: Path) -> None:
store = installer.record_store
key = installer.register_path(test_file)
model_record = store.get_model(key)
assert model_record is not None
assert model_record.name == "test_embedding"
assert model_record.type == ModelType.TextualInversion
assert Path(model_record.path) == test_file
assert model_record.base == BaseModelType("sd-1")
assert model_record.description is not None
assert model_record.source is not None
assert Path(model_record.source) == test_file
def test_registration_meta_override_fail(installer: ModelInstallServiceBase, test_file: Path) -> None:
key = None
with pytest.raises(ValidationError):
key = installer.register_path(test_file, {"name": "banana_sushi", "type": ModelType("lora")})
assert key is None
def test_registration_meta_override_succeed(installer: ModelInstallServiceBase, test_file: Path) -> None:
store = installer.record_store
key = installer.register_path(
test_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash"}
)
model_record = store.get_model(key)
assert model_record.name == "banana_sushi"
assert model_record.source == "fake/repo_id"
assert model_record.current_hash == "New Hash"
def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
store = installer.record_store
key = installer.install_path(test_file)
model_record = store.get_model(key)
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
assert model_record.source == test_file.as_posix()
def test_background_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
"""Note: may want to break this down into several smaller unit tests."""
path = test_file
description = "Test of metadata assignment"
source = LocalModelSource(path=path, inplace=False)
job = installer.import_model(source, config={"description": description})
assert job is not None
assert isinstance(job, ModelInstallJob)
# See if job is registered properly
assert job in installer.get_job(source)
# test that the job object tracked installation correctly
jobs = installer.wait_for_installs()
assert len(jobs) > 0
my_job = [x for x in jobs if x.source == source]
assert len(my_job) == 1
assert my_job[0].status == InstallStatus.COMPLETED
# test that the expected events were issued
bus = installer.event_bus
assert bus is not None # sigh - ruff is a stickler for type checking
assert isinstance(bus, DummyEventService)
assert len(bus.events) == 2
event_names = [x.event_name for x in bus.events]
assert "model_install_started" in event_names
assert "model_install_completed" in event_names
assert Path(bus.events[0].payload["source"]) == source
assert Path(bus.events[1].payload["source"]) == source
key = bus.events[1].payload["key"]
assert key is not None
# see if the thing actually got installed at the expected location
model_record = installer.record_store.get_model(key)
assert model_record is not None
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
assert Path(app_config.models_dir / model_record.path).exists()
# see if metadata was properly passed through
assert model_record.description == description
# see if prune works properly
installer.prune_jobs()
assert not installer.get_job(source)
def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
store = installer.record_store
key = installer.install_path(test_file)
model_record = store.get_model(key)
assert Path(app_config.models_dir / model_record.path).exists()
assert test_file.exists() # original should still be there after installation
installer.delete(key)
assert not Path(
app_config.models_dir / model_record.path
).exists() # after deletion, installed copy should not exist
assert test_file.exists() # but original should still be there
with pytest.raises(UnknownModelException):
store.get_model(key)
def test_delete_register(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
store = installer.record_store
key = installer.register_path(test_file)
model_record = store.get_model(key)
assert Path(app_config.models_dir / model_record.path).exists()
assert test_file.exists() # original should still be there after installation
installer.delete(key)
assert Path(app_config.models_dir / model_record.path).exists()
with pytest.raises(UnknownModelException):
store.get_model(key)

View File

@ -0,0 +1 @@
This directory is used by pytest-datadir.

View File

@ -0,0 +1,79 @@
model:
base_learning_rate: 1.0e-04
target: invokeai.backend.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: invokeai.backend.stable_diffusion.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
personalization_config:
target: invokeai.backend.stable_diffusion.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ['sculpture']
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: invokeai.backend.stable_diffusion.encoders.modules.WeightedFrozenCLIPEmbedder

View File

@ -0,0 +1 @@
Dummy file to establish git path.