all features implemented, docs updated, ready for review

This commit is contained in:
Lincoln Stein 2023-11-26 13:18:21 -05:00
parent dc5c452ef9
commit 8695ad6f59
12 changed files with 542 additions and 469 deletions

View File

@ -10,40 +10,36 @@ model. These are the:
tracks the type of the model, its provenance, and where it can be tracks the type of the model, its provenance, and where it can be
found on disk. found on disk.
* _ModelLoadServiceBase_ Responsible for loading a model from disk
into RAM and VRAM and getting it ready for inference.
* _DownloadQueueServiceBase_ A multithreaded downloader responsible
for downloading models from a remote source to disk. The download
queue has special methods for downloading repo_id folders from
Hugging Face, as well as discriminating among model versions in
Civitai, but can be used for arbitrary content.
* _ModelInstallServiceBase_ A service for installing models to * _ModelInstallServiceBase_ A service for installing models to
disk. It uses `DownloadQueueServiceBase` to download models and disk. It uses `DownloadQueueServiceBase` to download models and
their metadata, and `ModelRecordServiceBase` to store that their metadata, and `ModelRecordServiceBase` to store that
information. It is also responsible for managing the InvokeAI information. It is also responsible for managing the InvokeAI
`models` directory and its contents. `models` directory and its contents.
* _DownloadQueueServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
A multithreaded downloader responsible
for downloading models from a remote source to disk. The download
queue has special methods for downloading repo_id folders from
Hugging Face, as well as discriminating among model versions in
Civitai, but can be used for arbitrary content.
* _ModelLoadServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
Responsible for loading a model from disk
into RAM and VRAM and getting it ready for inference.
## Location of the Code ## Location of the Code
All four of these services can be found in All four of these services can be found in
`invokeai/app/services` in the following directories: `invokeai/app/services` in the following directories:
* `invokeai/app/services/model_records/` * `invokeai/app/services/model_records/`
* `invokeai/app/services/downloads/`
* `invokeai/app/services/model_loader/`
* `invokeai/app/services/model_install/` * `invokeai/app/services/model_install/`
* `invokeai/app/services/model_loader/` (**under development**)
With the exception of the install service, each of these is a thin * `invokeai/app/services/downloads/`(**under development**)
shell around a corresponding implementation located in
`invokeai/backend/model_manager`. The main difference between the
modules found in app services and those in the backend folder is that
the former add support for event reporting and are more tied to the
needs of the InvokeAI API.
Code related to the FastAPI web API can be found in Code related to the FastAPI web API can be found in
`invokeai/app/api/routers/models.py`. `invokeai/app/api/routers/model_records.py`.
*** ***
@ -165,10 +161,6 @@ of the fields, including `name`, `model_type` and `base_model`, are
shared between `ModelConfigBase` and `ModelBase`, and this is a shared between `ModelConfigBase` and `ModelBase`, and this is a
potential source of confusion. potential source of confusion.
** TO DO: ** The `ModelBase` code needs to be revised to reduce the
duplication of similar classes and to support using the `key` as the
primary model identifier.
## Reading and Writing Model Configuration Records ## Reading and Writing Model Configuration Records
The `ModelRecordService` provides the ability to retrieve model The `ModelRecordService` provides the ability to retrieve model
@ -362,7 +354,7 @@ model and pass its key to `get_model()`.
Several methods allow you to create and update stored model config Several methods allow you to create and update stored model config
records. records.
#### add_model(key, config) -> ModelConfigBase: #### add_model(key, config) -> AnyModelConfig:
Given a key and a configuration, this will add the model's Given a key and a configuration, this will add the model's
configuration record to the database. `config` can either be a subclass of configuration record to the database. `config` can either be a subclass of
@ -386,27 +378,350 @@ fields to be updated. This will return an `AnyModelConfig` on success,
or raise `InvalidModelConfigException` or `UnknownModelException` or raise `InvalidModelConfigException` or `UnknownModelException`
exceptions on failure. exceptions on failure.
***TO DO:*** Investigate why `update_model()` returns an
`AnyModelConfig` while `add_model()` returns a `ModelConfigBase`.
### rename_model(key, new_name) -> ModelConfigBase:
This is a special case of `update_model()` for the use case of
changing the model's name. It is broken out because there are cases in
which the InvokeAI application wants to synchronize the model's name
with its path in the `models` directory after changing the name, type
or base. However, when using the ModelRecordService directly, the call
is equivalent to:
```
store.rename_model(key, {'name': 'new_name'})
```
***TO DO:*** Investigate why `rename_model()` is returning a
`ModelConfigBase` while `update_model()` returns a `AnyModelConfig`.
*** ***
## Model installation
The `ModelInstallService` class implements the
`ModelInstallServiceBase` abstract base class, and provides a one-stop
shop for all your model install needs. It provides the following
functionality:
- Registering a model config record for a model already located on the
local filesystem, without moving it or changing its path.
- Installing a model alreadiy located on the local filesystem, by
moving it into the InvokeAI root directory under the
`models` folder (or wherever config parameter `models_dir`
specifies).
- Probing of models to determine their type, base type and other key
information.
- Interface with the InvokeAI event bus to provide status updates on
the download, installation and registration process.
- Downloading a model from an arbitrary URL and installing it in
`models_dir` (_implementation pending_).
- Special handling for Civitai model URLs which allow the user to
paste in a model page's URL or download link (_implementation pending_).
- Special handling for HuggingFace repo_ids to recursively download
the contents of the repository, paying attention to alternative
variants such as fp16. (_implementation pending_)
### Initializing the installer
A default installer is created at InvokeAI api startup time and stored
in `ApiDependencies.invoker.services.model_install` and can
also be retrieved from an invocation's `context` argument with
`context.services.model_install`.
In the event you wish to create a new installer, you may use the
following initialization pattern:
```
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config, logger)
store = ModelRecordServiceSQL(db)
installer = ModelInstallService(config, store)
```
The full form of `ModelInstallService()` takes the following
required parameters:
| **Argument** | **Type** | **Description** |
|------------------|------------------------------|------------------------------|
| `config` | InvokeAIAppConfig | InvokeAI app configuration object |
| `record_store` | ModelRecordServiceBase | Config record storage database |
| `event_bus` | EventServiceBase | Optional event bus to send download/install progress events to |
Once initialized, the installer will provide the following methods:
#### install_job = installer.import_model()
The `import_model()` method is the core of the installer. The
following illustrates basic usage:
```
sources = [
Path('/opt/models/sushi.safetensors'), # a local safetensors file
Path('/opt/models/sushi_diffusers/'), # a local diffusers folder
'runwayml/stable-diffusion-v1-5', # a repo_id
'runwayml/stable-diffusion-v1-5:vae', # a subfolder within a repo_id
'https://civitai.com/api/download/models/63006', # a civitai direct download link
'https://civitai.com/models/8765?modelVersionId=10638', # civitai model page
'https://s3.amazon.com/fjacks/sd-3.safetensors', # arbitrary URL
]
for source in sources:
install_job = installer.install_model(source)
source2job = installer.wait_for_installs()
for source in sources:
job = source2job[source]
if job.status == "completed":
model_config = job.config_out
model_key = model_config.key
print(f"{source} installed as {model_key}")
elif job.status == "error":
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
```
As shown here, the `import_model()` method accepts a variety of
sources, including local safetensors files, local diffusers folders,
HuggingFace repo_ids with and without a subfolder designation,
Civitai model URLs and arbitrary URLs that point to checkpoint files
(but not to folders).
Each call to `import_model()` return a `ModelInstallJob` job,
an object which tracks the progress of the install.
If a remote model is requested, the model's files are downloaded in
parallel across a multiple set of threads using the download
queue. During the download process, the `ModelInstallJob` is updated
to provide status and progress information. After the files (if any)
are downloaded, the remainder of the installation runs in a single
serialized background thread. These are the model probing, file
copying, and config record database update steps.
Multiple install jobs can be queued up. You may block until all
install jobs are completed (or errored) by calling the
`wait_for_installs()` method as shown in the code
example. `wait_for_installs()` will return a `dict` that maps the
requested source to its job. This object can be interrogated
to determine its status. If the job errored out, then the error type
and details can be recovered from `job.error_type` and `job.error`.
The full list of arguments to `import_model()` is as follows:
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `source` | Union[str, Path, AnyHttpUrl] | | The source of the model, Path, URL or repo_id |
| `inplace` | bool | True | Leave a local model in its current location |
| `variant` | str | None | Desired variant, such as 'fp16' or 'onnx' (HuggingFace only) |
| `subfolder` | str | None | Repository subfolder (HuggingFace only) |
| `config` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
| `access_token` | str | None | Provide authorization information needed to download |
The `inplace` field controls how local model Paths are handled. If
True (the default), then the model is simply registered in its current
location by the installer's `ModelConfigRecordService`. Otherwise, a
copy of the model put into the location specified by the `models_dir`
application configuration parameter.
The `variant` field is used for HuggingFace repo_ids only. If
provided, the repo_id download handler will look for and download
tensors files that follow the convention for the selected variant:
- "fp16" will select files named "*model.fp16.{safetensors,bin}"
- "onnx" will select files ending with the suffix ".onnx"
- "openvino" will select files beginning with "openvino_model"
In the special case of the "fp16" variant, the installer will select
the 32-bit version of the files if the 16-bit version is unavailable.
`subfolder` is used for HuggingFace repo_ids only. If provided, the
model will be downloaded from the designated subfolder rather than the
top-level repository folder. If a subfolder is attached to the repo_id
using the format `repo_owner/repo_name:subfolder`, then the subfolder
specified by the repo_id will override the subfolder argument.
`config` can be used to override all or a portion of the configuration
attributes returned by the model prober. See the section below for
details.
`access_token` is passed to the download queue and used to access
repositories that require it.
#### Monitoring the install job process
When you create an install job with `import_model()`, it launches the
download and installation process in the background and returns a
`ModelInstallJob` object for monitoring the process.
The `ModelInstallJob` class has the following structure:
| **Attribute** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `status` | `InstallStatus` | An enum of ["waiting", "running", "completed" and "error" |
| `config_in` | `dict` | Overriding configuration values provided by the caller |
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
| `local_path` | `Path` | If a remote model, holds the path of the model after it is downloaded; if a local model, same as `source` |
| `error_type` | `str` | Name of the exception that led to an error status |
| `error` | `str` | Traceback of the error |
If the `event_bus` argument was provided, events will also be
broadcast to the InvokeAI event bus. The events will appear on the bus
as an event of type `EventServiceBase.model_event`, a timestamp and
the following event names:
- `model_install_started`
The payload will contain the keys `timestamp` and `source`. The latter
indicates the requested model source for installation.
- `model_install_progress`
Emitted at regular intervals when downloading a remote model, the
payload will contain the keys `timestamp`, `source`, `current_bytes`
and `total_bytes`. These events are _not_ emitted when a local model
already on the filesystem is imported.
- `model_install_completed`
Issued once at the end of a successful installation. The payload will
contain the keys `timestamp`, `source` and `key`, where `key` is the
ID under which the model has been registered.
- `model_install_error`
Emitted if the installation process fails for some reason. The payload
will contain the keys `timestamp`, `source`, `error_type` and
`error`. `error_type` is a short message indicating the nature of the
error, and `error` is the long traceback to help debug the problem.
#### Model confguration and probing
The install service uses the `invokeai.backend.model_manager.probe`
module during import to determine the model's type, base type, and
other configuration parameters. Among other things, it assigns a
default name and description for the model based on probed
fields.
When downloading remote models is implemented, additional
configuration information, such as list of trigger terms, will be
retrieved from the HuggingFace and Civitai model repositories.
The probed values can be overriden by providing a dictionary in the
optional `config` argument passed to `import_model()`. You may provide
overriding values for any of the model's configuration
attributes. Here is an example of setting the
`SchedulerPredictionType` and `name` for an sd-2 model:
This is typically used to set
the model's name and description, but can also be used to overcome
cases in which automatic probing is unable to (correctly) determine
the model's attribute. The most common situation is the
`prediction_type` field for sd-2 (and rare sd-1) models. Here is an
example of how it works:
```
install_job = installer.import_model(
source='stabilityai/stable-diffusion-2-1',
variant='fp16',
config=dict(
prediction_type=SchedulerPredictionType('v_prediction')
name='stable diffusion 2 base model',
)
)
```
### Other installer methods
This section describes additional methods provided by the installer class.
#### source2job = installer.wait_for_installs()
Block until all pending installs are completed or errored and return a
dictionary that maps the model `source` to the completed
`ModelInstallJob`.
#### jobs = installer.list_jobs([source])
Return a list of all active and complete `ModelInstallJobs`. An
optional `source` argument allows you to filter the returned list by a
model source string pattern using a partial string match.
#### job = installer.get_job(source)
Return the `ModelInstallJob` corresponding to the indicated model source.
#### installer.prune_jobs
Remove non-pending jobs (completed or errored) from the job list
returned by `list_jobs()` and `get_job()`.
#### installer.app_config, installer.record_store,
installer.event_bus
Properties that provide access to the installer's `InvokeAIAppConfig`,
`ModelRecordServiceBase` and `EventServiceBase` objects.
#### key = installer.register_path(model_path, config), key = installer.install_path(model_path, config)
These methods bypass the download queue and directly register or
install the model at the indicated path, returning the unique ID for
the installed model.
Both methods accept a Path object corresponding to a checkpoint or
diffusers folder, and an optional dict of config attributes to use to
override the values derived from model probing.
The difference between `register_path()` and `install_path()` is that
the former creates a model configuration record without changing the
location of the model in the filesystem. The latter makes a copy of
the model inside the InvokeAI models directory before registering
it.
#### installer.unregister(key)
This will remove the model config record for the model at key, and is
equivalent to `installer.record_store.del_model(key)`
#### installer.delete(key)
This is similar to `unregister()` but has the additional effect of
conditionally deleting the underlying model file(s) if they reside
within the InvokeAI models directory
#### installer.unconditionally_delete(key)
This method is similar to `unregister()`, but also unconditionally
deletes the corresponding model weights file(s), regardless of whether
they are inside or outside the InvokeAI models hierarchy.
#### List[str]=installer.scan_directory(scan_dir: Path, install: bool)
This method will recursively scan the directory indicated in
`scan_dir` for new models and either install them in the models
directory or register them in place, depending on the setting of
`install` (default False).
The return value is the list of keys of the new installed/registered
models.
#### installer.sync_to_config()
This method synchronizes models in the models directory and autoimport
directory to those in the `ModelConfigRecordService` database. New
models are registered and orphan models are unregistered.
#### installer.start(invoker)
The `start` method is called by the API intialization routines when
the API starts up. Its effect is to call `sync_to_config()` to
synchronize the model record store database with what's currently on
disk.
# The remainder of this documentation is provisional, pending implementation of the Download and Load services
## Let's get loaded, the lowdown on ModelLoadService ## Let's get loaded, the lowdown on ModelLoadService
The `ModelLoadService` is responsible for loading a named model into The `ModelLoadService` is responsible for loading a named model into
@ -863,351 +1178,3 @@ other resources that it might have been using.
This will start/pause/cancel all jobs that have been submitted to the This will start/pause/cancel all jobs that have been submitted to the
queue and have not yet reached a terminal state. queue and have not yet reached a terminal state.
## Model installation
The `ModelInstallService` class implements the
`ModelInstallServiceBase` abstract base class, and provides a one-stop
shop for all your model install needs. It provides the following
functionality:
- Registering a model config record for a model already located on the
local filesystem, without moving it or changing its path.
- Installing a model alreadiy located on the local filesystem, by
moving it into the InvokeAI root directory under the
`models` folder (or wherever config parameter `models_dir`
specifies).
- Downloading a model from an arbitrary URL and installing it in
`models_dir`.
- Special handling for Civitai model URLs which allow the user to
paste in a model page's URL or download link. Any metadata provided
by Civitai, such as trigger terms, are captured and placed in the
model config record.
- Special handling for HuggingFace repo_ids to recursively download
the contents of the repository, paying attention to alternative
variants such as fp16.
- Probing of models to determine their type, base type and other key
information.
- Interface with the InvokeAI event bus to provide status updates on
the download, installation and registration process.
### Initializing the installer
A default installer is created at InvokeAI api startup time and stored
in `ApiDependencies.invoker.services.model_install_service` and can
also be retrieved from an invocation's `context` argument with
`context.services.model_install_service`.
In the event you wish to create a new installer, you may use the
following initialization pattern:
```
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download_manager import DownloadQueueServive
from invokeai.app.services.model_record_service import ModelRecordServiceBase
config = InvokeAI.get_config()
queue = DownloadQueueService()
store = ModelRecordServiceBase.open(config)
installer = ModelInstallService(config=config, queue=queue, store=store)
```
The full form of `ModelInstallService()` takes the following
parameters. Each parameter will default to a reasonable value, but it
is recommended that you set them explicitly as shown in the above example.
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `config` | InvokeAIAppConfig | Use system-wide config | InvokeAI app configuration object |
| `queue` | DownloadQueueServiceBase | Create a new download queue for internal use | Download queue |
| `store` | ModelRecordServiceBase | Use config to select the database to open | Config storage database |
| `event_bus` | EventServiceBase | None | An event bus to send download/install progress events to |
| `event_handlers` | List[DownloadEventHandler] | None | Event handlers for the download queue |
Note that if `store` is not provided, then the class will use
`ModelRecordServiceBase.open(config)` to select the database to use.
Once initialized, the installer will provide the following methods:
#### install_job = installer.install_model()
The `install_model()` method is the core of the installer. The
following illustrates basic usage:
```
sources = [
Path('/opt/models/sushi.safetensors'), # a local safetensors file
Path('/opt/models/sushi_diffusers/'), # a local diffusers folder
'runwayml/stable-diffusion-v1-5', # a repo_id
'runwayml/stable-diffusion-v1-5:vae', # a subfolder within a repo_id
'https://civitai.com/api/download/models/63006', # a civitai direct download link
'https://civitai.com/models/8765?modelVersionId=10638', # civitai model page
'https://s3.amazon.com/fjacks/sd-3.safetensors', # arbitrary URL
]
for source in sources:
install_job = installer.install_model(source)
source2key = installer.wait_for_installs()
for source in sources:
model_key = source2key[source]
print(f"{source} installed as {model_key}")
```
As shown here, the `install_model()` method accepts a variety of
sources, including local safetensors files, local diffusers folders,
HuggingFace repo_ids with and without a subfolder designation,
Civitai model URLs and arbitrary URLs that point to checkpoint files
(but not to folders).
Each call to `install_model()` will return a `ModelInstallJob` job, a
subclass of `DownloadJobBase`. The install job has additional
install-specific fields described in the next section.
Each install job will run in a series of background threads using
the object's download queue. You may block until all install jobs are
completed (or errored) by calling the `wait_for_installs()` method as
shown in the code example. `wait_for_installs()` will return a `dict`
that maps the requested source to the key of the installed model. In
the case that a model fails to download or install, its value in the
dict will be None. The actual cause of the error will be reported in
the corresponding job's `error` field.
Alternatively you may install event handlers and/or listen for events
on the InvokeAI event bus in order to monitor the progress of the
requested installs.
The full list of arguments to `model_install()` is as follows:
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `source` | Union[str, Path, AnyHttpUrl] | | The source of the model, Path, URL or repo_id |
| `inplace` | bool | True | Leave a local model in its current location |
| `variant` | str | None | Desired variant, such as 'fp16' or 'onnx' (HuggingFace only) |
| `subfolder` | str | None | Repository subfolder (HuggingFace only) |
| `probe_override` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
| `metadata` | ModelSourceMetadata | None | Provide metadata that will be added to model's config |
| `access_token` | str | None | Provide authorization information needed to download |
| `priority` | int | 10 | Download queue priority for the job |
The `inplace` field controls how local model Paths are handled. If
True (the default), then the model is simply registered in its current
location by the installer's `ModelConfigRecordService`. Otherwise, the
model will be moved into the location specified by the `models_dir`
application configuration parameter.
The `variant` field is used for HuggingFace repo_ids only. If
provided, the repo_id download handler will look for and download
tensors files that follow the convention for the selected variant:
- "fp16" will select files named "*model.fp16.{safetensors,bin}"
- "onnx" will select files ending with the suffix ".onnx"
- "openvino" will select files beginning with "openvino_model"
In the special case of the "fp16" variant, the installer will select
the 32-bit version of the files if the 16-bit version is unavailable.
`subfolder` is used for HuggingFace repo_ids only. If provided, the
model will be downloaded from the designated subfolder rather than the
top-level repository folder. If a subfolder is attached to the repo_id
using the format `repo_owner/repo_name:subfolder`, then the subfolder
specified by the repo_id will override the subfolder argument.
`probe_override` can be used to override all or a portion of the
attributes returned by the model prober. This can be used to overcome
cases in which automatic probing is unable to (correctly) determine
the model's attribute. The most common situation is the
`prediction_type` field for sd-2 (and rare sd-1) models. Here is an
example of how it works:
```
install_job = installer.install_model(
source='stabilityai/stable-diffusion-2-1',
variant='fp16',
probe_override=dict(
prediction_type=SchedulerPredictionType('v_prediction')
)
)
```
`metadata` allows you to attach custom metadata to the installed
model. See the next section for details.
`priority` and `access_token` are passed to the download queue and
have the same effect as they do for the DownloadQueueServiceBase.
#### Monitoring the install job process
When you create an install job with `model_install()`, events will be
passed to the list of `DownloadEventHandlers` provided at installer
initialization time. Event handlers can also be added to individual
model install jobs by calling their `add_handler()` method as
described earlier for the `DownloadQueueService`.
If the `event_bus` argument was provided, events will also be
broadcast to the InvokeAI event bus. The events will appear on the bus
as a singular event type named `model_event` with a payload of
`job`. You can then retrieve the job and check its status.
** TO DO: ** consider breaking `model_event` into
`model_install_started`, `model_install_completed`, etc. The event bus
features have not yet been tested with FastAPI/websockets, and it may
turn out that the job object is not serializable.
#### Model metadata and probing
The install service has special handling for HuggingFace and Civitai
URLs that capture metadata from the source and include it in the model
configuration record. For example, fetching the Civitai model 8765
will produce a config record similar to this (using YAML
representation):
```
5abc3ef8600b6c1cc058480eaae3091e:
path: sd-1/lora/to8contrast-1-5.safetensors
name: to8contrast-1-5
base_model: sd-1
model_type: lora
model_format: lycoris
key: 5abc3ef8600b6c1cc058480eaae3091e
hash: 5abc3ef8600b6c1cc058480eaae3091e
description: 'Trigger terms: to8contrast style'
author: theovercomer8
license: allowCommercialUse=Sell; allowDerivatives=True; allowNoCredit=True
source: https://civitai.com/models/8765?modelVersionId=10638
thumbnail_url: null
tags:
- model
- style
- portraits
```
For sources that do not provide model metadata, you can attach custom
fields by providing a `metadata` argument to `model_install()` using
an initialized `ModelSourceMetadata` object (available for import from
`model_install_service.py`):
```
from invokeai.app.services.model_install_service import ModelSourceMetadata
meta = ModelSourceMetadata(
name="my model",
author="Sushi Chef",
description="Highly customized model; trigger with 'sushi',"
license="mit",
thumbnail_url="http://s3.amazon.com/ljack/pics/sushi.png",
tags=list('sfw', 'food')
)
install_job = installer.install_model(
source='sushi_chef/model3',
variant='fp16',
metadata=meta,
)
```
It is not currently recommended to provide custom metadata when
installing from Civitai or HuggingFace source, as the metadata
provided by the source will overwrite the fields you provide. Instead,
after the model is installed you can use
`ModelRecordService.update_model()` to change the desired fields.
** TO DO: ** Change the logic so that the caller's metadata fields take
precedence over those provided by the source.
#### Other installer methods
This section describes additional, less-frequently-used attributes and
methods provided by the installer class.
##### installer.wait_for_installs()
This is equivalent to the `DownloadQueue` `join()` method. It will
block until all the active jobs in the install queue have reached a
terminal state (completed, errored or cancelled).
##### installer.queue, installer.store, installer.config
These attributes provide access to the `DownloadQueueServiceBase`,
`ModelConfigRecordServiceBase`, and `InvokeAIAppConfig` objects that
the installer uses.
For example, to temporarily pause all pending installations, you can
do this:
```
installer.queue.pause_all_jobs()
```
##### key = installer.register_path(model_path, overrides), key = installer.install_path(model_path, overrides)
These methods bypass the download queue and directly register or
install the model at the indicated path, returning the unique ID for
the installed model.
Both methods accept a Path object corresponding to a checkpoint or
diffusers folder, and an optional dict of attributes to use to
override the values derived from model probing.
The difference between `register_path()` and `install_path()` is that
the former will not move the model from its current position, while
the latter will move it into the `models_dir` hierarchy.
##### installer.unregister(key)
This will remove the model config record for the model at key, and is
equivalent to `installer.store.unregister(key)`
##### installer.delete(key)
This is similar to `unregister()` but has the additional effect of
deleting the underlying model file(s) -- even if they were outside the
`models_dir` directory!
##### installer.conditionally_delete(key)
This method will call `unregister()` if the model identified by `key`
is outside the `models_dir` hierarchy, and call `delete()` if the
model is inside.
#### List[str]=installer.scan_directory(scan_dir: Path, install: bool)
This method will recursively scan the directory indicated in
`scan_dir` for new models and either install them in the models
directory or register them in place, depending on the setting of
`install` (default False).
The return value is the list of keys of the new installed/registered
models.
#### installer.scan_models_directory()
This method scans the models directory for new models and registers
them in place. Models that are present in the
`ModelConfigRecordService` database whose paths are not found will be
unregistered.
#### installer.sync_to_config()
This method synchronizes models in the models directory and autoimport
directory to those in the `ModelConfigRecordService` database. New
models are registered and orphan models are unregistered.
#### hash=installer.hash(model_path)
This method is calls the fasthash algorithm on a model's Path
(either a file or a folder) to generate a unique ID based on the
contents of the model.
##### installer.start(invoker)
The `start` method is called by the API intialization routines when
the API starts up. Its effect is to call `sync_to_config()` to
synchronize the model record store database with what's currently on
disk.
This method should not ordinarily be called manually.

View File

@ -23,9 +23,9 @@ from ..services.invoker import Invoker
from ..services.item_storage.item_storage_sqlite import SqliteItemStorage from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
from ..services.model_install import ModelInstallService
from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL from ..services.model_records import ModelRecordServiceSQL
from ..services.model_install import ModelInstallService
from ..services.names.names_default import SimpleNameService from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue

View File

@ -4,7 +4,7 @@
from hashlib import sha1 from hashlib import sha1
from random import randbytes from random import randbytes
from typing import List, Optional, Any, Dict from typing import Any, Dict, List, Optional
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
@ -12,6 +12,7 @@ from pydantic import BaseModel, ConfigDict
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from typing_extensions import Annotated from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ( from invokeai.app.services.model_records import (
DuplicateModelException, DuplicateModelException,
InvalidModelException, InvalidModelException,
@ -22,11 +23,10 @@ from invokeai.backend.model_manager.config import (
BaseModelType, BaseModelType,
ModelType, ModelType,
) )
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"]) model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2"])
class ModelsList(BaseModel): class ModelsList(BaseModel):
@ -44,15 +44,16 @@ class ModelsList(BaseModel):
async def list_model_records( async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
) -> ModelsList: ) -> ModelsList:
"""Get a list of models.""" """Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records record_store = ApiDependencies.invoker.services.model_records
found_models: list[AnyModelConfig] = [] found_models: list[AnyModelConfig] = []
if base_models: if base_models:
for base_model in base_models: for base_model in base_models:
found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type)) found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name))
else: else:
found_models.extend(record_store.search_by_attr(model_type=model_type)) found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name))
return ModelsList(models=found_models) return ModelsList(models=found_models)
@ -118,12 +119,17 @@ async def update_model_record(
async def del_model_record( async def del_model_record(
key: str = Path(description="Unique key of model to remove from model registry."), key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response: ) -> Response:
"""Delete Model""" """
Delete model record from database.
The configuration record will be removed. The corresponding weights files will be
deleted as well if they reside within the InvokeAI "models" directory.
"""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
record_store = ApiDependencies.invoker.services.model_records installer = ApiDependencies.invoker.services.model_install
record_store.del_model(key) installer.delete(key)
logger.info(f"Deleted model: {key}") logger.info(f"Deleted model: {key}")
return Response(status_code=204) return Response(status_code=204)
except UnknownModelException as e: except UnknownModelException as e:
@ -181,8 +187,8 @@ async def import_model(
source: ModelSource = Body( source: ModelSource = Body(
description="A model path, repo_id or URL to import. NOTE: only model path is implemented currently!" description="A model path, repo_id or URL to import. NOTE: only model path is implemented currently!"
), ),
metadata: Optional[Dict[str, Any]] = Body( config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values, such as name, description and prediction_type ", description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None, default=None,
), ),
variant: Optional[str] = Body( variant: Optional[str] = Body(
@ -208,8 +214,13 @@ async def import_model(
automatically. To override the default guesses, pass "metadata" automatically. To override the default guesses, pass "metadata"
with a Dict containing the attributes you wish to override. with a Dict containing the attributes you wish to override.
Listen on the event bus for the following events: Installation occurs in the background. Either use list_model_install_jobs()
"model_install_started", "model_install_completed", and "model_install_error." to poll for completion, or listen on the event bus for the following events:
"model_install_started"
"model_install_completed"
"model_install_error"
On successful completion, the event's payload will contain the field "key" On successful completion, the event's payload will contain the field "key"
containing the installed ID of the model. On an error, the event's payload containing the installed ID of the model. On an error, the event's payload
will contain the fields "error_type" and "error" describing the nature of the will contain the fields "error_type" and "error" describing the nature of the
@ -222,11 +233,12 @@ async def import_model(
installer = ApiDependencies.invoker.services.model_install installer = ApiDependencies.invoker.services.model_install
result: ModelInstallJob = installer.import_model( result: ModelInstallJob = installer.import_model(
source, source,
metadata=metadata, config=config,
variant=variant, variant=variant,
subfolder=subfolder, subfolder=subfolder,
access_token=access_token, access_token=access_token,
) )
logger.info(f"Started installation of {source}")
except UnknownModelException as e: except UnknownModelException as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@ -242,7 +254,7 @@ async def import_model(
"/import", "/import",
operation_id="list_model_install_jobs", operation_id="list_model_install_jobs",
) )
async def list_install_jobs( async def list_model_install_jobs(
source: Optional[str] = Query(description="Filter list by install source, partial string match.", source: Optional[str] = Query(description="Filter list by install source, partial string match.",
default=None, default=None,
) )
@ -255,3 +267,36 @@ async def list_install_jobs(
""" """
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs(source) jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs(source)
return jobs return jobs
@model_records_router.patch(
"/import",
operation_id="prune_model_install_jobs",
responses={
204: {"description": "All completed and errored jobs have been pruned"},
400: {"description": "Bad request"},
},
)
async def prune_model_install_jobs(
) -> Response:
"""
Prune all completed and errored jobs from the install job list.
"""
ApiDependencies.invoker.services.model_install.prune_jobs()
return Response(status_code=204)
@model_records_router.patch(
"/sync",
operation_id="sync_models_to_config",
responses={
204: {"description": "Model config record database resynced with files on disk"},
400: {"description": "Bad request"},
},
)
async def sync_models_to_config(
) -> Response:
"""
Traverse the models and autoimport directories. Model files without a corresponding
record in the database are added. Orphan records without a models file are deleted.
"""
ApiDependencies.invoker.services.model_install.sync_to_config()
return Response(status_code=204)

View File

@ -351,6 +351,29 @@ class EventServiceBase:
}, },
) )
def emit_model_install_progress(self,
source: str,
current_bytes: int,
total_bytes: int,
) -> None:
"""
Emitted while the install job is in progress.
(Downloaded models only)
:param source: Source of the model
:param current_bytes: Number of bytes downloaded so far
:param total_bytes: Total bytes to download
"""
self.__emit_model_event(
event_name="model_install_progress",
payload={
"source": source,
"current_bytes": int,
"total_bytes": int,
},
)
def emit_model_install_error(self, def emit_model_install_error(self,
source: str, source: str,
error_type: str, error_type: str,

View File

@ -21,9 +21,9 @@ if TYPE_CHECKING:
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .item_storage.item_storage_base import ItemStorageABC from .item_storage.item_storage_base import ItemStorageABC
from .latents_storage.latents_storage_base import LatentsStorageBase from .latents_storage.latents_storage_base import LatentsStorageBase
from .model_install import ModelInstallServiceBase
from .model_manager.model_manager_base import ModelManagerServiceBase from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase from .model_records import ModelRecordServiceBase
from .model_install import ModelInstallServiceBase
from .names.names_base import NameServiceBase from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase from .session_queue.session_queue_base import SessionQueueBase
@ -52,7 +52,7 @@ class InvocationServices:
logger: "Logger" logger: "Logger"
model_manager: "ModelManagerServiceBase" model_manager: "ModelManagerServiceBase"
model_records: "ModelRecordServiceBase" model_records: "ModelRecordServiceBase"
model_install: "ModelRecordInstallServiceBase" model_install: "ModelInstallServiceBase"
processor: "InvocationProcessorABC" processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase" performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC" queue: "InvocationQueueABC"

View File

@ -1,6 +1,12 @@
"""Initialization file for model install service package.""" """Initialization file for model install service package."""
from .model_install_base import InstallStatus, ModelInstallServiceBase, ModelInstallJob, UnknownInstallJobException, ModelSource from .model_install_base import (
InstallStatus,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
UnknownInstallJobException,
)
from .model_install_default import ModelInstallService from .model_install_default import ModelInstallService
__all__ = ['ModelInstallServiceBase', __all__ = ['ModelInstallServiceBase',

View File

@ -9,7 +9,9 @@ from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events import EventServiceBase from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig
class InstallStatus(str, Enum): class InstallStatus(str, Enum):
@ -31,13 +33,13 @@ ModelSource = Union[str, Path, AnyHttpUrl]
class ModelInstallJob(BaseModel): class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request.""" """Object that tracks the current status of an install request."""
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Configuration metadata to apply to model before installing it") config_in: Dict[str, Any] = Field(default_factory=dict, description="Configuration information (e.g. 'description') to apply to model.")
config_out: Optional[AnyModelConfig] = Field(default=None, description="After successful installation, this will hold the configuration object.")
inplace: bool = Field(default=False, description="Leave model in its current location; otherwise install under models directory") inplace: bool = Field(default=False, description="Leave model in its current location; otherwise install under models directory")
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model") source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source") local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
key: str = Field(default="<NO KEY>", description="After model is installed, this is its config record key") error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR")
error_type: str = Field(default="", description="Class name of the exception that led to status==ERROR") error: Optional[str] = Field(default=None, description="Error traceback") # noqa #501
error: str = Field(default="", description="Error traceback") # noqa #501
def set_error(self, e: Exception) -> None: def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception.""" """Record the error and traceback from an exception."""
@ -64,6 +66,9 @@ class ModelInstallServiceBase(ABC):
:param event_bus: InvokeAI event bus for reporting events to. :param event_bus: InvokeAI event bus for reporting events to.
""" """
def start(self, invoker: Invoker) -> None:
self.sync_to_config()
@property @property
@abstractmethod @abstractmethod
def app_config(self) -> InvokeAIAppConfig: def app_config(self) -> InvokeAIAppConfig:
@ -83,7 +88,7 @@ class ModelInstallServiceBase(ABC):
def register_path( def register_path(
self, self,
model_path: Union[Path, str], model_path: Union[Path, str],
metadata: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
) -> str: ) -> str:
""" """
Probe and register the model at model_path. Probe and register the model at model_path.
@ -91,7 +96,7 @@ class ModelInstallServiceBase(ABC):
This keeps the model in its current location. This keeps the model in its current location.
:param model_path: Filesystem Path to the model. :param model_path: Filesystem Path to the model.
:param metadata: Dict of attributes that will override autoassigned values. :param config: Dict of attributes that will override autoassigned values.
:returns id: The string ID of the registered model. :returns id: The string ID of the registered model.
""" """
@ -111,7 +116,7 @@ class ModelInstallServiceBase(ABC):
def install_path( def install_path(
self, self,
model_path: Union[Path, str], model_path: Union[Path, str],
metadata: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
) -> str: ) -> str:
""" """
Probe, register and install the model in the models directory. Probe, register and install the model in the models directory.
@ -120,7 +125,7 @@ class ModelInstallServiceBase(ABC):
the models directory handled by InvokeAI. the models directory handled by InvokeAI.
:param model_path: Filesystem Path to the model. :param model_path: Filesystem Path to the model.
:param metadata: Dict of attributes that will override autoassigned values. :param config: Dict of attributes that will override autoassigned values.
:returns id: The string ID of the registered model. :returns id: The string ID of the registered model.
""" """
@ -128,10 +133,10 @@ class ModelInstallServiceBase(ABC):
def import_model( def import_model(
self, self,
source: Union[str, Path, AnyHttpUrl], source: Union[str, Path, AnyHttpUrl],
inplace: bool = True, inplace: bool = False,
variant: Optional[str] = None, variant: Optional[str] = None,
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
) -> ModelInstallJob: ) -> ModelInstallJob:
"""Install the indicated model. """Install the indicated model.
@ -147,8 +152,9 @@ class ModelInstallServiceBase(ABC):
:param subfolder: When downloading HF repo_ids this can be used to :param subfolder: When downloading HF repo_ids this can be used to
specify a subfolder of the HF repository to download from. specify a subfolder of the HF repository to download from.
:param metadata: Optional dict. Any fields in this dict :param config: Optional dict. Any fields in this dict
will override corresponding autoassigned probe fields. Use it to override will override corresponding autoassigned probe fields in the
model's config record. Use it to override
`name`, `description`, `base_type`, `model_type`, `format`, `name`, `description`, `base_type`, `model_type`, `format`,
`prediction_type`, `image_size`, and/or `ztsnr_training`. `prediction_type`, `image_size`, and/or `ztsnr_training`.

View File

@ -5,26 +5,30 @@ from hashlib import sha256
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
from random import randbytes from random import randbytes
from shutil import move, rmtree from shutil import copyfile, copytree, move, rmtree
from typing import Any, Dict, List, Set, Optional, Union from typing import Any, Dict, List, Optional, Set, Union
from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events import EventServiceBase from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.model_records import ModelRecordServiceBase, DuplicateModelException from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, UnknownModelException
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
BaseModelType,
InvalidModelConfigException, InvalidModelConfigException,
ModelType,
) )
from invokeai.backend.model_manager.config import ModelType, BaseModelType
from invokeai.backend.model_manager.hash import FastModelHash from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.model_manager.probe import ModelProbe from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import Chdir, InvokeAILogger from invokeai.backend.util import Chdir, InvokeAILogger
from .model_install_base import ModelSource, InstallStatus, ModelInstallJob, ModelInstallServiceBase, UnknownInstallJobException from .model_install_base import (
InstallStatus,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
UnknownInstallJobException,
)
# marker that the queue is done and that thread should exit # marker that the queue is done and that thread should exit
STOP_JOB = ModelInstallJob(source="stop", local_path=Path("/dev/null")) STOP_JOB = ModelInstallJob(source="stop", local_path=Path("/dev/null"))
@ -91,10 +95,12 @@ class ModelInstallService(ModelInstallServiceBase):
try: try:
self._signal_job_running(job) self._signal_job_running(job)
if job.inplace: if job.inplace:
job.key = self.register_path(job.local_path, job.metadata) key = self.register_path(job.local_path, job.config_in)
else: else:
job.key = self.install_path(job.local_path, job.metadata) key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job) self._signal_job_completed(job)
except (OSError, DuplicateModelException, InvalidModelConfigException) as excp: except (OSError, DuplicateModelException, InvalidModelConfigException) as excp:
self._signal_job_errored(job, excp) self._signal_job_errored(job, excp)
finally: finally:
@ -109,67 +115,73 @@ class ModelInstallService(ModelInstallServiceBase):
job.status = InstallStatus.COMPLETED job.status = InstallStatus.COMPLETED
if self._event_bus: if self._event_bus:
assert job.local_path is not None assert job.local_path is not None
self._event_bus.emit_model_install_completed(str(job.source), job.key) assert job.config_out is not None
key = job.config_out.key
self._event_bus.emit_model_install_completed(str(job.source), key)
def _signal_job_errored(self, job: ModelInstallJob, excp: Exception) -> None: def _signal_job_errored(self, job: ModelInstallJob, excp: Exception) -> None:
job.set_error(excp) job.set_error(excp)
if self._event_bus: if self._event_bus:
self._event_bus.emit_model_install_error(str(job.source), job.error_type, job.error) error_type = job.error_type
error = job.error
assert error_type is not None
assert error is not None
self._event_bus.emit_model_install_error(str(job.source), error_type, error)
def register_path( def register_path(
self, self,
model_path: Union[Path, str], model_path: Union[Path, str],
metadata: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102 ) -> str: # noqa D102
model_path = Path(model_path) model_path = Path(model_path)
metadata = metadata or {} config = config or {}
if metadata.get('source') is None: if config.get('source') is None:
metadata['source'] = model_path.resolve().as_posix() config['source'] = model_path.resolve().as_posix()
return self._register(model_path, metadata) return self._register(model_path, config)
def install_path( def install_path(
self, self,
model_path: Union[Path, str], model_path: Union[Path, str],
metadata: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102 ) -> str: # noqa D102
model_path = Path(model_path) model_path = Path(model_path)
metadata = metadata or {} config = config or {}
if metadata.get('source') is None: if config.get('source') is None:
metadata['source'] = model_path.resolve().as_posix() config['source'] = model_path.resolve().as_posix()
info: AnyModelConfig = self._probe_model(Path(model_path), metadata) info: AnyModelConfig = self._probe_model(Path(model_path), config)
old_hash = info.original_hash old_hash = info.original_hash
dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name
new_path = self._move_model(model_path, dest_path) new_path = self._copy_model(model_path, dest_path)
new_hash = FastModelHash.hash(new_path) new_hash = FastModelHash.hash(new_path)
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted." assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
return self._register( return self._register(
new_path, new_path,
metadata, config,
info, info,
) )
def import_model( def import_model(
self, self,
source: ModelSource, source: ModelSource,
inplace: bool = True, inplace: bool = False,
variant: Optional[str] = None, variant: Optional[str] = None,
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
) -> ModelInstallJob: # noqa D102 ) -> ModelInstallJob: # noqa D102
# Clean up a common source of error. Doesn't work with Paths. # Clean up a common source of error. Doesn't work with Paths.
if isinstance(source, str): if isinstance(source, str):
source = source.strip() source = source.strip()
if not metadata: if not config:
metadata = {} config = {}
# Installing a local path # Installing a local path
if isinstance(source, (str, Path)) and Path(source).exists(): # a path that is already on disk if isinstance(source, (str, Path)) and Path(source).exists(): # a path that is already on disk
job = ModelInstallJob(metadata=metadata, job = ModelInstallJob(config_in=config,
source=source, source=source,
inplace=inplace, inplace=inplace,
local_path=Path(source), local_path=Path(source),
@ -179,7 +191,7 @@ class ModelInstallService(ModelInstallServiceBase):
return job return job
else: # here is where we'd download a URL or repo_id. Implementation pending download queue. else: # here is where we'd download a URL or repo_id. Implementation pending download queue.
raise NotImplementedError raise UnknownModelException("File or directory not found")
def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102 def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102
jobs = self._install_jobs jobs = self._install_jobs
@ -212,7 +224,9 @@ class ModelInstallService(ModelInstallServiceBase):
self._scan_models_directory() self._scan_models_directory()
if autoimport := self._app_config.autoimport_dir: if autoimport := self._app_config.autoimport_dir:
self._logger.info("Scanning autoimport directory for new models") self._logger.info("Scanning autoimport directory for new models")
self.scan_directory(self._app_config.root_path / autoimport) installed = self.scan_directory(self._app_config.root_path / autoimport)
self._logger.info(f"{len(installed)} new models registered")
self._logger.info("Model installer (re)initialized")
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()} self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
@ -242,7 +256,7 @@ class ModelInstallService(ModelInstallServiceBase):
for key in defunct_models: for key in defunct_models:
self.unregister(key) self.unregister(key)
self._logger.info(f"Scanning {self._app_config.models_path} for new models") self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
for cur_base_model in BaseModelType: for cur_base_model in BaseModelType:
for cur_model_type in ModelType: for cur_model_type in ModelType:
models_dir = Path(cur_base_model.value, cur_model_type.value) models_dir = Path(cur_base_model.value, cur_model_type.value)
@ -328,6 +342,16 @@ class ModelInstallService(ModelInstallServiceBase):
path.unlink() path.unlink()
self.unregister(key) self.unregister(key)
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
new_path.parent.mkdir(parents=True, exist_ok=True)
if old_path.is_dir():
copytree(old_path, new_path)
else:
copyfile(old_path, new_path)
return new_path
def _move_model(self, old_path: Path, new_path: Path) -> Path: def _move_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path: if old_path == new_path:
return old_path return old_path
@ -344,10 +368,10 @@ class ModelInstallService(ModelInstallServiceBase):
move(old_path, new_path) move(old_path, new_path)
return new_path return new_path
def _probe_model(self, model_path: Path, metadata: Optional[Dict[str, Any]] = None) -> AnyModelConfig: def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
info: AnyModelConfig = ModelProbe.probe(Path(model_path)) info: AnyModelConfig = ModelProbe.probe(Path(model_path))
if metadata: # used to override probe fields if config: # used to override probe fields
for key, value in metadata.items(): for key, value in config.items():
setattr(info, key, value) setattr(info, key, value)
return info return info
@ -356,10 +380,10 @@ class ModelInstallService(ModelInstallServiceBase):
def _register(self, def _register(self,
model_path: Path, model_path: Path,
metadata: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
info: Optional[AnyModelConfig] = None) -> str: info: Optional[AnyModelConfig] = None) -> str:
info = info or ModelProbe.probe(model_path, metadata) info = info or ModelProbe.probe(model_path, config)
key = self._create_key() key = self._create_key()
model_path = model_path.absolute() model_path = model_path.absolute()

View File

@ -1,17 +1,17 @@
"""Re-export frequently-used symbols from the Model Manager backend.""" """Re-export frequently-used symbols from the Model Manager backend."""
from .probe import ModelProbe
from .config import ( from .config import (
AnyModelConfig,
BaseModelType,
InvalidModelConfigException, InvalidModelConfigException,
ModelConfigFactory, ModelConfigFactory,
BaseModelType,
ModelType,
SubModelType,
ModelVariantType,
ModelFormat, ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType, SchedulerPredictionType,
AnyModelConfig, SubModelType,
) )
from .probe import ModelProbe
from .search import ModelSearch from .search import ModelSearch
__all__ = ['ModelProbe', 'ModelSearch', __all__ = ['ModelProbe', 'ModelSearch',

View File

@ -11,7 +11,7 @@ from .devices import ( # noqa: F401
normalize_device, normalize_device,
torch_dtype, torch_dtype,
) )
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401
from .logging import InvokeAILogger from .logging import InvokeAILogger
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401
__all__ = ['Chdir', 'InvokeAILogger', 'choose_precision', 'choose_torch_device'] __all__ = ['Chdir', 'InvokeAILogger', 'choose_precision', 'choose_torch_device']

View File

@ -164,6 +164,7 @@ nav:
- Overview: 'contributing/contribution_guides/development.md' - Overview: 'contributing/contribution_guides/development.md'
- New Contributors: 'contributing/contribution_guides/newContributorChecklist.md' - New Contributors: 'contributing/contribution_guides/newContributorChecklist.md'
- InvokeAI Architecture: 'contributing/ARCHITECTURE.md' - InvokeAI Architecture: 'contributing/ARCHITECTURE.md'
- Model Manager v2: 'contributing/MODEL_MANAGER.md'
- Frontend Documentation: 'contributing/contribution_guides/contributingToFrontend.md' - Frontend Documentation: 'contributing/contribution_guides/contributingToFrontend.md'
- Local Development: 'contributing/LOCAL_DEVELOPMENT.md' - Local Development: 'contributing/LOCAL_DEVELOPMENT.md'
- Adding Tests: 'contributing/TESTS.md' - Adding Tests: 'contributing/TESTS.md'

View File

@ -127,7 +127,7 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path,
"""Note: may want to break this down into several smaller unit tests.""" """Note: may want to break this down into several smaller unit tests."""
source = test_file source = test_file
description = "Test of metadata assignment" description = "Test of metadata assignment"
job = installer.import_model(source, inplace=False, metadata={"description": description}) job = installer.import_model(source, inplace=False, config={"description": description})
assert job is not None assert job is not None
assert isinstance(job, ModelInstallJob) assert isinstance(job, ModelInstallJob)
@ -172,9 +172,10 @@ def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app
key = installer.install_path(test_file) key = installer.install_path(test_file)
model_record = store.get_model(key) model_record = store.get_model(key)
assert Path(app_config.models_dir / model_record.path).exists() assert Path(app_config.models_dir / model_record.path).exists()
assert not test_file.exists() # original should not still be there after installation assert test_file.exists() # original should still be there after installation
installer.delete(key) installer.delete(key)
assert not Path(app_config.models_dir / model_record.path).exists() # but installed copy should not! assert not Path(app_config.models_dir / model_record.path).exists() # after deletion, installed copy should not exist
assert test_file.exists() # but original should still be there
with pytest.raises(UnknownModelException): with pytest.raises(UnknownModelException):
store.get_model(key) store.get_model(key)