mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into install/release-tools
This commit is contained in:
commit
713a83e7da
@ -100,6 +100,8 @@ ENV INVOKEAI_SRC=/opt/invokeai
|
|||||||
ENV VIRTUAL_ENV=/opt/venv/invokeai
|
ENV VIRTUAL_ENV=/opt/venv/invokeai
|
||||||
ENV INVOKEAI_ROOT=/invokeai
|
ENV INVOKEAI_ROOT=/invokeai
|
||||||
ENV PATH="$VIRTUAL_ENV/bin:$INVOKEAI_SRC:$PATH"
|
ENV PATH="$VIRTUAL_ENV/bin:$INVOKEAI_SRC:$PATH"
|
||||||
|
ENV CONTAINER_UID=${CONTAINER_UID:-1000}
|
||||||
|
ENV CONTAINER_GID=${CONTAINER_GID:-1000}
|
||||||
|
|
||||||
# --link requires buldkit w/ dockerfile syntax 1.4
|
# --link requires buldkit w/ dockerfile syntax 1.4
|
||||||
COPY --link --from=builder ${INVOKEAI_SRC} ${INVOKEAI_SRC}
|
COPY --link --from=builder ${INVOKEAI_SRC} ${INVOKEAI_SRC}
|
||||||
@ -117,7 +119,7 @@ WORKDIR ${INVOKEAI_SRC}
|
|||||||
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
|
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
|
||||||
RUN python3 -c "from patchmatch import patch_match"
|
RUN python3 -c "from patchmatch import patch_match"
|
||||||
|
|
||||||
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R 1000:1000 ${INVOKEAI_ROOT}
|
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
|
||||||
|
|
||||||
COPY docker/docker-entrypoint.sh ./
|
COPY docker/docker-entrypoint.sh ./
|
||||||
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]
|
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]
|
||||||
|
@ -10,40 +10,36 @@ model. These are the:
|
|||||||
tracks the type of the model, its provenance, and where it can be
|
tracks the type of the model, its provenance, and where it can be
|
||||||
found on disk.
|
found on disk.
|
||||||
|
|
||||||
* _ModelLoadServiceBase_ Responsible for loading a model from disk
|
|
||||||
into RAM and VRAM and getting it ready for inference.
|
|
||||||
|
|
||||||
* _DownloadQueueServiceBase_ A multithreaded downloader responsible
|
|
||||||
for downloading models from a remote source to disk. The download
|
|
||||||
queue has special methods for downloading repo_id folders from
|
|
||||||
Hugging Face, as well as discriminating among model versions in
|
|
||||||
Civitai, but can be used for arbitrary content.
|
|
||||||
|
|
||||||
* _ModelInstallServiceBase_ A service for installing models to
|
* _ModelInstallServiceBase_ A service for installing models to
|
||||||
disk. It uses `DownloadQueueServiceBase` to download models and
|
disk. It uses `DownloadQueueServiceBase` to download models and
|
||||||
their metadata, and `ModelRecordServiceBase` to store that
|
their metadata, and `ModelRecordServiceBase` to store that
|
||||||
information. It is also responsible for managing the InvokeAI
|
information. It is also responsible for managing the InvokeAI
|
||||||
`models` directory and its contents.
|
`models` directory and its contents.
|
||||||
|
|
||||||
|
* _DownloadQueueServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
|
||||||
|
A multithreaded downloader responsible
|
||||||
|
for downloading models from a remote source to disk. The download
|
||||||
|
queue has special methods for downloading repo_id folders from
|
||||||
|
Hugging Face, as well as discriminating among model versions in
|
||||||
|
Civitai, but can be used for arbitrary content.
|
||||||
|
|
||||||
|
* _ModelLoadServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
|
||||||
|
Responsible for loading a model from disk
|
||||||
|
into RAM and VRAM and getting it ready for inference.
|
||||||
|
|
||||||
|
|
||||||
## Location of the Code
|
## Location of the Code
|
||||||
|
|
||||||
All four of these services can be found in
|
All four of these services can be found in
|
||||||
`invokeai/app/services` in the following directories:
|
`invokeai/app/services` in the following directories:
|
||||||
|
|
||||||
* `invokeai/app/services/model_records/`
|
* `invokeai/app/services/model_records/`
|
||||||
* `invokeai/app/services/downloads/`
|
|
||||||
* `invokeai/app/services/model_loader/`
|
|
||||||
* `invokeai/app/services/model_install/`
|
* `invokeai/app/services/model_install/`
|
||||||
|
* `invokeai/app/services/model_loader/` (**under development**)
|
||||||
With the exception of the install service, each of these is a thin
|
* `invokeai/app/services/downloads/`(**under development**)
|
||||||
shell around a corresponding implementation located in
|
|
||||||
`invokeai/backend/model_manager`. The main difference between the
|
|
||||||
modules found in app services and those in the backend folder is that
|
|
||||||
the former add support for event reporting and are more tied to the
|
|
||||||
needs of the InvokeAI API.
|
|
||||||
|
|
||||||
Code related to the FastAPI web API can be found in
|
Code related to the FastAPI web API can be found in
|
||||||
`invokeai/app/api/routers/models.py`.
|
`invokeai/app/api/routers/model_records.py`.
|
||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
@ -165,10 +161,6 @@ of the fields, including `name`, `model_type` and `base_model`, are
|
|||||||
shared between `ModelConfigBase` and `ModelBase`, and this is a
|
shared between `ModelConfigBase` and `ModelBase`, and this is a
|
||||||
potential source of confusion.
|
potential source of confusion.
|
||||||
|
|
||||||
** TO DO: ** The `ModelBase` code needs to be revised to reduce the
|
|
||||||
duplication of similar classes and to support using the `key` as the
|
|
||||||
primary model identifier.
|
|
||||||
|
|
||||||
## Reading and Writing Model Configuration Records
|
## Reading and Writing Model Configuration Records
|
||||||
|
|
||||||
The `ModelRecordService` provides the ability to retrieve model
|
The `ModelRecordService` provides the ability to retrieve model
|
||||||
@ -362,7 +354,7 @@ model and pass its key to `get_model()`.
|
|||||||
Several methods allow you to create and update stored model config
|
Several methods allow you to create and update stored model config
|
||||||
records.
|
records.
|
||||||
|
|
||||||
#### add_model(key, config) -> ModelConfigBase:
|
#### add_model(key, config) -> AnyModelConfig:
|
||||||
|
|
||||||
Given a key and a configuration, this will add the model's
|
Given a key and a configuration, this will add the model's
|
||||||
configuration record to the database. `config` can either be a subclass of
|
configuration record to the database. `config` can either be a subclass of
|
||||||
@ -386,27 +378,356 @@ fields to be updated. This will return an `AnyModelConfig` on success,
|
|||||||
or raise `InvalidModelConfigException` or `UnknownModelException`
|
or raise `InvalidModelConfigException` or `UnknownModelException`
|
||||||
exceptions on failure.
|
exceptions on failure.
|
||||||
|
|
||||||
***TO DO:*** Investigate why `update_model()` returns an
|
|
||||||
`AnyModelConfig` while `add_model()` returns a `ModelConfigBase`.
|
|
||||||
|
|
||||||
### rename_model(key, new_name) -> ModelConfigBase:
|
|
||||||
|
|
||||||
This is a special case of `update_model()` for the use case of
|
|
||||||
changing the model's name. It is broken out because there are cases in
|
|
||||||
which the InvokeAI application wants to synchronize the model's name
|
|
||||||
with its path in the `models` directory after changing the name, type
|
|
||||||
or base. However, when using the ModelRecordService directly, the call
|
|
||||||
is equivalent to:
|
|
||||||
|
|
||||||
```
|
|
||||||
store.rename_model(key, {'name': 'new_name'})
|
|
||||||
```
|
|
||||||
|
|
||||||
***TO DO:*** Investigate why `rename_model()` is returning a
|
|
||||||
`ModelConfigBase` while `update_model()` returns a `AnyModelConfig`.
|
|
||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
|
## Model installation
|
||||||
|
|
||||||
|
The `ModelInstallService` class implements the
|
||||||
|
`ModelInstallServiceBase` abstract base class, and provides a one-stop
|
||||||
|
shop for all your model install needs. It provides the following
|
||||||
|
functionality:
|
||||||
|
|
||||||
|
- Registering a model config record for a model already located on the
|
||||||
|
local filesystem, without moving it or changing its path.
|
||||||
|
|
||||||
|
- Installing a model alreadiy located on the local filesystem, by
|
||||||
|
moving it into the InvokeAI root directory under the
|
||||||
|
`models` folder (or wherever config parameter `models_dir`
|
||||||
|
specifies).
|
||||||
|
|
||||||
|
- Probing of models to determine their type, base type and other key
|
||||||
|
information.
|
||||||
|
|
||||||
|
- Interface with the InvokeAI event bus to provide status updates on
|
||||||
|
the download, installation and registration process.
|
||||||
|
|
||||||
|
- Downloading a model from an arbitrary URL and installing it in
|
||||||
|
`models_dir` (_implementation pending_).
|
||||||
|
|
||||||
|
- Special handling for Civitai model URLs which allow the user to
|
||||||
|
paste in a model page's URL or download link (_implementation pending_).
|
||||||
|
|
||||||
|
|
||||||
|
- Special handling for HuggingFace repo_ids to recursively download
|
||||||
|
the contents of the repository, paying attention to alternative
|
||||||
|
variants such as fp16. (_implementation pending_)
|
||||||
|
|
||||||
|
### Initializing the installer
|
||||||
|
|
||||||
|
A default installer is created at InvokeAI api startup time and stored
|
||||||
|
in `ApiDependencies.invoker.services.model_install` and can
|
||||||
|
also be retrieved from an invocation's `context` argument with
|
||||||
|
`context.services.model_install`.
|
||||||
|
|
||||||
|
In the event you wish to create a new installer, you may use the
|
||||||
|
following initialization pattern:
|
||||||
|
|
||||||
|
```
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||||
|
from invokeai.app.services.model_install import ModelInstallService
|
||||||
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
config.parse_args()
|
||||||
|
logger = InvokeAILogger.get_logger(config=config)
|
||||||
|
db = SqliteDatabase(config, logger)
|
||||||
|
|
||||||
|
store = ModelRecordServiceSQL(db)
|
||||||
|
installer = ModelInstallService(config, store)
|
||||||
|
```
|
||||||
|
|
||||||
|
The full form of `ModelInstallService()` takes the following
|
||||||
|
required parameters:
|
||||||
|
|
||||||
|
| **Argument** | **Type** | **Description** |
|
||||||
|
|------------------|------------------------------|------------------------------|
|
||||||
|
| `config` | InvokeAIAppConfig | InvokeAI app configuration object |
|
||||||
|
| `record_store` | ModelRecordServiceBase | Config record storage database |
|
||||||
|
| `event_bus` | EventServiceBase | Optional event bus to send download/install progress events to |
|
||||||
|
|
||||||
|
Once initialized, the installer will provide the following methods:
|
||||||
|
|
||||||
|
#### install_job = installer.import_model()
|
||||||
|
|
||||||
|
The `import_model()` method is the core of the installer. The
|
||||||
|
following illustrates basic usage:
|
||||||
|
|
||||||
|
```
|
||||||
|
from invokeai.app.services.model_install import (
|
||||||
|
LocalModelSource,
|
||||||
|
HFModelSource,
|
||||||
|
URLModelSource,
|
||||||
|
)
|
||||||
|
|
||||||
|
source1 = LocalModelSource(path='/opt/models/sushi.safetensors') # a local safetensors file
|
||||||
|
source2 = LocalModelSource(path='/opt/models/sushi_diffusers') # a local diffusers folder
|
||||||
|
|
||||||
|
source3 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5') # a repo_id
|
||||||
|
source4 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='vae') # a subfolder within a repo_id
|
||||||
|
source5 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', variant='fp16') # a named variant of a HF model
|
||||||
|
|
||||||
|
source6 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL
|
||||||
|
source7 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token
|
||||||
|
|
||||||
|
for source in [source1, source2, source3, source4, source5, source6, source7]:
|
||||||
|
install_job = installer.install_model(source)
|
||||||
|
|
||||||
|
source2job = installer.wait_for_installs()
|
||||||
|
for source in sources:
|
||||||
|
job = source2job[source]
|
||||||
|
if job.status == "completed":
|
||||||
|
model_config = job.config_out
|
||||||
|
model_key = model_config.key
|
||||||
|
print(f"{source} installed as {model_key}")
|
||||||
|
elif job.status == "error":
|
||||||
|
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
As shown here, the `import_model()` method accepts a variety of
|
||||||
|
sources, including local safetensors files, local diffusers folders,
|
||||||
|
HuggingFace repo_ids with and without a subfolder designation,
|
||||||
|
Civitai model URLs and arbitrary URLs that point to checkpoint files
|
||||||
|
(but not to folders).
|
||||||
|
|
||||||
|
Each call to `import_model()` return a `ModelInstallJob` job,
|
||||||
|
an object which tracks the progress of the install.
|
||||||
|
|
||||||
|
If a remote model is requested, the model's files are downloaded in
|
||||||
|
parallel across a multiple set of threads using the download
|
||||||
|
queue. During the download process, the `ModelInstallJob` is updated
|
||||||
|
to provide status and progress information. After the files (if any)
|
||||||
|
are downloaded, the remainder of the installation runs in a single
|
||||||
|
serialized background thread. These are the model probing, file
|
||||||
|
copying, and config record database update steps.
|
||||||
|
|
||||||
|
Multiple install jobs can be queued up. You may block until all
|
||||||
|
install jobs are completed (or errored) by calling the
|
||||||
|
`wait_for_installs()` method as shown in the code
|
||||||
|
example. `wait_for_installs()` will return a `dict` that maps the
|
||||||
|
requested source to its job. This object can be interrogated
|
||||||
|
to determine its status. If the job errored out, then the error type
|
||||||
|
and details can be recovered from `job.error_type` and `job.error`.
|
||||||
|
|
||||||
|
The full list of arguments to `import_model()` is as follows:
|
||||||
|
|
||||||
|
| **Argument** | **Type** | **Default** | **Description** |
|
||||||
|
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||||
|
| `source` | Union[str, Path, AnyHttpUrl] | | The source of the model, Path, URL or repo_id |
|
||||||
|
| `inplace` | bool | True | Leave a local model in its current location |
|
||||||
|
| `variant` | str | None | Desired variant, such as 'fp16' or 'onnx' (HuggingFace only) |
|
||||||
|
| `subfolder` | str | None | Repository subfolder (HuggingFace only) |
|
||||||
|
| `config` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
|
||||||
|
| `access_token` | str | None | Provide authorization information needed to download |
|
||||||
|
|
||||||
|
|
||||||
|
The `inplace` field controls how local model Paths are handled. If
|
||||||
|
True (the default), then the model is simply registered in its current
|
||||||
|
location by the installer's `ModelConfigRecordService`. Otherwise, a
|
||||||
|
copy of the model put into the location specified by the `models_dir`
|
||||||
|
application configuration parameter.
|
||||||
|
|
||||||
|
The `variant` field is used for HuggingFace repo_ids only. If
|
||||||
|
provided, the repo_id download handler will look for and download
|
||||||
|
tensors files that follow the convention for the selected variant:
|
||||||
|
|
||||||
|
- "fp16" will select files named "*model.fp16.{safetensors,bin}"
|
||||||
|
- "onnx" will select files ending with the suffix ".onnx"
|
||||||
|
- "openvino" will select files beginning with "openvino_model"
|
||||||
|
|
||||||
|
In the special case of the "fp16" variant, the installer will select
|
||||||
|
the 32-bit version of the files if the 16-bit version is unavailable.
|
||||||
|
|
||||||
|
`subfolder` is used for HuggingFace repo_ids only. If provided, the
|
||||||
|
model will be downloaded from the designated subfolder rather than the
|
||||||
|
top-level repository folder. If a subfolder is attached to the repo_id
|
||||||
|
using the format `repo_owner/repo_name:subfolder`, then the subfolder
|
||||||
|
specified by the repo_id will override the subfolder argument.
|
||||||
|
|
||||||
|
`config` can be used to override all or a portion of the configuration
|
||||||
|
attributes returned by the model prober. See the section below for
|
||||||
|
details.
|
||||||
|
|
||||||
|
`access_token` is passed to the download queue and used to access
|
||||||
|
repositories that require it.
|
||||||
|
|
||||||
|
#### Monitoring the install job process
|
||||||
|
|
||||||
|
When you create an install job with `import_model()`, it launches the
|
||||||
|
download and installation process in the background and returns a
|
||||||
|
`ModelInstallJob` object for monitoring the process.
|
||||||
|
|
||||||
|
The `ModelInstallJob` class has the following structure:
|
||||||
|
|
||||||
|
| **Attribute** | **Type** | **Description** |
|
||||||
|
|----------------|-----------------|------------------|
|
||||||
|
| `status` | `InstallStatus` | An enum of ["waiting", "running", "completed" and "error" |
|
||||||
|
| `config_in` | `dict` | Overriding configuration values provided by the caller |
|
||||||
|
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
|
||||||
|
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
|
||||||
|
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
|
||||||
|
| `local_path` | `Path` | If a remote model, holds the path of the model after it is downloaded; if a local model, same as `source` |
|
||||||
|
| `error_type` | `str` | Name of the exception that led to an error status |
|
||||||
|
| `error` | `str` | Traceback of the error |
|
||||||
|
|
||||||
|
|
||||||
|
If the `event_bus` argument was provided, events will also be
|
||||||
|
broadcast to the InvokeAI event bus. The events will appear on the bus
|
||||||
|
as an event of type `EventServiceBase.model_event`, a timestamp and
|
||||||
|
the following event names:
|
||||||
|
|
||||||
|
- `model_install_started`
|
||||||
|
|
||||||
|
The payload will contain the keys `timestamp` and `source`. The latter
|
||||||
|
indicates the requested model source for installation.
|
||||||
|
|
||||||
|
- `model_install_progress`
|
||||||
|
|
||||||
|
Emitted at regular intervals when downloading a remote model, the
|
||||||
|
payload will contain the keys `timestamp`, `source`, `current_bytes`
|
||||||
|
and `total_bytes`. These events are _not_ emitted when a local model
|
||||||
|
already on the filesystem is imported.
|
||||||
|
|
||||||
|
- `model_install_completed`
|
||||||
|
|
||||||
|
Issued once at the end of a successful installation. The payload will
|
||||||
|
contain the keys `timestamp`, `source` and `key`, where `key` is the
|
||||||
|
ID under which the model has been registered.
|
||||||
|
|
||||||
|
- `model_install_error`
|
||||||
|
|
||||||
|
Emitted if the installation process fails for some reason. The payload
|
||||||
|
will contain the keys `timestamp`, `source`, `error_type` and
|
||||||
|
`error`. `error_type` is a short message indicating the nature of the
|
||||||
|
error, and `error` is the long traceback to help debug the problem.
|
||||||
|
|
||||||
|
#### Model confguration and probing
|
||||||
|
|
||||||
|
The install service uses the `invokeai.backend.model_manager.probe`
|
||||||
|
module during import to determine the model's type, base type, and
|
||||||
|
other configuration parameters. Among other things, it assigns a
|
||||||
|
default name and description for the model based on probed
|
||||||
|
fields.
|
||||||
|
|
||||||
|
When downloading remote models is implemented, additional
|
||||||
|
configuration information, such as list of trigger terms, will be
|
||||||
|
retrieved from the HuggingFace and Civitai model repositories.
|
||||||
|
|
||||||
|
The probed values can be overriden by providing a dictionary in the
|
||||||
|
optional `config` argument passed to `import_model()`. You may provide
|
||||||
|
overriding values for any of the model's configuration
|
||||||
|
attributes. Here is an example of setting the
|
||||||
|
`SchedulerPredictionType` and `name` for an sd-2 model:
|
||||||
|
|
||||||
|
This is typically used to set
|
||||||
|
the model's name and description, but can also be used to overcome
|
||||||
|
cases in which automatic probing is unable to (correctly) determine
|
||||||
|
the model's attribute. The most common situation is the
|
||||||
|
`prediction_type` field for sd-2 (and rare sd-1) models. Here is an
|
||||||
|
example of how it works:
|
||||||
|
|
||||||
|
```
|
||||||
|
install_job = installer.import_model(
|
||||||
|
source='stabilityai/stable-diffusion-2-1',
|
||||||
|
variant='fp16',
|
||||||
|
config=dict(
|
||||||
|
prediction_type=SchedulerPredictionType('v_prediction')
|
||||||
|
name='stable diffusion 2 base model',
|
||||||
|
)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Other installer methods
|
||||||
|
|
||||||
|
This section describes additional methods provided by the installer class.
|
||||||
|
|
||||||
|
#### jobs = installer.wait_for_installs()
|
||||||
|
|
||||||
|
Block until all pending installs are completed or errored and then
|
||||||
|
returns a list of completed jobs.
|
||||||
|
|
||||||
|
#### jobs = installer.list_jobs([source])
|
||||||
|
|
||||||
|
Return a list of all active and complete `ModelInstallJobs`. An
|
||||||
|
optional `source` argument allows you to filter the returned list by a
|
||||||
|
model source string pattern using a partial string match.
|
||||||
|
|
||||||
|
#### jobs = installer.get_job(source)
|
||||||
|
|
||||||
|
Return a list of `ModelInstallJob` corresponding to the indicated
|
||||||
|
model source.
|
||||||
|
|
||||||
|
#### installer.prune_jobs
|
||||||
|
|
||||||
|
Remove non-pending jobs (completed or errored) from the job list
|
||||||
|
returned by `list_jobs()` and `get_job()`.
|
||||||
|
|
||||||
|
#### installer.app_config, installer.record_store,
|
||||||
|
installer.event_bus
|
||||||
|
|
||||||
|
Properties that provide access to the installer's `InvokeAIAppConfig`,
|
||||||
|
`ModelRecordServiceBase` and `EventServiceBase` objects.
|
||||||
|
|
||||||
|
#### key = installer.register_path(model_path, config), key = installer.install_path(model_path, config)
|
||||||
|
|
||||||
|
These methods bypass the download queue and directly register or
|
||||||
|
install the model at the indicated path, returning the unique ID for
|
||||||
|
the installed model.
|
||||||
|
|
||||||
|
Both methods accept a Path object corresponding to a checkpoint or
|
||||||
|
diffusers folder, and an optional dict of config attributes to use to
|
||||||
|
override the values derived from model probing.
|
||||||
|
|
||||||
|
The difference between `register_path()` and `install_path()` is that
|
||||||
|
the former creates a model configuration record without changing the
|
||||||
|
location of the model in the filesystem. The latter makes a copy of
|
||||||
|
the model inside the InvokeAI models directory before registering
|
||||||
|
it.
|
||||||
|
|
||||||
|
#### installer.unregister(key)
|
||||||
|
|
||||||
|
This will remove the model config record for the model at key, and is
|
||||||
|
equivalent to `installer.record_store.del_model(key)`
|
||||||
|
|
||||||
|
#### installer.delete(key)
|
||||||
|
|
||||||
|
This is similar to `unregister()` but has the additional effect of
|
||||||
|
conditionally deleting the underlying model file(s) if they reside
|
||||||
|
within the InvokeAI models directory
|
||||||
|
|
||||||
|
#### installer.unconditionally_delete(key)
|
||||||
|
|
||||||
|
This method is similar to `unregister()`, but also unconditionally
|
||||||
|
deletes the corresponding model weights file(s), regardless of whether
|
||||||
|
they are inside or outside the InvokeAI models hierarchy.
|
||||||
|
|
||||||
|
#### List[str]=installer.scan_directory(scan_dir: Path, install: bool)
|
||||||
|
|
||||||
|
This method will recursively scan the directory indicated in
|
||||||
|
`scan_dir` for new models and either install them in the models
|
||||||
|
directory or register them in place, depending on the setting of
|
||||||
|
`install` (default False).
|
||||||
|
|
||||||
|
The return value is the list of keys of the new installed/registered
|
||||||
|
models.
|
||||||
|
|
||||||
|
#### installer.sync_to_config()
|
||||||
|
|
||||||
|
This method synchronizes models in the models directory and autoimport
|
||||||
|
directory to those in the `ModelConfigRecordService` database. New
|
||||||
|
models are registered and orphan models are unregistered.
|
||||||
|
|
||||||
|
#### installer.start(invoker)
|
||||||
|
|
||||||
|
The `start` method is called by the API intialization routines when
|
||||||
|
the API starts up. Its effect is to call `sync_to_config()` to
|
||||||
|
synchronize the model record store database with what's currently on
|
||||||
|
disk.
|
||||||
|
|
||||||
|
# The remainder of this documentation is provisional, pending implementation of the Download and Load services
|
||||||
|
|
||||||
## Let's get loaded, the lowdown on ModelLoadService
|
## Let's get loaded, the lowdown on ModelLoadService
|
||||||
|
|
||||||
The `ModelLoadService` is responsible for loading a named model into
|
The `ModelLoadService` is responsible for loading a named model into
|
||||||
@ -863,351 +1184,3 @@ other resources that it might have been using.
|
|||||||
This will start/pause/cancel all jobs that have been submitted to the
|
This will start/pause/cancel all jobs that have been submitted to the
|
||||||
queue and have not yet reached a terminal state.
|
queue and have not yet reached a terminal state.
|
||||||
|
|
||||||
## Model installation
|
|
||||||
|
|
||||||
The `ModelInstallService` class implements the
|
|
||||||
`ModelInstallServiceBase` abstract base class, and provides a one-stop
|
|
||||||
shop for all your model install needs. It provides the following
|
|
||||||
functionality:
|
|
||||||
|
|
||||||
- Registering a model config record for a model already located on the
|
|
||||||
local filesystem, without moving it or changing its path.
|
|
||||||
|
|
||||||
- Installing a model alreadiy located on the local filesystem, by
|
|
||||||
moving it into the InvokeAI root directory under the
|
|
||||||
`models` folder (or wherever config parameter `models_dir`
|
|
||||||
specifies).
|
|
||||||
|
|
||||||
- Downloading a model from an arbitrary URL and installing it in
|
|
||||||
`models_dir`.
|
|
||||||
|
|
||||||
- Special handling for Civitai model URLs which allow the user to
|
|
||||||
paste in a model page's URL or download link. Any metadata provided
|
|
||||||
by Civitai, such as trigger terms, are captured and placed in the
|
|
||||||
model config record.
|
|
||||||
|
|
||||||
- Special handling for HuggingFace repo_ids to recursively download
|
|
||||||
the contents of the repository, paying attention to alternative
|
|
||||||
variants such as fp16.
|
|
||||||
|
|
||||||
- Probing of models to determine their type, base type and other key
|
|
||||||
information.
|
|
||||||
|
|
||||||
- Interface with the InvokeAI event bus to provide status updates on
|
|
||||||
the download, installation and registration process.
|
|
||||||
|
|
||||||
### Initializing the installer
|
|
||||||
|
|
||||||
A default installer is created at InvokeAI api startup time and stored
|
|
||||||
in `ApiDependencies.invoker.services.model_install_service` and can
|
|
||||||
also be retrieved from an invocation's `context` argument with
|
|
||||||
`context.services.model_install_service`.
|
|
||||||
|
|
||||||
In the event you wish to create a new installer, you may use the
|
|
||||||
following initialization pattern:
|
|
||||||
|
|
||||||
```
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from invokeai.app.services.download_manager import DownloadQueueServive
|
|
||||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
|
||||||
|
|
||||||
config = InvokeAI.get_config()
|
|
||||||
queue = DownloadQueueService()
|
|
||||||
store = ModelRecordServiceBase.open(config)
|
|
||||||
installer = ModelInstallService(config=config, queue=queue, store=store)
|
|
||||||
```
|
|
||||||
|
|
||||||
The full form of `ModelInstallService()` takes the following
|
|
||||||
parameters. Each parameter will default to a reasonable value, but it
|
|
||||||
is recommended that you set them explicitly as shown in the above example.
|
|
||||||
|
|
||||||
| **Argument** | **Type** | **Default** | **Description** |
|
|
||||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
|
||||||
| `config` | InvokeAIAppConfig | Use system-wide config | InvokeAI app configuration object |
|
|
||||||
| `queue` | DownloadQueueServiceBase | Create a new download queue for internal use | Download queue |
|
|
||||||
| `store` | ModelRecordServiceBase | Use config to select the database to open | Config storage database |
|
|
||||||
| `event_bus` | EventServiceBase | None | An event bus to send download/install progress events to |
|
|
||||||
| `event_handlers` | List[DownloadEventHandler] | None | Event handlers for the download queue |
|
|
||||||
|
|
||||||
Note that if `store` is not provided, then the class will use
|
|
||||||
`ModelRecordServiceBase.open(config)` to select the database to use.
|
|
||||||
|
|
||||||
Once initialized, the installer will provide the following methods:
|
|
||||||
|
|
||||||
#### install_job = installer.install_model()
|
|
||||||
|
|
||||||
The `install_model()` method is the core of the installer. The
|
|
||||||
following illustrates basic usage:
|
|
||||||
|
|
||||||
```
|
|
||||||
sources = [
|
|
||||||
Path('/opt/models/sushi.safetensors'), # a local safetensors file
|
|
||||||
Path('/opt/models/sushi_diffusers/'), # a local diffusers folder
|
|
||||||
'runwayml/stable-diffusion-v1-5', # a repo_id
|
|
||||||
'runwayml/stable-diffusion-v1-5:vae', # a subfolder within a repo_id
|
|
||||||
'https://civitai.com/api/download/models/63006', # a civitai direct download link
|
|
||||||
'https://civitai.com/models/8765?modelVersionId=10638', # civitai model page
|
|
||||||
'https://s3.amazon.com/fjacks/sd-3.safetensors', # arbitrary URL
|
|
||||||
]
|
|
||||||
|
|
||||||
for source in sources:
|
|
||||||
install_job = installer.install_model(source)
|
|
||||||
|
|
||||||
source2key = installer.wait_for_installs()
|
|
||||||
for source in sources:
|
|
||||||
model_key = source2key[source]
|
|
||||||
print(f"{source} installed as {model_key}")
|
|
||||||
```
|
|
||||||
|
|
||||||
As shown here, the `install_model()` method accepts a variety of
|
|
||||||
sources, including local safetensors files, local diffusers folders,
|
|
||||||
HuggingFace repo_ids with and without a subfolder designation,
|
|
||||||
Civitai model URLs and arbitrary URLs that point to checkpoint files
|
|
||||||
(but not to folders).
|
|
||||||
|
|
||||||
Each call to `install_model()` will return a `ModelInstallJob` job, a
|
|
||||||
subclass of `DownloadJobBase`. The install job has additional
|
|
||||||
install-specific fields described in the next section.
|
|
||||||
|
|
||||||
Each install job will run in a series of background threads using
|
|
||||||
the object's download queue. You may block until all install jobs are
|
|
||||||
completed (or errored) by calling the `wait_for_installs()` method as
|
|
||||||
shown in the code example. `wait_for_installs()` will return a `dict`
|
|
||||||
that maps the requested source to the key of the installed model. In
|
|
||||||
the case that a model fails to download or install, its value in the
|
|
||||||
dict will be None. The actual cause of the error will be reported in
|
|
||||||
the corresponding job's `error` field.
|
|
||||||
|
|
||||||
Alternatively you may install event handlers and/or listen for events
|
|
||||||
on the InvokeAI event bus in order to monitor the progress of the
|
|
||||||
requested installs.
|
|
||||||
|
|
||||||
The full list of arguments to `model_install()` is as follows:
|
|
||||||
|
|
||||||
| **Argument** | **Type** | **Default** | **Description** |
|
|
||||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
|
||||||
| `source` | Union[str, Path, AnyHttpUrl] | | The source of the model, Path, URL or repo_id |
|
|
||||||
| `inplace` | bool | True | Leave a local model in its current location |
|
|
||||||
| `variant` | str | None | Desired variant, such as 'fp16' or 'onnx' (HuggingFace only) |
|
|
||||||
| `subfolder` | str | None | Repository subfolder (HuggingFace only) |
|
|
||||||
| `probe_override` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
|
|
||||||
| `metadata` | ModelSourceMetadata | None | Provide metadata that will be added to model's config |
|
|
||||||
| `access_token` | str | None | Provide authorization information needed to download |
|
|
||||||
| `priority` | int | 10 | Download queue priority for the job |
|
|
||||||
|
|
||||||
|
|
||||||
The `inplace` field controls how local model Paths are handled. If
|
|
||||||
True (the default), then the model is simply registered in its current
|
|
||||||
location by the installer's `ModelConfigRecordService`. Otherwise, the
|
|
||||||
model will be moved into the location specified by the `models_dir`
|
|
||||||
application configuration parameter.
|
|
||||||
|
|
||||||
The `variant` field is used for HuggingFace repo_ids only. If
|
|
||||||
provided, the repo_id download handler will look for and download
|
|
||||||
tensors files that follow the convention for the selected variant:
|
|
||||||
|
|
||||||
- "fp16" will select files named "*model.fp16.{safetensors,bin}"
|
|
||||||
- "onnx" will select files ending with the suffix ".onnx"
|
|
||||||
- "openvino" will select files beginning with "openvino_model"
|
|
||||||
|
|
||||||
In the special case of the "fp16" variant, the installer will select
|
|
||||||
the 32-bit version of the files if the 16-bit version is unavailable.
|
|
||||||
|
|
||||||
`subfolder` is used for HuggingFace repo_ids only. If provided, the
|
|
||||||
model will be downloaded from the designated subfolder rather than the
|
|
||||||
top-level repository folder. If a subfolder is attached to the repo_id
|
|
||||||
using the format `repo_owner/repo_name:subfolder`, then the subfolder
|
|
||||||
specified by the repo_id will override the subfolder argument.
|
|
||||||
|
|
||||||
`probe_override` can be used to override all or a portion of the
|
|
||||||
attributes returned by the model prober. This can be used to overcome
|
|
||||||
cases in which automatic probing is unable to (correctly) determine
|
|
||||||
the model's attribute. The most common situation is the
|
|
||||||
`prediction_type` field for sd-2 (and rare sd-1) models. Here is an
|
|
||||||
example of how it works:
|
|
||||||
|
|
||||||
```
|
|
||||||
install_job = installer.install_model(
|
|
||||||
source='stabilityai/stable-diffusion-2-1',
|
|
||||||
variant='fp16',
|
|
||||||
probe_override=dict(
|
|
||||||
prediction_type=SchedulerPredictionType('v_prediction')
|
|
||||||
)
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
`metadata` allows you to attach custom metadata to the installed
|
|
||||||
model. See the next section for details.
|
|
||||||
|
|
||||||
`priority` and `access_token` are passed to the download queue and
|
|
||||||
have the same effect as they do for the DownloadQueueServiceBase.
|
|
||||||
|
|
||||||
#### Monitoring the install job process
|
|
||||||
|
|
||||||
When you create an install job with `model_install()`, events will be
|
|
||||||
passed to the list of `DownloadEventHandlers` provided at installer
|
|
||||||
initialization time. Event handlers can also be added to individual
|
|
||||||
model install jobs by calling their `add_handler()` method as
|
|
||||||
described earlier for the `DownloadQueueService`.
|
|
||||||
|
|
||||||
If the `event_bus` argument was provided, events will also be
|
|
||||||
broadcast to the InvokeAI event bus. The events will appear on the bus
|
|
||||||
as a singular event type named `model_event` with a payload of
|
|
||||||
`job`. You can then retrieve the job and check its status.
|
|
||||||
|
|
||||||
** TO DO: ** consider breaking `model_event` into
|
|
||||||
`model_install_started`, `model_install_completed`, etc. The event bus
|
|
||||||
features have not yet been tested with FastAPI/websockets, and it may
|
|
||||||
turn out that the job object is not serializable.
|
|
||||||
|
|
||||||
#### Model metadata and probing
|
|
||||||
|
|
||||||
The install service has special handling for HuggingFace and Civitai
|
|
||||||
URLs that capture metadata from the source and include it in the model
|
|
||||||
configuration record. For example, fetching the Civitai model 8765
|
|
||||||
will produce a config record similar to this (using YAML
|
|
||||||
representation):
|
|
||||||
|
|
||||||
```
|
|
||||||
5abc3ef8600b6c1cc058480eaae3091e:
|
|
||||||
path: sd-1/lora/to8contrast-1-5.safetensors
|
|
||||||
name: to8contrast-1-5
|
|
||||||
base_model: sd-1
|
|
||||||
model_type: lora
|
|
||||||
model_format: lycoris
|
|
||||||
key: 5abc3ef8600b6c1cc058480eaae3091e
|
|
||||||
hash: 5abc3ef8600b6c1cc058480eaae3091e
|
|
||||||
description: 'Trigger terms: to8contrast style'
|
|
||||||
author: theovercomer8
|
|
||||||
license: allowCommercialUse=Sell; allowDerivatives=True; allowNoCredit=True
|
|
||||||
source: https://civitai.com/models/8765?modelVersionId=10638
|
|
||||||
thumbnail_url: null
|
|
||||||
tags:
|
|
||||||
- model
|
|
||||||
- style
|
|
||||||
- portraits
|
|
||||||
```
|
|
||||||
|
|
||||||
For sources that do not provide model metadata, you can attach custom
|
|
||||||
fields by providing a `metadata` argument to `model_install()` using
|
|
||||||
an initialized `ModelSourceMetadata` object (available for import from
|
|
||||||
`model_install_service.py`):
|
|
||||||
|
|
||||||
```
|
|
||||||
from invokeai.app.services.model_install_service import ModelSourceMetadata
|
|
||||||
meta = ModelSourceMetadata(
|
|
||||||
name="my model",
|
|
||||||
author="Sushi Chef",
|
|
||||||
description="Highly customized model; trigger with 'sushi',"
|
|
||||||
license="mit",
|
|
||||||
thumbnail_url="http://s3.amazon.com/ljack/pics/sushi.png",
|
|
||||||
tags=list('sfw', 'food')
|
|
||||||
)
|
|
||||||
install_job = installer.install_model(
|
|
||||||
source='sushi_chef/model3',
|
|
||||||
variant='fp16',
|
|
||||||
metadata=meta,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
It is not currently recommended to provide custom metadata when
|
|
||||||
installing from Civitai or HuggingFace source, as the metadata
|
|
||||||
provided by the source will overwrite the fields you provide. Instead,
|
|
||||||
after the model is installed you can use
|
|
||||||
`ModelRecordService.update_model()` to change the desired fields.
|
|
||||||
|
|
||||||
** TO DO: ** Change the logic so that the caller's metadata fields take
|
|
||||||
precedence over those provided by the source.
|
|
||||||
|
|
||||||
|
|
||||||
#### Other installer methods
|
|
||||||
|
|
||||||
This section describes additional, less-frequently-used attributes and
|
|
||||||
methods provided by the installer class.
|
|
||||||
|
|
||||||
##### installer.wait_for_installs()
|
|
||||||
|
|
||||||
This is equivalent to the `DownloadQueue` `join()` method. It will
|
|
||||||
block until all the active jobs in the install queue have reached a
|
|
||||||
terminal state (completed, errored or cancelled).
|
|
||||||
|
|
||||||
##### installer.queue, installer.store, installer.config
|
|
||||||
|
|
||||||
These attributes provide access to the `DownloadQueueServiceBase`,
|
|
||||||
`ModelConfigRecordServiceBase`, and `InvokeAIAppConfig` objects that
|
|
||||||
the installer uses.
|
|
||||||
|
|
||||||
For example, to temporarily pause all pending installations, you can
|
|
||||||
do this:
|
|
||||||
|
|
||||||
```
|
|
||||||
installer.queue.pause_all_jobs()
|
|
||||||
```
|
|
||||||
##### key = installer.register_path(model_path, overrides), key = installer.install_path(model_path, overrides)
|
|
||||||
|
|
||||||
These methods bypass the download queue and directly register or
|
|
||||||
install the model at the indicated path, returning the unique ID for
|
|
||||||
the installed model.
|
|
||||||
|
|
||||||
Both methods accept a Path object corresponding to a checkpoint or
|
|
||||||
diffusers folder, and an optional dict of attributes to use to
|
|
||||||
override the values derived from model probing.
|
|
||||||
|
|
||||||
The difference between `register_path()` and `install_path()` is that
|
|
||||||
the former will not move the model from its current position, while
|
|
||||||
the latter will move it into the `models_dir` hierarchy.
|
|
||||||
|
|
||||||
##### installer.unregister(key)
|
|
||||||
|
|
||||||
This will remove the model config record for the model at key, and is
|
|
||||||
equivalent to `installer.store.unregister(key)`
|
|
||||||
|
|
||||||
##### installer.delete(key)
|
|
||||||
|
|
||||||
This is similar to `unregister()` but has the additional effect of
|
|
||||||
deleting the underlying model file(s) -- even if they were outside the
|
|
||||||
`models_dir` directory!
|
|
||||||
|
|
||||||
##### installer.conditionally_delete(key)
|
|
||||||
|
|
||||||
This method will call `unregister()` if the model identified by `key`
|
|
||||||
is outside the `models_dir` hierarchy, and call `delete()` if the
|
|
||||||
model is inside.
|
|
||||||
|
|
||||||
#### List[str]=installer.scan_directory(scan_dir: Path, install: bool)
|
|
||||||
|
|
||||||
This method will recursively scan the directory indicated in
|
|
||||||
`scan_dir` for new models and either install them in the models
|
|
||||||
directory or register them in place, depending on the setting of
|
|
||||||
`install` (default False).
|
|
||||||
|
|
||||||
The return value is the list of keys of the new installed/registered
|
|
||||||
models.
|
|
||||||
|
|
||||||
#### installer.scan_models_directory()
|
|
||||||
|
|
||||||
This method scans the models directory for new models and registers
|
|
||||||
them in place. Models that are present in the
|
|
||||||
`ModelConfigRecordService` database whose paths are not found will be
|
|
||||||
unregistered.
|
|
||||||
|
|
||||||
#### installer.sync_to_config()
|
|
||||||
|
|
||||||
This method synchronizes models in the models directory and autoimport
|
|
||||||
directory to those in the `ModelConfigRecordService` database. New
|
|
||||||
models are registered and orphan models are unregistered.
|
|
||||||
|
|
||||||
#### hash=installer.hash(model_path)
|
|
||||||
|
|
||||||
This method is calls the fasthash algorithm on a model's Path
|
|
||||||
(either a file or a folder) to generate a unique ID based on the
|
|
||||||
contents of the model.
|
|
||||||
|
|
||||||
##### installer.start(invoker)
|
|
||||||
|
|
||||||
The `start` method is called by the API intialization routines when
|
|
||||||
the API starts up. Its effect is to call `sync_to_config()` to
|
|
||||||
synchronize the model record store database with what's currently on
|
|
||||||
disk.
|
|
||||||
|
|
||||||
This method should not ordinarily be called manually.
|
|
||||||
|
@ -22,6 +22,7 @@ from ..services.invoker import Invoker
|
|||||||
from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
|
from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||||
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
||||||
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
||||||
|
from ..services.model_install import ModelInstallService
|
||||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||||
from ..services.model_records import ModelRecordServiceSQL
|
from ..services.model_records import ModelRecordServiceSQL
|
||||||
from ..services.names.names_default import SimpleNameService
|
from ..services.names.names_default import SimpleNameService
|
||||||
@ -86,6 +87,9 @@ class ApiDependencies:
|
|||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||||
model_manager = ModelManagerService(config, logger)
|
model_manager = ModelManagerService(config, logger)
|
||||||
model_record_service = ModelRecordServiceSQL(db=db)
|
model_record_service = ModelRecordServiceSQL(db=db)
|
||||||
|
model_install_service = ModelInstallService(
|
||||||
|
app_config=config, record_store=model_record_service, event_bus=events
|
||||||
|
)
|
||||||
names = SimpleNameService()
|
names = SimpleNameService()
|
||||||
performance_statistics = InvocationStatsService()
|
performance_statistics = InvocationStatsService()
|
||||||
processor = DefaultInvocationProcessor()
|
processor = DefaultInvocationProcessor()
|
||||||
@ -112,6 +116,7 @@ class ApiDependencies:
|
|||||||
logger=logger,
|
logger=logger,
|
||||||
model_manager=model_manager,
|
model_manager=model_manager,
|
||||||
model_records=model_record_service,
|
model_records=model_record_service,
|
||||||
|
model_install=model_install_service,
|
||||||
names=names,
|
names=names,
|
||||||
performance_statistics=performance_statistics,
|
performance_statistics=performance_statistics,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
from random import randbytes
|
from random import randbytes
|
||||||
from typing import List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Path, Query, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
@ -12,6 +12,7 @@ from pydantic import BaseModel, ConfigDict
|
|||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
|
||||||
from invokeai.app.services.model_records import (
|
from invokeai.app.services.model_records import (
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
@ -25,7 +26,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
|
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2"])
|
||||||
|
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
@ -43,15 +44,18 @@ class ModelsList(BaseModel):
|
|||||||
async def list_model_records(
|
async def list_model_records(
|
||||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||||
|
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
|
||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
"""Get a list of models."""
|
"""Get a list of models."""
|
||||||
record_store = ApiDependencies.invoker.services.model_records
|
record_store = ApiDependencies.invoker.services.model_records
|
||||||
found_models: list[AnyModelConfig] = []
|
found_models: list[AnyModelConfig] = []
|
||||||
if base_models:
|
if base_models:
|
||||||
for base_model in base_models:
|
for base_model in base_models:
|
||||||
found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
|
found_models.extend(
|
||||||
|
record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
found_models.extend(record_store.search_by_attr(model_type=model_type))
|
found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name))
|
||||||
return ModelsList(models=found_models)
|
return ModelsList(models=found_models)
|
||||||
|
|
||||||
|
|
||||||
@ -117,12 +121,17 @@ async def update_model_record(
|
|||||||
async def del_model_record(
|
async def del_model_record(
|
||||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""Delete Model"""
|
"""
|
||||||
|
Delete model record from database.
|
||||||
|
|
||||||
|
The configuration record will be removed. The corresponding weights files will be
|
||||||
|
deleted as well if they reside within the InvokeAI "models" directory.
|
||||||
|
"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
record_store = ApiDependencies.invoker.services.model_records
|
installer = ApiDependencies.invoker.services.model_install
|
||||||
record_store.del_model(key)
|
installer.delete(key)
|
||||||
logger.info(f"Deleted model: {key}")
|
logger.info(f"Deleted model: {key}")
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
except UnknownModelException as e:
|
except UnknownModelException as e:
|
||||||
@ -162,3 +171,145 @@ async def add_model_record(
|
|||||||
|
|
||||||
# now fetch it out
|
# now fetch it out
|
||||||
return record_store.get_model(config.key)
|
return record_store.get_model(config.key)
|
||||||
|
|
||||||
|
|
||||||
|
@model_records_router.post(
|
||||||
|
"/import",
|
||||||
|
operation_id="import_model_record",
|
||||||
|
responses={
|
||||||
|
201: {"description": "The model imported successfully"},
|
||||||
|
415: {"description": "Unrecognized file/folder format"},
|
||||||
|
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||||
|
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
)
|
||||||
|
async def import_model(
|
||||||
|
source: ModelSource,
|
||||||
|
config: Optional[Dict[str, Any]] = Body(
|
||||||
|
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||||
|
default=None,
|
||||||
|
),
|
||||||
|
) -> ModelInstallJob:
|
||||||
|
"""Add a model using its local path, repo_id, or remote URL.
|
||||||
|
|
||||||
|
Models will be downloaded, probed, configured and installed in a
|
||||||
|
series of background threads. The return object has `status` attribute
|
||||||
|
that can be used to monitor progress.
|
||||||
|
|
||||||
|
The source object is a discriminated Union of LocalModelSource,
|
||||||
|
HFModelSource and URLModelSource. Set the "type" field to the
|
||||||
|
appropriate value:
|
||||||
|
|
||||||
|
* To install a local path using LocalModelSource, pass a source of form:
|
||||||
|
`{
|
||||||
|
"type": "local",
|
||||||
|
"path": "/path/to/model",
|
||||||
|
"inplace": false
|
||||||
|
}`
|
||||||
|
The "inplace" flag, if true, will register the model in place in its
|
||||||
|
current filesystem location. Otherwise, the model will be copied
|
||||||
|
into the InvokeAI models directory.
|
||||||
|
|
||||||
|
* To install a HuggingFace repo_id using HFModelSource, pass a source of form:
|
||||||
|
`{
|
||||||
|
"type": "hf",
|
||||||
|
"repo_id": "stabilityai/stable-diffusion-2.0",
|
||||||
|
"variant": "fp16",
|
||||||
|
"subfolder": "vae",
|
||||||
|
"access_token": "f5820a918aaf01"
|
||||||
|
}`
|
||||||
|
The `variant`, `subfolder` and `access_token` fields are optional.
|
||||||
|
|
||||||
|
* To install a remote model using an arbitrary URL, pass:
|
||||||
|
`{
|
||||||
|
"type": "url",
|
||||||
|
"url": "http://www.civitai.com/models/123456",
|
||||||
|
"access_token": "f5820a918aaf01"
|
||||||
|
}`
|
||||||
|
The `access_token` field is optonal
|
||||||
|
|
||||||
|
The model's configuration record will be probed and filled in
|
||||||
|
automatically. To override the default guesses, pass "metadata"
|
||||||
|
with a Dict containing the attributes you wish to override.
|
||||||
|
|
||||||
|
Installation occurs in the background. Either use list_model_install_jobs()
|
||||||
|
to poll for completion, or listen on the event bus for the following events:
|
||||||
|
|
||||||
|
"model_install_started"
|
||||||
|
"model_install_completed"
|
||||||
|
"model_install_error"
|
||||||
|
|
||||||
|
On successful completion, the event's payload will contain the field "key"
|
||||||
|
containing the installed ID of the model. On an error, the event's payload
|
||||||
|
will contain the fields "error_type" and "error" describing the nature of the
|
||||||
|
error and its traceback, respectively.
|
||||||
|
|
||||||
|
"""
|
||||||
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
installer = ApiDependencies.invoker.services.model_install
|
||||||
|
result: ModelInstallJob = installer.import_model(
|
||||||
|
source=source,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
logger.info(f"Started installation of {source}")
|
||||||
|
except UnknownModelException as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=424, detail=str(e))
|
||||||
|
except InvalidModelException as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=415)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@model_records_router.get(
|
||||||
|
"/import",
|
||||||
|
operation_id="list_model_install_jobs",
|
||||||
|
)
|
||||||
|
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||||
|
"""
|
||||||
|
Return list of model install jobs.
|
||||||
|
|
||||||
|
If the optional 'source' argument is provided, then the list will be filtered
|
||||||
|
for partial string matches against the install source.
|
||||||
|
"""
|
||||||
|
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs()
|
||||||
|
return jobs
|
||||||
|
|
||||||
|
|
||||||
|
@model_records_router.patch(
|
||||||
|
"/import",
|
||||||
|
operation_id="prune_model_install_jobs",
|
||||||
|
responses={
|
||||||
|
204: {"description": "All completed and errored jobs have been pruned"},
|
||||||
|
400: {"description": "Bad request"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def prune_model_install_jobs() -> Response:
|
||||||
|
"""
|
||||||
|
Prune all completed and errored jobs from the install job list.
|
||||||
|
"""
|
||||||
|
ApiDependencies.invoker.services.model_install.prune_jobs()
|
||||||
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
|
@model_records_router.patch(
|
||||||
|
"/sync",
|
||||||
|
operation_id="sync_models_to_config",
|
||||||
|
responses={
|
||||||
|
204: {"description": "Model config record database resynced with files on disk"},
|
||||||
|
400: {"description": "Bad request"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def sync_models_to_config() -> Response:
|
||||||
|
"""
|
||||||
|
Traverse the models and autoimport directories. Model files without a corresponding
|
||||||
|
record in the database are added. Orphan records without a models file are deleted.
|
||||||
|
"""
|
||||||
|
ApiDependencies.invoker.services.model_install.sync_to_config()
|
||||||
|
return Response(status_code=204)
|
||||||
|
@ -20,6 +20,7 @@ class SocketIO:
|
|||||||
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
|
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
|
||||||
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
|
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
|
||||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
|
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
|
||||||
|
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
|
||||||
|
|
||||||
async def _handle_queue_event(self, event: Event):
|
async def _handle_queue_event(self, event: Event):
|
||||||
await self.__sio.emit(
|
await self.__sio.emit(
|
||||||
@ -28,10 +29,13 @@ class SocketIO:
|
|||||||
room=event[1]["data"]["queue_id"],
|
room=event[1]["data"]["queue_id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_sub_queue(self, sid, data, *args, **kwargs):
|
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
|
||||||
if "queue_id" in data:
|
if "queue_id" in data:
|
||||||
await self.__sio.enter_room(sid, data["queue_id"])
|
await self.__sio.enter_room(sid, data["queue_id"])
|
||||||
|
|
||||||
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
|
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
|
||||||
if "queue_id" in data:
|
if "queue_id" in data:
|
||||||
await self.__sio.leave_room(sid, data["queue_id"])
|
await self.__sio.leave_room(sid, data["queue_id"])
|
||||||
|
|
||||||
|
async def _handle_model_event(self, event: Event) -> None:
|
||||||
|
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
"""
|
"""Init file for InvokeAI configure package."""
|
||||||
Init file for InvokeAI configure package
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .config_base import PagingArgumentParser # noqa F401
|
from .config_default import InvokeAIAppConfig, get_invokeai_config
|
||||||
from .config_default import InvokeAIAppConfig, get_invokeai_config # noqa F401
|
|
||||||
|
__all__ = ["InvokeAIAppConfig", "get_invokeai_config"]
|
||||||
|
@ -173,7 +173,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pydantic import Field, TypeAdapter
|
from pydantic import Field, TypeAdapter
|
||||||
@ -334,7 +334,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
def get_config(cls, **kwargs: Dict[str, Any]) -> InvokeAIAppConfig:
|
||||||
"""Return a singleton InvokeAIAppConfig configuration object."""
|
"""Return a singleton InvokeAIAppConfig configuration object."""
|
||||||
if (
|
if (
|
||||||
cls.singleton_config is None
|
cls.singleton_config is None
|
||||||
@ -383,17 +383,17 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
return db_dir / DB_FILE
|
return db_dir / DB_FILE
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_conf_path(self) -> Optional[Path]:
|
def model_conf_path(self) -> Path:
|
||||||
"""Path to models configuration file."""
|
"""Path to models configuration file."""
|
||||||
return self._resolve(self.conf_path)
|
return self._resolve(self.conf_path)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def legacy_conf_path(self) -> Optional[Path]:
|
def legacy_conf_path(self) -> Path:
|
||||||
"""Path to directory of legacy configuration files (e.g. v1-inference.yaml)."""
|
"""Path to directory of legacy configuration files (e.g. v1-inference.yaml)."""
|
||||||
return self._resolve(self.legacy_conf_dir)
|
return self._resolve(self.legacy_conf_dir)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def models_path(self) -> Optional[Path]:
|
def models_path(self) -> Path:
|
||||||
"""Path to the models directory."""
|
"""Path to the models directory."""
|
||||||
return self._resolve(self.models_dir)
|
return self._resolve(self.models_dir)
|
||||||
|
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
from .events_base import EventServiceBase # noqa F401
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
|
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
|
||||||
@ -16,6 +17,7 @@ from invokeai.backend.model_management.models.base import BaseModelType, ModelTy
|
|||||||
|
|
||||||
class EventServiceBase:
|
class EventServiceBase:
|
||||||
queue_event: str = "queue_event"
|
queue_event: str = "queue_event"
|
||||||
|
model_event: str = "model_event"
|
||||||
|
|
||||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||||
|
|
||||||
@ -30,6 +32,13 @@ class EventServiceBase:
|
|||||||
payload={"event": event_name, "data": payload},
|
payload={"event": event_name, "data": payload},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __emit_model_event(self, event_name: str, payload: dict) -> None:
|
||||||
|
payload["timestamp"] = get_timestamp()
|
||||||
|
self.dispatch(
|
||||||
|
event_name=EventServiceBase.model_event,
|
||||||
|
payload={"event": event_name, "data": payload},
|
||||||
|
)
|
||||||
|
|
||||||
# Define events here for every event in the system.
|
# Define events here for every event in the system.
|
||||||
# This will make them easier to integrate until we find a schema generator.
|
# This will make them easier to integrate until we find a schema generator.
|
||||||
def emit_generator_progress(
|
def emit_generator_progress(
|
||||||
@ -313,3 +322,73 @@ class EventServiceBase:
|
|||||||
event_name="queue_cleared",
|
event_name="queue_cleared",
|
||||||
payload={"queue_id": queue_id},
|
payload={"queue_id": queue_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def emit_model_install_started(self, source: str) -> None:
|
||||||
|
"""
|
||||||
|
Emitted when an install job is started.
|
||||||
|
|
||||||
|
:param source: Source of the model; local path, repo_id or url
|
||||||
|
"""
|
||||||
|
self.__emit_model_event(
|
||||||
|
event_name="model_install_started",
|
||||||
|
payload={"source": source},
|
||||||
|
)
|
||||||
|
|
||||||
|
def emit_model_install_completed(self, source: str, key: str) -> None:
|
||||||
|
"""
|
||||||
|
Emitted when an install job is completed successfully.
|
||||||
|
|
||||||
|
:param source: Source of the model; local path, repo_id or url
|
||||||
|
:param key: Model config record key
|
||||||
|
"""
|
||||||
|
self.__emit_model_event(
|
||||||
|
event_name="model_install_completed",
|
||||||
|
payload={
|
||||||
|
"source": source,
|
||||||
|
"key": key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def emit_model_install_progress(
|
||||||
|
self,
|
||||||
|
source: str,
|
||||||
|
current_bytes: int,
|
||||||
|
total_bytes: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Emitted while the install job is in progress.
|
||||||
|
(Downloaded models only)
|
||||||
|
|
||||||
|
:param source: Source of the model
|
||||||
|
:param current_bytes: Number of bytes downloaded so far
|
||||||
|
:param total_bytes: Total bytes to download
|
||||||
|
"""
|
||||||
|
self.__emit_model_event(
|
||||||
|
event_name="model_install_progress",
|
||||||
|
payload={
|
||||||
|
"source": source,
|
||||||
|
"current_bytes": int,
|
||||||
|
"total_bytes": int,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def emit_model_install_error(
|
||||||
|
self,
|
||||||
|
source: str,
|
||||||
|
error_type: str,
|
||||||
|
error: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Emitted when an install job encounters an exception.
|
||||||
|
|
||||||
|
:param source: Source of the model
|
||||||
|
:param exception: The exception that raised the error
|
||||||
|
"""
|
||||||
|
self.__emit_model_event(
|
||||||
|
event_name="model_install_error",
|
||||||
|
payload={
|
||||||
|
"source": source,
|
||||||
|
"error_type": error_type,
|
||||||
|
"error": error,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@ -21,6 +21,7 @@ if TYPE_CHECKING:
|
|||||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||||
from .item_storage.item_storage_base import ItemStorageABC
|
from .item_storage.item_storage_base import ItemStorageABC
|
||||||
from .latents_storage.latents_storage_base import LatentsStorageBase
|
from .latents_storage.latents_storage_base import LatentsStorageBase
|
||||||
|
from .model_install import ModelInstallServiceBase
|
||||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||||
from .model_records import ModelRecordServiceBase
|
from .model_records import ModelRecordServiceBase
|
||||||
from .names.names_base import NameServiceBase
|
from .names.names_base import NameServiceBase
|
||||||
@ -50,6 +51,7 @@ class InvocationServices:
|
|||||||
logger: "Logger"
|
logger: "Logger"
|
||||||
model_manager: "ModelManagerServiceBase"
|
model_manager: "ModelManagerServiceBase"
|
||||||
model_records: "ModelRecordServiceBase"
|
model_records: "ModelRecordServiceBase"
|
||||||
|
model_install: "ModelInstallServiceBase"
|
||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
performance_statistics: "InvocationStatsServiceBase"
|
performance_statistics: "InvocationStatsServiceBase"
|
||||||
queue: "InvocationQueueABC"
|
queue: "InvocationQueueABC"
|
||||||
@ -77,6 +79,7 @@ class InvocationServices:
|
|||||||
logger: "Logger",
|
logger: "Logger",
|
||||||
model_manager: "ModelManagerServiceBase",
|
model_manager: "ModelManagerServiceBase",
|
||||||
model_records: "ModelRecordServiceBase",
|
model_records: "ModelRecordServiceBase",
|
||||||
|
model_install: "ModelInstallServiceBase",
|
||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
performance_statistics: "InvocationStatsServiceBase",
|
performance_statistics: "InvocationStatsServiceBase",
|
||||||
queue: "InvocationQueueABC",
|
queue: "InvocationQueueABC",
|
||||||
@ -102,6 +105,7 @@ class InvocationServices:
|
|||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.model_records = model_records
|
self.model_records = model_records
|
||||||
|
self.model_install = model_install
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
self.performance_statistics = performance_statistics
|
self.performance_statistics = performance_statistics
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
|
25
invokeai/app/services/model_install/__init__.py
Normal file
25
invokeai/app/services/model_install/__init__.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
"""Initialization file for model install service package."""
|
||||||
|
|
||||||
|
from .model_install_base import (
|
||||||
|
HFModelSource,
|
||||||
|
InstallStatus,
|
||||||
|
LocalModelSource,
|
||||||
|
ModelInstallJob,
|
||||||
|
ModelInstallServiceBase,
|
||||||
|
ModelSource,
|
||||||
|
UnknownInstallJobException,
|
||||||
|
URLModelSource,
|
||||||
|
)
|
||||||
|
from .model_install_default import ModelInstallService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ModelInstallServiceBase",
|
||||||
|
"ModelInstallService",
|
||||||
|
"InstallStatus",
|
||||||
|
"ModelInstallJob",
|
||||||
|
"UnknownInstallJobException",
|
||||||
|
"ModelSource",
|
||||||
|
"LocalModelSource",
|
||||||
|
"HFModelSource",
|
||||||
|
"URLModelSource",
|
||||||
|
]
|
317
invokeai/app/services/model_install/model_install_base.py
Normal file
317
invokeai/app/services/model_install/model_install_base.py
Normal file
@ -0,0 +1,317 @@
|
|||||||
|
import re
|
||||||
|
import traceback
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from fastapi import Body
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.events import EventServiceBase
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||||
|
from invokeai.backend.model_manager import AnyModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
class InstallStatus(str, Enum):
|
||||||
|
"""State of an install job running in the background."""
|
||||||
|
|
||||||
|
WAITING = "waiting" # waiting to be dequeued
|
||||||
|
RUNNING = "running" # being processed
|
||||||
|
COMPLETED = "completed" # finished running
|
||||||
|
ERROR = "error" # terminated with an error message
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownInstallJobException(Exception):
|
||||||
|
"""Raised when the status of an unknown job is requested."""
|
||||||
|
|
||||||
|
|
||||||
|
class StringLikeSource(BaseModel):
|
||||||
|
"""
|
||||||
|
Base class for model sources, implements functions that lets the source be sorted and indexed.
|
||||||
|
|
||||||
|
These shenanigans let this stuff work:
|
||||||
|
|
||||||
|
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
|
||||||
|
mydict = {source1: 'model 1'}
|
||||||
|
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
|
||||||
|
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
|
||||||
|
|
||||||
|
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
|
||||||
|
assert source1 == source2
|
||||||
|
assert source1 == 'C:/users/mort/foo.safetensors'
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""Return hash of the path field, for indexing."""
|
||||||
|
return hash(str(self))
|
||||||
|
|
||||||
|
def __lt__(self, other: object) -> int:
|
||||||
|
"""Return comparison of the stringified version, for sorting."""
|
||||||
|
return str(self) < str(other)
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
"""Return equality on the stringified version."""
|
||||||
|
if isinstance(other, Path):
|
||||||
|
return str(self) == other.as_posix()
|
||||||
|
else:
|
||||||
|
return str(self) == str(other)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalModelSource(StringLikeSource):
|
||||||
|
"""A local file or directory path."""
|
||||||
|
|
||||||
|
path: str | Path
|
||||||
|
inplace: Optional[bool] = False
|
||||||
|
type: Literal["local"] = "local"
|
||||||
|
|
||||||
|
# these methods allow the source to be used in a string-like way,
|
||||||
|
# for example as an index into a dict
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""Return string version of path when string rep needed."""
|
||||||
|
return Path(self.path).as_posix()
|
||||||
|
|
||||||
|
|
||||||
|
class HFModelSource(StringLikeSource):
|
||||||
|
"""A HuggingFace repo_id, with optional variant and sub-folder."""
|
||||||
|
|
||||||
|
repo_id: str
|
||||||
|
variant: Optional[str] = None
|
||||||
|
subfolder: Optional[str | Path] = None
|
||||||
|
access_token: Optional[str] = None
|
||||||
|
type: Literal["hf"] = "hf"
|
||||||
|
|
||||||
|
@field_validator("repo_id")
|
||||||
|
@classmethod
|
||||||
|
def proper_repo_id(cls, v: str) -> str: # noqa D102
|
||||||
|
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
|
||||||
|
raise ValueError(f"{v}: invalid repo_id format")
|
||||||
|
return v
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""Return string version of repoid when string rep needed."""
|
||||||
|
base: str = self.repo_id
|
||||||
|
base += f":{self.subfolder}" if self.subfolder else ""
|
||||||
|
base += f" ({self.variant})" if self.variant else ""
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
class URLModelSource(StringLikeSource):
|
||||||
|
"""A generic URL point to a checkpoint file."""
|
||||||
|
|
||||||
|
url: AnyHttpUrl
|
||||||
|
access_token: Optional[str] = None
|
||||||
|
type: Literal["generic_url"] = "generic_url"
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""Return string version of the url when string rep needed."""
|
||||||
|
return str(self.url)
|
||||||
|
|
||||||
|
|
||||||
|
# Body() is being applied here rather than Field() because otherwise FastAPI will
|
||||||
|
# refuse to generate a schema. Relevant links:
|
||||||
|
#
|
||||||
|
# "Model Manager Refactor Phase 1 - SQL-based config storage
|
||||||
|
# https://github.com/invoke-ai/InvokeAI/pull/5039#discussion_r1389752119 (comment)
|
||||||
|
# Param: xyz can only be a request body, using Body() when using discriminated unions
|
||||||
|
# https://github.com/tiangolo/fastapi/discussions/9761
|
||||||
|
# Body parameter cannot be a pydantic union anymore sinve v0.95
|
||||||
|
# https://github.com/tiangolo/fastapi/discussions/9287
|
||||||
|
|
||||||
|
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Body(discriminator="type")]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInstallJob(BaseModel):
|
||||||
|
"""Object that tracks the current status of an install request."""
|
||||||
|
|
||||||
|
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||||
|
config_in: Dict[str, Any] = Field(
|
||||||
|
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||||
|
)
|
||||||
|
config_out: Optional[AnyModelConfig] = Field(
|
||||||
|
default=None, description="After successful installation, this will hold the configuration object."
|
||||||
|
)
|
||||||
|
inplace: bool = Field(
|
||||||
|
default=False, description="Leave model in its current location; otherwise install under models directory"
|
||||||
|
)
|
||||||
|
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
|
||||||
|
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
|
||||||
|
error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR")
|
||||||
|
error: Optional[str] = Field(default=None, description="Error traceback") # noqa #501
|
||||||
|
|
||||||
|
def set_error(self, e: Exception) -> None:
|
||||||
|
"""Record the error and traceback from an exception."""
|
||||||
|
self.error_type = e.__class__.__name__
|
||||||
|
self.error = "".join(traceback.format_exception(e))
|
||||||
|
self.status = InstallStatus.ERROR
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInstallServiceBase(ABC):
|
||||||
|
"""Abstract base class for InvokeAI model installation."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app_config: InvokeAIAppConfig,
|
||||||
|
record_store: ModelRecordServiceBase,
|
||||||
|
event_bus: Optional["EventServiceBase"] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create ModelInstallService object.
|
||||||
|
|
||||||
|
:param config: Systemwide InvokeAIAppConfig.
|
||||||
|
:param store: Systemwide ModelConfigStore
|
||||||
|
:param event_bus: InvokeAI event bus for reporting events to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
"""Call at InvokeAI startup time."""
|
||||||
|
self.sync_to_config()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop the model install service. After this the objection can be safely deleted."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def app_config(self) -> InvokeAIAppConfig:
|
||||||
|
"""Return the appConfig object associated with the installer."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def record_store(self) -> ModelRecordServiceBase:
|
||||||
|
"""Return the ModelRecoreService object associated with the installer."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def event_bus(self) -> Optional[EventServiceBase]:
|
||||||
|
"""Return the event service base object associated with the installer."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def register_path(
|
||||||
|
self,
|
||||||
|
model_path: Union[Path, str],
|
||||||
|
config: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Probe and register the model at model_path.
|
||||||
|
|
||||||
|
This keeps the model in its current location.
|
||||||
|
|
||||||
|
:param model_path: Filesystem Path to the model.
|
||||||
|
:param config: Dict of attributes that will override autoassigned values.
|
||||||
|
:returns id: The string ID of the registered model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def unregister(self, key: str) -> None:
|
||||||
|
"""Remove model with indicated key from the database."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, key: str) -> None:
|
||||||
|
"""Remove model with indicated key from the database. Delete its files only if they are within our models directory."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def unconditionally_delete(self, key: str) -> None:
|
||||||
|
"""Remove model with indicated key from the database and unconditionally delete weight files from disk."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def install_path(
|
||||||
|
self,
|
||||||
|
model_path: Union[Path, str],
|
||||||
|
config: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Probe, register and install the model in the models directory.
|
||||||
|
|
||||||
|
This moves the model from its current location into
|
||||||
|
the models directory handled by InvokeAI.
|
||||||
|
|
||||||
|
:param model_path: Filesystem Path to the model.
|
||||||
|
:param config: Dict of attributes that will override autoassigned values.
|
||||||
|
:returns id: The string ID of the registered model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def import_model(
|
||||||
|
self,
|
||||||
|
source: ModelSource,
|
||||||
|
config: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> ModelInstallJob:
|
||||||
|
"""Install the indicated model.
|
||||||
|
|
||||||
|
:param source: ModelSource object
|
||||||
|
|
||||||
|
:param config: Optional dict. Any fields in this dict
|
||||||
|
will override corresponding autoassigned probe fields in the
|
||||||
|
model's config record. Use it to override
|
||||||
|
`name`, `description`, `base_type`, `model_type`, `format`,
|
||||||
|
`prediction_type`, `image_size`, and/or `ztsnr_training`.
|
||||||
|
|
||||||
|
This will download the model located at `source`,
|
||||||
|
probe it, and install it into the models directory.
|
||||||
|
This call is executed asynchronously in a separate
|
||||||
|
thread and will issue the following events on the event bus:
|
||||||
|
|
||||||
|
- model_install_started
|
||||||
|
- model_install_error
|
||||||
|
- model_install_completed
|
||||||
|
|
||||||
|
The `inplace` flag does not affect the behavior of downloaded
|
||||||
|
models, which are always moved into the `models` directory.
|
||||||
|
|
||||||
|
The call returns a ModelInstallJob object which can be
|
||||||
|
polled to learn the current status and/or error message.
|
||||||
|
|
||||||
|
Variants recognized by HuggingFace currently are:
|
||||||
|
1. onnx
|
||||||
|
2. openvino
|
||||||
|
3. fp16
|
||||||
|
4. None (usually returns fp32 model)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_job(self, source: ModelSource) -> List[ModelInstallJob]:
|
||||||
|
"""Return the ModelInstallJob(s) corresponding to the provided source."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
||||||
|
"""
|
||||||
|
List active and complete install jobs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def prune_jobs(self) -> None:
|
||||||
|
"""Prune all completed and errored jobs."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def wait_for_installs(self) -> List[ModelInstallJob]:
|
||||||
|
"""
|
||||||
|
Wait for all pending installs to complete.
|
||||||
|
|
||||||
|
This will block until all pending installs have
|
||||||
|
completed, been cancelled, or errored out. It will
|
||||||
|
block indefinitely if one or more jobs are in the
|
||||||
|
paused state.
|
||||||
|
|
||||||
|
It will return the current list of jobs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
|
||||||
|
"""
|
||||||
|
Recursively scan directory for new models and register or install them.
|
||||||
|
|
||||||
|
:param scan_dir: Path to the directory to scan.
|
||||||
|
:param install: Install if True, otherwise register in place.
|
||||||
|
:returns list of IDs: Returns list of IDs of models registered/installed
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def sync_to_config(self) -> None:
|
||||||
|
"""Synchronize models on disk to those in the model record database."""
|
395
invokeai/app/services/model_install/model_install_default.py
Normal file
395
invokeai/app/services/model_install/model_install_default.py
Normal file
@ -0,0 +1,395 @@
|
|||||||
|
"""Model installation class."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from hashlib import sha256
|
||||||
|
from logging import Logger
|
||||||
|
from pathlib import Path
|
||||||
|
from queue import Queue
|
||||||
|
from random import randbytes
|
||||||
|
from shutil import copyfile, copytree, move, rmtree
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.events import EventServiceBase
|
||||||
|
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, UnknownModelException
|
||||||
|
from invokeai.backend.model_manager.config import (
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
InvalidModelConfigException,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_manager.hash import FastModelHash
|
||||||
|
from invokeai.backend.model_manager.probe import ModelProbe
|
||||||
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
|
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||||
|
|
||||||
|
from .model_install_base import (
|
||||||
|
InstallStatus,
|
||||||
|
LocalModelSource,
|
||||||
|
ModelInstallJob,
|
||||||
|
ModelInstallServiceBase,
|
||||||
|
ModelSource,
|
||||||
|
)
|
||||||
|
|
||||||
|
# marker that the queue is done and that thread should exit
|
||||||
|
STOP_JOB = ModelInstallJob(
|
||||||
|
source=LocalModelSource(path="stop"),
|
||||||
|
local_path=Path("/dev/null"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInstallService(ModelInstallServiceBase):
|
||||||
|
"""class for InvokeAI model installation."""
|
||||||
|
|
||||||
|
_app_config: InvokeAIAppConfig
|
||||||
|
_record_store: ModelRecordServiceBase
|
||||||
|
_event_bus: Optional[EventServiceBase] = None
|
||||||
|
_install_queue: Queue[ModelInstallJob]
|
||||||
|
_install_jobs: List[ModelInstallJob]
|
||||||
|
_logger: Logger
|
||||||
|
_cached_model_paths: Set[Path]
|
||||||
|
_models_installed: Set[str]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app_config: InvokeAIAppConfig,
|
||||||
|
record_store: ModelRecordServiceBase,
|
||||||
|
event_bus: Optional[EventServiceBase] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the installer object.
|
||||||
|
|
||||||
|
:param app_config: InvokeAIAppConfig object
|
||||||
|
:param record_store: Previously-opened ModelRecordService database
|
||||||
|
:param event_bus: Optional EventService object
|
||||||
|
"""
|
||||||
|
self._app_config = app_config
|
||||||
|
self._record_store = record_store
|
||||||
|
self._event_bus = event_bus
|
||||||
|
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
|
||||||
|
self._install_jobs = []
|
||||||
|
self._install_queue = Queue()
|
||||||
|
self._cached_model_paths = set()
|
||||||
|
self._models_installed = set()
|
||||||
|
self._start_installer_thread()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||||
|
return self._app_config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def record_store(self) -> ModelRecordServiceBase: # noqa D102
|
||||||
|
return self._record_store
|
||||||
|
|
||||||
|
@property
|
||||||
|
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
||||||
|
return self._event_bus
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop the install thread; after this the object can be deleted and garbage collected."""
|
||||||
|
self._install_queue.put(STOP_JOB)
|
||||||
|
|
||||||
|
def _start_installer_thread(self) -> None:
|
||||||
|
threading.Thread(target=self._install_next_item, daemon=True).start()
|
||||||
|
|
||||||
|
def _install_next_item(self) -> None:
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
job = self._install_queue.get()
|
||||||
|
if job == STOP_JOB:
|
||||||
|
done = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert job.local_path is not None
|
||||||
|
try:
|
||||||
|
self._signal_job_running(job)
|
||||||
|
if job.inplace:
|
||||||
|
key = self.register_path(job.local_path, job.config_in)
|
||||||
|
else:
|
||||||
|
key = self.install_path(job.local_path, job.config_in)
|
||||||
|
job.config_out = self.record_store.get_model(key)
|
||||||
|
self._signal_job_completed(job)
|
||||||
|
|
||||||
|
except (OSError, DuplicateModelException, InvalidModelConfigException) as excp:
|
||||||
|
self._signal_job_errored(job, excp)
|
||||||
|
finally:
|
||||||
|
self._install_queue.task_done()
|
||||||
|
self._logger.info("Install thread exiting")
|
||||||
|
|
||||||
|
def _signal_job_running(self, job: ModelInstallJob) -> None:
|
||||||
|
job.status = InstallStatus.RUNNING
|
||||||
|
self._logger.info(f"{job.source}: model installation started")
|
||||||
|
if self._event_bus:
|
||||||
|
self._event_bus.emit_model_install_started(str(job.source))
|
||||||
|
|
||||||
|
def _signal_job_completed(self, job: ModelInstallJob) -> None:
|
||||||
|
job.status = InstallStatus.COMPLETED
|
||||||
|
assert job.config_out
|
||||||
|
self._logger.info(
|
||||||
|
f"{job.source}: model installation completed. {job.local_path} registered key {job.config_out.key}"
|
||||||
|
)
|
||||||
|
if self._event_bus:
|
||||||
|
assert job.local_path is not None
|
||||||
|
assert job.config_out is not None
|
||||||
|
key = job.config_out.key
|
||||||
|
self._event_bus.emit_model_install_completed(str(job.source), key)
|
||||||
|
|
||||||
|
def _signal_job_errored(self, job: ModelInstallJob, excp: Exception) -> None:
|
||||||
|
job.set_error(excp)
|
||||||
|
self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}")
|
||||||
|
if self._event_bus:
|
||||||
|
error_type = job.error_type
|
||||||
|
error = job.error
|
||||||
|
assert error_type is not None
|
||||||
|
assert error is not None
|
||||||
|
self._event_bus.emit_model_install_error(str(job.source), error_type, error)
|
||||||
|
|
||||||
|
def register_path(
|
||||||
|
self,
|
||||||
|
model_path: Union[Path, str],
|
||||||
|
config: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> str: # noqa D102
|
||||||
|
model_path = Path(model_path)
|
||||||
|
config = config or {}
|
||||||
|
if config.get("source") is None:
|
||||||
|
config["source"] = model_path.resolve().as_posix()
|
||||||
|
return self._register(model_path, config)
|
||||||
|
|
||||||
|
def install_path(
|
||||||
|
self,
|
||||||
|
model_path: Union[Path, str],
|
||||||
|
config: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> str: # noqa D102
|
||||||
|
model_path = Path(model_path)
|
||||||
|
config = config or {}
|
||||||
|
if config.get("source") is None:
|
||||||
|
config["source"] = model_path.resolve().as_posix()
|
||||||
|
|
||||||
|
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||||
|
old_hash = info.original_hash
|
||||||
|
dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name
|
||||||
|
new_path = self._copy_model(model_path, dest_path)
|
||||||
|
new_hash = FastModelHash.hash(new_path)
|
||||||
|
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
||||||
|
|
||||||
|
return self._register(
|
||||||
|
new_path,
|
||||||
|
config,
|
||||||
|
info,
|
||||||
|
)
|
||||||
|
|
||||||
|
def import_model(
|
||||||
|
self,
|
||||||
|
source: ModelSource,
|
||||||
|
config: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> ModelInstallJob: # noqa D102
|
||||||
|
if not config:
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
# Installing a local path
|
||||||
|
if isinstance(source, LocalModelSource) and Path(source.path).exists(): # a path that is already on disk
|
||||||
|
job = ModelInstallJob(
|
||||||
|
source=source,
|
||||||
|
config_in=config,
|
||||||
|
local_path=Path(source.path),
|
||||||
|
)
|
||||||
|
self._install_jobs.append(job)
|
||||||
|
self._install_queue.put(job)
|
||||||
|
return job
|
||||||
|
|
||||||
|
else: # here is where we'd download a URL or repo_id. Implementation pending download queue.
|
||||||
|
raise UnknownModelException("File or directory not found")
|
||||||
|
|
||||||
|
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
||||||
|
return self._install_jobs
|
||||||
|
|
||||||
|
def get_job(self, source: ModelSource) -> List[ModelInstallJob]: # noqa D102
|
||||||
|
return [x for x in self._install_jobs if x.source == source]
|
||||||
|
|
||||||
|
def wait_for_installs(self) -> List[ModelInstallJob]: # noqa D102
|
||||||
|
self._install_queue.join()
|
||||||
|
return self._install_jobs
|
||||||
|
|
||||||
|
def prune_jobs(self) -> None:
|
||||||
|
"""Prune all completed and errored jobs."""
|
||||||
|
unfinished_jobs = [
|
||||||
|
x for x in self._install_jobs if x.status not in [InstallStatus.COMPLETED, InstallStatus.ERROR]
|
||||||
|
]
|
||||||
|
self._install_jobs = unfinished_jobs
|
||||||
|
|
||||||
|
def sync_to_config(self) -> None:
|
||||||
|
"""Synchronize models on disk to those in the config record store database."""
|
||||||
|
self._scan_models_directory()
|
||||||
|
if autoimport := self._app_config.autoimport_dir:
|
||||||
|
self._logger.info("Scanning autoimport directory for new models")
|
||||||
|
installed = self.scan_directory(self._app_config.root_path / autoimport)
|
||||||
|
self._logger.info(f"{len(installed)} new models registered")
|
||||||
|
self._logger.info("Model installer (re)initialized")
|
||||||
|
|
||||||
|
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||||
|
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
|
||||||
|
callback = self._scan_install if install else self._scan_register
|
||||||
|
search = ModelSearch(on_model_found=callback)
|
||||||
|
self._models_installed: Set[str] = set()
|
||||||
|
search.search(scan_dir)
|
||||||
|
return list(self._models_installed)
|
||||||
|
|
||||||
|
def _scan_models_directory(self) -> None:
|
||||||
|
"""
|
||||||
|
Scan the models directory for new and missing models.
|
||||||
|
|
||||||
|
New models will be added to the storage backend. Missing models
|
||||||
|
will be deleted.
|
||||||
|
"""
|
||||||
|
defunct_models = set()
|
||||||
|
installed = set()
|
||||||
|
|
||||||
|
with Chdir(self._app_config.models_path):
|
||||||
|
self._logger.info("Checking for models that have been moved or deleted from disk")
|
||||||
|
for model_config in self.record_store.all_models():
|
||||||
|
path = Path(model_config.path)
|
||||||
|
if not path.exists():
|
||||||
|
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering")
|
||||||
|
defunct_models.add(model_config.key)
|
||||||
|
for key in defunct_models:
|
||||||
|
self.unregister(key)
|
||||||
|
|
||||||
|
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
|
||||||
|
for cur_base_model in BaseModelType:
|
||||||
|
for cur_model_type in ModelType:
|
||||||
|
models_dir = Path(cur_base_model.value, cur_model_type.value)
|
||||||
|
installed.update(self.scan_directory(models_dir))
|
||||||
|
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||||
|
|
||||||
|
def _sync_model_path(self, key: str, ignore_hash_change: bool = False) -> AnyModelConfig:
|
||||||
|
"""
|
||||||
|
Move model into the location indicated by its basetype, type and name.
|
||||||
|
|
||||||
|
Call this after updating a model's attributes in order to move
|
||||||
|
the model's path into the location indicated by its basetype, type and
|
||||||
|
name. Applies only to models whose paths are within the root `models_dir`
|
||||||
|
directory.
|
||||||
|
|
||||||
|
May raise an UnknownModelException.
|
||||||
|
"""
|
||||||
|
model = self.record_store.get_model(key)
|
||||||
|
old_path = Path(model.path)
|
||||||
|
models_dir = self.app_config.models_path
|
||||||
|
|
||||||
|
if not old_path.is_relative_to(models_dir):
|
||||||
|
return model
|
||||||
|
|
||||||
|
new_path = models_dir / model.base.value / model.type.value / model.name
|
||||||
|
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||||
|
new_path = self._move_model(old_path, new_path)
|
||||||
|
new_hash = FastModelHash.hash(new_path)
|
||||||
|
model.path = new_path.relative_to(models_dir).as_posix()
|
||||||
|
if model.current_hash != new_hash:
|
||||||
|
assert (
|
||||||
|
ignore_hash_change
|
||||||
|
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
|
||||||
|
model.current_hash = new_hash
|
||||||
|
self._logger.info(f"Model has new hash {model.current_hash}, but will continue to be identified by {key}")
|
||||||
|
self.record_store.update_model(key, model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _scan_register(self, model: Path) -> bool:
|
||||||
|
if model in self._cached_model_paths:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
id = self.register_path(model)
|
||||||
|
self._sync_model_path(id) # possibly move it to right place in `models`
|
||||||
|
self._logger.info(f"Registered {model.name} with id {id}")
|
||||||
|
self._models_installed.add(id)
|
||||||
|
except DuplicateModelException:
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _scan_install(self, model: Path) -> bool:
|
||||||
|
if model in self._cached_model_paths:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
id = self.install_path(model)
|
||||||
|
self._logger.info(f"Installed {model} with id {id}")
|
||||||
|
self._models_installed.add(id)
|
||||||
|
except DuplicateModelException:
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
|
||||||
|
def unregister(self, key: str) -> None: # noqa D102
|
||||||
|
self.record_store.del_model(key)
|
||||||
|
|
||||||
|
def delete(self, key: str) -> None: # noqa D102
|
||||||
|
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||||
|
model = self.record_store.get_model(key)
|
||||||
|
models_dir = self.app_config.models_path
|
||||||
|
model_path = models_dir / model.path
|
||||||
|
if model_path.is_relative_to(models_dir):
|
||||||
|
self.unconditionally_delete(key)
|
||||||
|
else:
|
||||||
|
self.unregister(key)
|
||||||
|
|
||||||
|
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
||||||
|
model = self.record_store.get_model(key)
|
||||||
|
path = self.app_config.models_path / model.path
|
||||||
|
if path.is_dir():
|
||||||
|
rmtree(path)
|
||||||
|
else:
|
||||||
|
path.unlink()
|
||||||
|
self.unregister(key)
|
||||||
|
|
||||||
|
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
|
||||||
|
if old_path == new_path:
|
||||||
|
return old_path
|
||||||
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if old_path.is_dir():
|
||||||
|
copytree(old_path, new_path)
|
||||||
|
else:
|
||||||
|
copyfile(old_path, new_path)
|
||||||
|
return new_path
|
||||||
|
|
||||||
|
def _move_model(self, old_path: Path, new_path: Path) -> Path:
|
||||||
|
if old_path == new_path:
|
||||||
|
return old_path
|
||||||
|
|
||||||
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# if path already exists then we jigger the name to make it unique
|
||||||
|
counter: int = 1
|
||||||
|
while new_path.exists():
|
||||||
|
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
|
||||||
|
if not path.exists():
|
||||||
|
new_path = path
|
||||||
|
counter += 1
|
||||||
|
move(old_path, new_path)
|
||||||
|
return new_path
|
||||||
|
|
||||||
|
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
|
||||||
|
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
|
||||||
|
if config: # used to override probe fields
|
||||||
|
for key, value in config.items():
|
||||||
|
setattr(info, key, value)
|
||||||
|
return info
|
||||||
|
|
||||||
|
def _create_key(self) -> str:
|
||||||
|
return sha256(randbytes(100)).hexdigest()[0:32]
|
||||||
|
|
||||||
|
def _register(
|
||||||
|
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||||
|
) -> str:
|
||||||
|
info = info or ModelProbe.probe(model_path, config)
|
||||||
|
key = self._create_key()
|
||||||
|
|
||||||
|
model_path = model_path.absolute()
|
||||||
|
if model_path.is_relative_to(self.app_config.models_path):
|
||||||
|
model_path = model_path.relative_to(self.app_config.models_path)
|
||||||
|
|
||||||
|
info.path = model_path.as_posix()
|
||||||
|
|
||||||
|
# add 'main' specific fields
|
||||||
|
if hasattr(info, "config"):
|
||||||
|
# make config relative to our root
|
||||||
|
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
|
||||||
|
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||||
|
self.record_store.add_model(key, info)
|
||||||
|
return key
|
@ -6,3 +6,11 @@ from .model_records_base import ( # noqa F401
|
|||||||
UnknownModelException,
|
UnknownModelException,
|
||||||
)
|
)
|
||||||
from .model_records_sql import ModelRecordServiceSQL # noqa F401
|
from .model_records_sql import ModelRecordServiceSQL # noqa F401
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ModelRecordServiceBase",
|
||||||
|
"ModelRecordServiceSQL",
|
||||||
|
"DuplicateModelException",
|
||||||
|
"InvalidModelException",
|
||||||
|
"UnknownModelException",
|
||||||
|
]
|
||||||
|
@ -32,6 +32,8 @@ class ModelProbeInfo(object):
|
|||||||
upcast_attention: bool
|
upcast_attention: bool
|
||||||
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
|
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
|
||||||
image_size: int
|
image_size: int
|
||||||
|
name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ProbeBase(object):
|
class ProbeBase(object):
|
||||||
@ -113,12 +115,16 @@ class ModelProbe(object):
|
|||||||
base_type = probe.get_base_type()
|
base_type = probe.get_base_type()
|
||||||
variant_type = probe.get_variant_type()
|
variant_type = probe.get_variant_type()
|
||||||
prediction_type = probe.get_scheduler_prediction_type()
|
prediction_type = probe.get_scheduler_prediction_type()
|
||||||
|
name = cls.get_model_name(model_path)
|
||||||
|
description = f"{base_type.value} {model_type.value} model {name}"
|
||||||
format = probe.get_format()
|
format = probe.get_format()
|
||||||
model_info = ModelProbeInfo(
|
model_info = ModelProbeInfo(
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
base_type=base_type,
|
base_type=base_type,
|
||||||
variant_type=variant_type,
|
variant_type=variant_type,
|
||||||
prediction_type=prediction_type,
|
prediction_type=prediction_type,
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
upcast_attention=(
|
upcast_attention=(
|
||||||
base_type == BaseModelType.StableDiffusion2
|
base_type == BaseModelType.StableDiffusion2
|
||||||
and prediction_type == SchedulerPredictionType.VPrediction
|
and prediction_type == SchedulerPredictionType.VPrediction
|
||||||
@ -142,6 +148,13 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
return model_info
|
return model_info
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_name(cls, model_path: Path) -> str:
|
||||||
|
if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
||||||
|
return model_path.stem
|
||||||
|
else:
|
||||||
|
return model_path.name
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
|
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
|
||||||
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
||||||
|
29
invokeai/backend/model_manager/__init__.py
Normal file
29
invokeai/backend/model_manager/__init__.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
"""Re-export frequently-used symbols from the Model Manager backend."""
|
||||||
|
|
||||||
|
from .config import (
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
InvalidModelConfigException,
|
||||||
|
ModelConfigFactory,
|
||||||
|
ModelFormat,
|
||||||
|
ModelType,
|
||||||
|
ModelVariantType,
|
||||||
|
SchedulerPredictionType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
|
from .probe import ModelProbe
|
||||||
|
from .search import ModelSearch
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ModelProbe",
|
||||||
|
"ModelSearch",
|
||||||
|
"InvalidModelConfigException",
|
||||||
|
"ModelConfigFactory",
|
||||||
|
"BaseModelType",
|
||||||
|
"ModelType",
|
||||||
|
"SubModelType",
|
||||||
|
"ModelVariantType",
|
||||||
|
"ModelFormat",
|
||||||
|
"SchedulerPredictionType",
|
||||||
|
"AnyModelConfig",
|
||||||
|
]
|
@ -23,7 +23,7 @@ from enum import Enum
|
|||||||
from typing import Literal, Optional, Type, Union
|
from typing import Literal, Optional, Type, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated, Any, Dict
|
||||||
|
|
||||||
|
|
||||||
class InvalidModelConfigException(Exception):
|
class InvalidModelConfigException(Exception):
|
||||||
@ -122,7 +122,7 @@ class ModelConfigBase(BaseModel):
|
|||||||
validate_assignment=True,
|
validate_assignment=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update(self, attributes: dict):
|
def update(self, attributes: Dict[str, Any]) -> None:
|
||||||
"""Update the object with fields in dict."""
|
"""Update the object with fields in dict."""
|
||||||
for key, value in attributes.items():
|
for key, value in attributes.items():
|
||||||
setattr(self, key, value) # may raise a validation error
|
setattr(self, key, value) # may raise a validation error
|
||||||
@ -195,8 +195,6 @@ class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
|||||||
"""Model config for main checkpoint models."""
|
"""Model config for main checkpoint models."""
|
||||||
|
|
||||||
type: Literal[ModelType.Main] = ModelType.Main
|
type: Literal[ModelType.Main] = ModelType.Main
|
||||||
# Note that we do not need prediction_type or upcast_attention here
|
|
||||||
# because they are provided in the checkpoint's own config file.
|
|
||||||
|
|
||||||
|
|
||||||
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||||
|
684
invokeai/backend/model_manager/probe.py
Normal file
684
invokeai/backend/model_manager/probe.py
Normal file
@ -0,0 +1,684 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
|
import safetensors.torch
|
||||||
|
import torch
|
||||||
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
|
from invokeai.backend.model_management.models.base import read_checkpoint_meta
|
||||||
|
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
||||||
|
from invokeai.backend.model_management.util import lora_token_vector_length
|
||||||
|
from invokeai.backend.util.util import SilenceWarnings
|
||||||
|
|
||||||
|
from .config import (
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
InvalidModelConfigException,
|
||||||
|
ModelConfigFactory,
|
||||||
|
ModelFormat,
|
||||||
|
ModelType,
|
||||||
|
ModelVariantType,
|
||||||
|
SchedulerPredictionType,
|
||||||
|
)
|
||||||
|
from .hash import FastModelHash
|
||||||
|
|
||||||
|
CkptType = Dict[str, Any]
|
||||||
|
|
||||||
|
LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = {
|
||||||
|
BaseModelType.StableDiffusion1: {
|
||||||
|
ModelVariantType.Normal: "v1-inference.yaml",
|
||||||
|
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||||
|
},
|
||||||
|
BaseModelType.StableDiffusion2: {
|
||||||
|
ModelVariantType.Normal: {
|
||||||
|
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
|
||||||
|
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
|
||||||
|
},
|
||||||
|
ModelVariantType.Inpaint: {
|
||||||
|
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
|
||||||
|
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
BaseModelType.StableDiffusionXL: {
|
||||||
|
ModelVariantType.Normal: "sd_xl_base.yaml",
|
||||||
|
},
|
||||||
|
BaseModelType.StableDiffusionXLRefiner: {
|
||||||
|
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ProbeBase(object):
|
||||||
|
"""Base class for probes."""
|
||||||
|
|
||||||
|
def __init__(self, model_path: Path):
|
||||||
|
self.model_path = model_path
|
||||||
|
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
"""Get model base type."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_format(self) -> ModelFormat:
|
||||||
|
"""Get model file format."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_variant_type(self) -> Optional[ModelVariantType]:
|
||||||
|
"""Get model variant type."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
|
||||||
|
"""Get model scheduler prediction type."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProbe(object):
|
||||||
|
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
|
||||||
|
"diffusers": {},
|
||||||
|
"checkpoint": {},
|
||||||
|
"onnx": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
CLASS2TYPE = {
|
||||||
|
"StableDiffusionPipeline": ModelType.Main,
|
||||||
|
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||||
|
"StableDiffusionXLPipeline": ModelType.Main,
|
||||||
|
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||||
|
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||||
|
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||||
|
"AutoencoderKL": ModelType.Vae,
|
||||||
|
"AutoencoderTiny": ModelType.Vae,
|
||||||
|
"ControlNetModel": ModelType.ControlNet,
|
||||||
|
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||||
|
"T2IAdapter": ModelType.T2IAdapter,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_probe(
|
||||||
|
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase]
|
||||||
|
) -> None:
|
||||||
|
cls.PROBES[format][model_type] = probe_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def heuristic_probe(
|
||||||
|
cls,
|
||||||
|
model_path: Path,
|
||||||
|
fields: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> AnyModelConfig:
|
||||||
|
return cls.probe(model_path, fields)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def probe(
|
||||||
|
cls,
|
||||||
|
model_path: Path,
|
||||||
|
fields: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> AnyModelConfig:
|
||||||
|
"""
|
||||||
|
Probe the model at model_path and return its configuration record.
|
||||||
|
|
||||||
|
:param model_path: Path to the model file (checkpoint) or directory (diffusers).
|
||||||
|
:param fields: An optional dictionary that can be used to override probed
|
||||||
|
fields. Typically used for fields that don't probe well, such as prediction_type.
|
||||||
|
|
||||||
|
Returns: The appropriate model configuration derived from ModelConfigBase.
|
||||||
|
"""
|
||||||
|
if fields is None:
|
||||||
|
fields = {}
|
||||||
|
|
||||||
|
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||||
|
model_info = None
|
||||||
|
model_type = None
|
||||||
|
if format_type == "diffusers":
|
||||||
|
model_type = cls.get_model_type_from_folder(model_path)
|
||||||
|
else:
|
||||||
|
model_type = cls.get_model_type_from_checkpoint(model_path)
|
||||||
|
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
|
||||||
|
|
||||||
|
probe_class = cls.PROBES[format_type].get(model_type)
|
||||||
|
if not probe_class:
|
||||||
|
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||||
|
|
||||||
|
hash = FastModelHash.hash(model_path)
|
||||||
|
probe = probe_class(model_path)
|
||||||
|
|
||||||
|
fields["path"] = model_path.as_posix()
|
||||||
|
fields["type"] = fields.get("type") or model_type
|
||||||
|
fields["base"] = fields.get("base") or probe.get_base_type()
|
||||||
|
fields["variant"] = fields.get("variant") or probe.get_variant_type()
|
||||||
|
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
|
||||||
|
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
||||||
|
fields["description"] = (
|
||||||
|
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||||
|
)
|
||||||
|
fields["format"] = fields.get("format") or probe.get_format()
|
||||||
|
fields["original_hash"] = fields.get("original_hash") or hash
|
||||||
|
fields["current_hash"] = fields.get("current_hash") or hash
|
||||||
|
|
||||||
|
# additional fields needed for main and controlnet models
|
||||||
|
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
|
||||||
|
fields["config"] = cls._get_checkpoint_config_path(
|
||||||
|
model_path,
|
||||||
|
model_type=fields["type"],
|
||||||
|
base_type=fields["base"],
|
||||||
|
variant_type=fields["variant"],
|
||||||
|
prediction_type=fields["prediction_type"],
|
||||||
|
).as_posix()
|
||||||
|
|
||||||
|
# additional fields needed for main non-checkpoint models
|
||||||
|
elif fields["type"] == ModelType.Main and fields["format"] in [
|
||||||
|
ModelFormat.Onnx,
|
||||||
|
ModelFormat.Olive,
|
||||||
|
ModelFormat.Diffusers,
|
||||||
|
]:
|
||||||
|
fields["upcast_attention"] = fields.get("upcast_attention") or (
|
||||||
|
fields["base"] == BaseModelType.StableDiffusion2
|
||||||
|
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
|
||||||
|
)
|
||||||
|
|
||||||
|
model_info = ModelConfigFactory.make_config(fields)
|
||||||
|
return model_info
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_name(cls, model_path: Path) -> str:
|
||||||
|
if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
||||||
|
return model_path.stem
|
||||||
|
else:
|
||||||
|
return model_path.name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType:
|
||||||
|
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
||||||
|
raise InvalidModelConfigException(f"{model_path}: unrecognized suffix")
|
||||||
|
|
||||||
|
if model_path.name == "learned_embeds.bin":
|
||||||
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
|
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
||||||
|
ckpt = ckpt.get("state_dict", ckpt)
|
||||||
|
|
||||||
|
for key in ckpt.keys():
|
||||||
|
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||||
|
return ModelType.Main
|
||||||
|
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||||
|
return ModelType.Vae
|
||||||
|
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||||
|
return ModelType.Lora
|
||||||
|
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||||
|
return ModelType.Lora
|
||||||
|
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||||
|
return ModelType.ControlNet
|
||||||
|
elif key in {"emb_params", "string_to_param"}:
|
||||||
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
|
else:
|
||||||
|
# diffusers-ti
|
||||||
|
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||||
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
|
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_type_from_folder(cls, folder_path: Path) -> ModelType:
|
||||||
|
"""Get the model type of a hugging-face style folder."""
|
||||||
|
class_name = None
|
||||||
|
error_hint = None
|
||||||
|
for suffix in ["bin", "safetensors"]:
|
||||||
|
if (folder_path / f"learned_embeds.{suffix}").exists():
|
||||||
|
return ModelType.TextualInversion
|
||||||
|
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
|
||||||
|
return ModelType.Lora
|
||||||
|
if (folder_path / "unet/model.onnx").exists():
|
||||||
|
return ModelType.ONNX
|
||||||
|
if (folder_path / "image_encoder.txt").exists():
|
||||||
|
return ModelType.IPAdapter
|
||||||
|
|
||||||
|
i = folder_path / "model_index.json"
|
||||||
|
c = folder_path / "config.json"
|
||||||
|
config_path = i if i.exists() else c if c.exists() else None
|
||||||
|
|
||||||
|
if config_path:
|
||||||
|
with open(config_path, "r") as file:
|
||||||
|
conf = json.load(file)
|
||||||
|
if "_class_name" in conf:
|
||||||
|
class_name = conf["_class_name"]
|
||||||
|
elif "architectures" in conf:
|
||||||
|
class_name = conf["architectures"][0]
|
||||||
|
else:
|
||||||
|
class_name = None
|
||||||
|
else:
|
||||||
|
error_hint = f"No model_index.json or config.json found in {folder_path}."
|
||||||
|
|
||||||
|
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
||||||
|
return type
|
||||||
|
else:
|
||||||
|
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
|
||||||
|
|
||||||
|
# give up
|
||||||
|
raise InvalidModelConfigException(
|
||||||
|
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_checkpoint_config_path(
|
||||||
|
cls,
|
||||||
|
model_path: Path,
|
||||||
|
model_type: ModelType,
|
||||||
|
base_type: BaseModelType,
|
||||||
|
variant_type: ModelVariantType,
|
||||||
|
prediction_type: SchedulerPredictionType,
|
||||||
|
) -> Path:
|
||||||
|
# look for a YAML file adjacent to the model file first
|
||||||
|
possible_conf = model_path.with_suffix(".yaml")
|
||||||
|
if possible_conf.exists():
|
||||||
|
return possible_conf.absolute()
|
||||||
|
|
||||||
|
if model_type == ModelType.Main:
|
||||||
|
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||||
|
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||||
|
config_file = config_file[prediction_type]
|
||||||
|
elif model_type == ModelType.ControlNet:
|
||||||
|
config_file = (
|
||||||
|
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException(
|
||||||
|
f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}"
|
||||||
|
)
|
||||||
|
assert isinstance(config_file, str)
|
||||||
|
return Path(config_file)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType:
|
||||||
|
with SilenceWarnings():
|
||||||
|
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||||
|
cls._scan_model(model_path.name, model_path)
|
||||||
|
model = torch.load(model_path)
|
||||||
|
assert isinstance(model, dict)
|
||||||
|
return model
|
||||||
|
else:
|
||||||
|
return safetensors.torch.load_file(model_path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _scan_model(cls, model_name: str, checkpoint: Path) -> None:
|
||||||
|
"""
|
||||||
|
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||||
|
and option to exit if an infected file is identified.
|
||||||
|
"""
|
||||||
|
# scan model
|
||||||
|
scan_result = scan_file_path(checkpoint)
|
||||||
|
if scan_result.infected_files != 0:
|
||||||
|
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||||
|
|
||||||
|
|
||||||
|
# ##################################################3
|
||||||
|
# Checkpoint probing
|
||||||
|
# ##################################################3
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointProbeBase(ProbeBase):
|
||||||
|
def __init__(self, model_path: Path):
|
||||||
|
super().__init__(model_path)
|
||||||
|
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||||
|
|
||||||
|
def get_format(self) -> ModelFormat:
|
||||||
|
return ModelFormat("checkpoint")
|
||||||
|
|
||||||
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
|
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
|
||||||
|
if model_type != ModelType.Main:
|
||||||
|
return ModelVariantType.Normal
|
||||||
|
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||||
|
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||||
|
if in_channels == 9:
|
||||||
|
return ModelVariantType.Inpaint
|
||||||
|
elif in_channels == 5:
|
||||||
|
return ModelVariantType.Depth
|
||||||
|
elif in_channels == 4:
|
||||||
|
return ModelVariantType.Normal
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException(
|
||||||
|
f"Cannot determine variant type (in_channels={in_channels}) at {self.model_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
checkpoint = self.checkpoint
|
||||||
|
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||||
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
||||||
|
return BaseModelType.StableDiffusionXLRefiner
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException("Cannot determine base type")
|
||||||
|
|
||||||
|
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||||
|
"""Return model prediction type."""
|
||||||
|
type = self.get_base_type()
|
||||||
|
if type == BaseModelType.StableDiffusion2:
|
||||||
|
checkpoint = self.checkpoint
|
||||||
|
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||||
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||||
|
if "global_step" in checkpoint:
|
||||||
|
if checkpoint["global_step"] == 220000:
|
||||||
|
return SchedulerPredictionType.Epsilon
|
||||||
|
elif checkpoint["global_step"] == 110000:
|
||||||
|
return SchedulerPredictionType.VPrediction
|
||||||
|
return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
|
||||||
|
|
||||||
|
elif type == BaseModelType.StableDiffusion1:
|
||||||
|
return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
|
||||||
|
else:
|
||||||
|
return SchedulerPredictionType.Epsilon
|
||||||
|
|
||||||
|
|
||||||
|
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
# I can't find any standalone 2.X VAEs to test with!
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
|
|
||||||
|
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||||
|
"""Class for LoRA checkpoints."""
|
||||||
|
|
||||||
|
def get_format(self) -> ModelFormat:
|
||||||
|
return ModelFormat("lycoris")
|
||||||
|
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
checkpoint = self.checkpoint
|
||||||
|
token_vector_length = lora_token_vector_length(checkpoint)
|
||||||
|
|
||||||
|
if token_vector_length == 768:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
elif token_vector_length == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
elif token_vector_length == 2048:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException(f"Unknown LoRA type: {self.model_path}")
|
||||||
|
|
||||||
|
|
||||||
|
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||||
|
"""Class for probing embeddings."""
|
||||||
|
|
||||||
|
def get_format(self) -> ModelFormat:
|
||||||
|
return ModelFormat.EmbeddingFile
|
||||||
|
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
checkpoint = self.checkpoint
|
||||||
|
if "string_to_token" in checkpoint:
|
||||||
|
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
|
||||||
|
elif "emb_params" in checkpoint:
|
||||||
|
token_dim = checkpoint["emb_params"].shape[-1]
|
||||||
|
elif "clip_g" in checkpoint:
|
||||||
|
token_dim = checkpoint["clip_g"].shape[-1]
|
||||||
|
else:
|
||||||
|
token_dim = list(checkpoint.values())[0].shape[0]
|
||||||
|
if token_dim == 768:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
elif token_dim == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
elif token_dim == 1280:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type")
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||||
|
"""Class for probing controlnets."""
|
||||||
|
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
checkpoint = self.checkpoint
|
||||||
|
for key_name in (
|
||||||
|
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||||
|
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||||
|
):
|
||||||
|
if key_name not in checkpoint:
|
||||||
|
continue
|
||||||
|
if checkpoint[key_name].shape[-1] == 768:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
elif checkpoint[key_name].shape[-1] == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
raise InvalidModelConfigException("{self.model_path}: Unable to determine base type")
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
########################################################
|
||||||
|
# classes for probing folders
|
||||||
|
#######################################################
|
||||||
|
class FolderProbeBase(ProbeBase):
|
||||||
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
|
def get_format(self) -> ModelFormat:
|
||||||
|
return ModelFormat("diffusers")
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
with open(self.model_path / "unet" / "config.json", "r") as file:
|
||||||
|
unet_conf = json.load(file)
|
||||||
|
if unet_conf["cross_attention_dim"] == 768:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
elif unet_conf["cross_attention_dim"] == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
elif unet_conf["cross_attention_dim"] == 1280:
|
||||||
|
return BaseModelType.StableDiffusionXLRefiner
|
||||||
|
elif unet_conf["cross_attention_dim"] == 2048:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||||
|
|
||||||
|
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||||
|
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
|
||||||
|
scheduler_conf = json.load(file)
|
||||||
|
if scheduler_conf["prediction_type"] == "v_prediction":
|
||||||
|
return SchedulerPredictionType.VPrediction
|
||||||
|
elif scheduler_conf["prediction_type"] == "epsilon":
|
||||||
|
return SchedulerPredictionType.Epsilon
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
|
||||||
|
|
||||||
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
|
# This only works for pipelines! Any kind of
|
||||||
|
# exception results in our returning the
|
||||||
|
# "normal" variant type
|
||||||
|
try:
|
||||||
|
config_file = self.model_path / "unet" / "config.json"
|
||||||
|
with open(config_file, "r") as file:
|
||||||
|
conf = json.load(file)
|
||||||
|
|
||||||
|
in_channels = conf["in_channels"]
|
||||||
|
if in_channels == 9:
|
||||||
|
return ModelVariantType.Inpaint
|
||||||
|
elif in_channels == 5:
|
||||||
|
return ModelVariantType.Depth
|
||||||
|
elif in_channels == 4:
|
||||||
|
return ModelVariantType.Normal
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
|
|
||||||
|
class VaeFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
if self._config_looks_like_sdxl():
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
elif self._name_looks_like_sdxl():
|
||||||
|
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
||||||
|
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
else:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
|
def _config_looks_like_sdxl(self) -> bool:
|
||||||
|
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
||||||
|
config_file = self.model_path / "config.json"
|
||||||
|
if not config_file.exists():
|
||||||
|
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
|
||||||
|
with open(config_file, "r") as file:
|
||||||
|
config = json.load(file)
|
||||||
|
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||||
|
|
||||||
|
def _name_looks_like_sdxl(self) -> bool:
|
||||||
|
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
||||||
|
|
||||||
|
def _guess_name(self) -> str:
|
||||||
|
name = self.model_path.name
|
||||||
|
if name == "vae":
|
||||||
|
name = self.model_path.parent.name
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
class TextualInversionFolderProbe(FolderProbeBase):
|
||||||
|
def get_format(self) -> ModelFormat:
|
||||||
|
return ModelFormat.EmbeddingFolder
|
||||||
|
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
path = self.model_path / "learned_embeds.bin"
|
||||||
|
if not path.exists():
|
||||||
|
raise InvalidModelConfigException(
|
||||||
|
f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file"
|
||||||
|
)
|
||||||
|
return TextualInversionCheckpointProbe(path).get_base_type()
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXFolderProbe(FolderProbeBase):
|
||||||
|
def get_format(self) -> ModelFormat:
|
||||||
|
return ModelFormat("onnx")
|
||||||
|
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
config_file = self.model_path / "config.json"
|
||||||
|
if not config_file.exists():
|
||||||
|
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
|
||||||
|
with open(config_file, "r") as file:
|
||||||
|
config = json.load(file)
|
||||||
|
# no obvious way to distinguish between sd2-base and sd2-768
|
||||||
|
dimension = config["cross_attention_dim"]
|
||||||
|
base_model = (
|
||||||
|
BaseModelType.StableDiffusion1
|
||||||
|
if dimension == 768
|
||||||
|
else (
|
||||||
|
BaseModelType.StableDiffusion2
|
||||||
|
if dimension == 1024
|
||||||
|
else BaseModelType.StableDiffusionXL
|
||||||
|
if dimension == 2048
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not base_model:
|
||||||
|
raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}")
|
||||||
|
return base_model
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
model_file = None
|
||||||
|
for suffix in ["safetensors", "bin"]:
|
||||||
|
base_file = self.model_path / f"pytorch_lora_weights.{suffix}"
|
||||||
|
if base_file.exists():
|
||||||
|
model_file = base_file
|
||||||
|
break
|
||||||
|
if not model_file:
|
||||||
|
raise InvalidModelConfigException("Unknown LoRA format encountered")
|
||||||
|
return LoRACheckpointProbe(model_file).get_base_type()
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterFolderProbe(FolderProbeBase):
|
||||||
|
def get_format(self) -> IPAdapterModelFormat:
|
||||||
|
return IPAdapterModelFormat.InvokeAI.value
|
||||||
|
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
model_file = self.model_path / "ip_adapter.bin"
|
||||||
|
if not model_file.exists():
|
||||||
|
raise InvalidModelConfigException("Unknown IP-Adapter model format.")
|
||||||
|
|
||||||
|
state_dict = torch.load(model_file, map_location="cpu")
|
||||||
|
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||||
|
if cross_attention_dim == 768:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
elif cross_attention_dim == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
elif cross_attention_dim == 2048:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException(
|
||||||
|
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPVisionFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
return BaseModelType.Any
|
||||||
|
|
||||||
|
|
||||||
|
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
config_file = self.model_path / "config.json"
|
||||||
|
if not config_file.exists():
|
||||||
|
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
|
||||||
|
with open(config_file, "r") as file:
|
||||||
|
config = json.load(file)
|
||||||
|
|
||||||
|
adapter_type = config.get("adapter_type", None)
|
||||||
|
if adapter_type == "full_adapter_xl":
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
elif adapter_type == "full_adapter" or "light_adapter":
|
||||||
|
# I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException(
|
||||||
|
f"Unable to determine base model for '{self.model_path}' (adapter_type = {adapter_type})."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
############## register probe classes ######
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||||
|
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
||||||
|
|
||||||
|
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
190
invokeai/backend/model_manager/search.py
Normal file
190
invokeai/backend/model_manager/search.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
||||||
|
"""
|
||||||
|
Abstract base class and implementation for recursive directory search for models.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```
|
||||||
|
from invokeai.backend.model_manager import ModelSearch, ModelProbe
|
||||||
|
|
||||||
|
def find_main_models(model: Path) -> bool:
|
||||||
|
info = ModelProbe.probe(model)
|
||||||
|
if info.model_type == 'main' and info.base_type == 'sd-1':
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
search = ModelSearch(on_model_found=report_it)
|
||||||
|
found = search.search('/tmp/models')
|
||||||
|
print(found) # list of matching model paths
|
||||||
|
print(search.stats) # search stats
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Optional, Set, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
default_logger = InvokeAILogger.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class SearchStats(BaseModel):
|
||||||
|
items_scanned: int = 0
|
||||||
|
models_found: int = 0
|
||||||
|
models_filtered: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSearchBase(ABC, BaseModel):
|
||||||
|
"""
|
||||||
|
Abstract directory traversal model search class
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
search = ModelSearchBase(
|
||||||
|
on_search_started = search_started_callback,
|
||||||
|
on_search_completed = search_completed_callback,
|
||||||
|
on_model_found = model_found_callback,
|
||||||
|
)
|
||||||
|
models_found = search.search('/path/to/directory')
|
||||||
|
"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221
|
||||||
|
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
|
||||||
|
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
|
||||||
|
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
|
||||||
|
logger : InvokeAILogger = Field(default=default_logger, description="Logger instance.") # noqa E221
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search_started(self) -> None:
|
||||||
|
"""
|
||||||
|
Called before the scan starts.
|
||||||
|
|
||||||
|
Passes the root search directory to the Callable `on_search_started`.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def model_found(self, model: Path) -> None:
|
||||||
|
"""
|
||||||
|
Called when a model is found during search.
|
||||||
|
|
||||||
|
:param model: Model to process - could be a directory or checkpoint.
|
||||||
|
|
||||||
|
Passes the model's Path to the Callable `on_model_found`.
|
||||||
|
This Callable receives the path to the model and returns a boolean
|
||||||
|
to indicate whether the model should be returned in the search
|
||||||
|
results.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search_completed(self) -> None:
|
||||||
|
"""
|
||||||
|
Called before the scan starts.
|
||||||
|
|
||||||
|
Passes the Set of found model Paths to the Callable `on_search_completed`.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||||
|
"""
|
||||||
|
Recursively search for models in `directory` and return a set of model paths.
|
||||||
|
|
||||||
|
If provided, the `on_search_started`, `on_model_found` and `on_search_completed`
|
||||||
|
Callables will be invoked during the search.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSearch(ModelSearchBase):
|
||||||
|
"""
|
||||||
|
Implementation of ModelSearch with callbacks.
|
||||||
|
Usage:
|
||||||
|
search = ModelSearch()
|
||||||
|
search.model_found = lambda path : 'anime' in path.as_posix()
|
||||||
|
found = search.list_models(['/tmp/models1','/tmp/models2'])
|
||||||
|
# returns all models that have 'anime' in the path
|
||||||
|
"""
|
||||||
|
|
||||||
|
models_found: Set[Path] = Field(default=None)
|
||||||
|
scanned_dirs: Set[Path] = Field(default=None)
|
||||||
|
pruned_paths: Set[Path] = Field(default=None)
|
||||||
|
|
||||||
|
def search_started(self) -> None:
|
||||||
|
self.models_found = set()
|
||||||
|
self.scanned_dirs = set()
|
||||||
|
self.pruned_paths = set()
|
||||||
|
if self.on_search_started:
|
||||||
|
self.on_search_started(self._directory)
|
||||||
|
|
||||||
|
def model_found(self, model: Path) -> None:
|
||||||
|
self.stats.models_found += 1
|
||||||
|
if not self.on_model_found or self.on_model_found(model):
|
||||||
|
self.stats.models_filtered += 1
|
||||||
|
self.models_found.add(model)
|
||||||
|
|
||||||
|
def search_completed(self) -> None:
|
||||||
|
if self.on_search_completed:
|
||||||
|
self.on_search_completed(self._models_found)
|
||||||
|
|
||||||
|
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||||
|
self._directory = Path(directory)
|
||||||
|
self.stats = SearchStats() # zero out
|
||||||
|
self.search_started() # This will initialize _models_found to empty
|
||||||
|
self._walk_directory(directory)
|
||||||
|
self.search_completed()
|
||||||
|
return self.models_found
|
||||||
|
|
||||||
|
def _walk_directory(self, path: Union[Path, str]) -> None:
|
||||||
|
for root, dirs, files in os.walk(path, followlinks=True):
|
||||||
|
# don't descend into directories that start with a "."
|
||||||
|
# to avoid the Mac .DS_STORE issue.
|
||||||
|
if str(Path(root).name).startswith("."):
|
||||||
|
self.pruned_paths.add(Path(root))
|
||||||
|
if any(Path(root).is_relative_to(x) for x in self.pruned_paths):
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.stats.items_scanned += len(dirs) + len(files)
|
||||||
|
for d in dirs:
|
||||||
|
path = Path(root) / d
|
||||||
|
if path.parent in self.scanned_dirs:
|
||||||
|
self.scanned_dirs.add(path)
|
||||||
|
continue
|
||||||
|
if any(
|
||||||
|
(path / x).exists()
|
||||||
|
for x in [
|
||||||
|
"config.json",
|
||||||
|
"model_index.json",
|
||||||
|
"learned_embeds.bin",
|
||||||
|
"pytorch_lora_weights.bin",
|
||||||
|
"image_encoder.txt",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
self.scanned_dirs.add(path)
|
||||||
|
try:
|
||||||
|
self.model_found(path)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(str(e))
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
path = Path(root) / f
|
||||||
|
if path.parent in self.scanned_dirs:
|
||||||
|
continue
|
||||||
|
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
||||||
|
try:
|
||||||
|
self.model_found(path)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(str(e))
|
@ -11,4 +11,7 @@ from .devices import ( # noqa: F401
|
|||||||
normalize_device,
|
normalize_device,
|
||||||
torch_dtype,
|
torch_dtype,
|
||||||
)
|
)
|
||||||
|
from .logging import InvokeAILogger
|
||||||
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401
|
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401
|
||||||
|
|
||||||
|
__all__ = ["Chdir", "InvokeAILogger", "choose_precision", "choose_torch_device"]
|
||||||
|
@ -164,6 +164,7 @@ nav:
|
|||||||
- Overview: 'contributing/contribution_guides/development.md'
|
- Overview: 'contributing/contribution_guides/development.md'
|
||||||
- New Contributors: 'contributing/contribution_guides/newContributorChecklist.md'
|
- New Contributors: 'contributing/contribution_guides/newContributorChecklist.md'
|
||||||
- InvokeAI Architecture: 'contributing/ARCHITECTURE.md'
|
- InvokeAI Architecture: 'contributing/ARCHITECTURE.md'
|
||||||
|
- Model Manager v2: 'contributing/MODEL_MANAGER.md'
|
||||||
- Frontend Documentation: 'contributing/contribution_guides/contributingToFrontend.md'
|
- Frontend Documentation: 'contributing/contribution_guides/contributingToFrontend.md'
|
||||||
- Local Development: 'contributing/LOCAL_DEVELOPMENT.md'
|
- Local Development: 'contributing/LOCAL_DEVELOPMENT.md'
|
||||||
- Adding Tests: 'contributing/TESTS.md'
|
- Adding Tests: 'contributing/TESTS.md'
|
||||||
|
@ -221,6 +221,8 @@ exclude = [
|
|||||||
# global mypy config
|
# global mypy config
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
ignore_missing_imports = true # ignores missing types in third-party libraries
|
ignore_missing_imports = true # ignores missing types in third-party libraries
|
||||||
|
strict = true
|
||||||
|
exclude = ["tests/*"]
|
||||||
|
|
||||||
# overrides for specific modules
|
# overrides for specific modules
|
||||||
[[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
#!/bin/env python
|
#!/bin/env python
|
||||||
|
|
||||||
|
"""Little command-line utility for probing a model on disk."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from invokeai.backend.model_management.model_probe import ModelProbe
|
from invokeai.backend.model_manager import InvalidModelConfigException, ModelProbe
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Probe model type")
|
parser = argparse.ArgumentParser(description="Probe model type")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -14,5 +16,8 @@ parser.add_argument(
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
for path in args.model_path:
|
for path in args.model_path:
|
||||||
info = ModelProbe().probe(path)
|
try:
|
||||||
print(f"{path}: {info}")
|
info = ModelProbe.probe(path)
|
||||||
|
print(f"{path}:{info.model_dump_json(indent=4)}")
|
||||||
|
except InvalidModelConfigException as exc:
|
||||||
|
print(exc)
|
||||||
|
@ -69,6 +69,7 @@ def mock_services() -> InvocationServices:
|
|||||||
logger=logging, # type: ignore
|
logger=logging, # type: ignore
|
||||||
model_manager=None, # type: ignore
|
model_manager=None, # type: ignore
|
||||||
model_records=None, # type: ignore
|
model_records=None, # type: ignore
|
||||||
|
model_install=None, # type: ignore
|
||||||
names=None, # type: ignore
|
names=None, # type: ignore
|
||||||
performance_statistics=InvocationStatsService(),
|
performance_statistics=InvocationStatsService(),
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
@ -74,6 +74,7 @@ def mock_services() -> InvocationServices:
|
|||||||
logger=logging, # type: ignore
|
logger=logging, # type: ignore
|
||||||
model_manager=None, # type: ignore
|
model_manager=None, # type: ignore
|
||||||
model_records=None, # type: ignore
|
model_records=None, # type: ignore
|
||||||
|
model_install=None, # type: ignore
|
||||||
names=None, # type: ignore
|
names=None, # type: ignore
|
||||||
performance_statistics=InvocationStatsService(),
|
performance_statistics=InvocationStatsService(),
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
@ -12,7 +12,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
prepare_values_to_insert,
|
prepare_values_to_insert,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
|
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
|
||||||
from tests.nodes.test_nodes import PromptTestInvocation
|
from tests.aa_nodes.test_nodes import PromptTestInvocation
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
195
tests/app/services/model_install/test_model_install.py
Normal file
195
tests/app/services/model_install/test_model_install.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
"""
|
||||||
|
Test the model installer
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
|
from invokeai.app.services.model_install import (
|
||||||
|
InstallStatus,
|
||||||
|
LocalModelSource,
|
||||||
|
ModelInstallJob,
|
||||||
|
ModelInstallService,
|
||||||
|
ModelInstallServiceBase,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
|
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_file(datadir: Path) -> Path:
|
||||||
|
return datadir / "test_embedding.safetensors"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app_config(datadir: Path) -> InvokeAIAppConfig:
|
||||||
|
return InvokeAIAppConfig(
|
||||||
|
root=datadir / "root",
|
||||||
|
models_dir=datadir / "root/models",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||||
|
database = SqliteDatabase(app_config, InvokeAILogger.get_logger(config=app_config))
|
||||||
|
store: ModelRecordServiceBase = ModelRecordServiceSQL(database)
|
||||||
|
return store
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def installer(app_config: InvokeAIAppConfig, store: ModelRecordServiceBase) -> ModelInstallServiceBase:
|
||||||
|
return ModelInstallService(
|
||||||
|
app_config=app_config,
|
||||||
|
record_store=store,
|
||||||
|
event_bus=DummyEventService(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyEvent(BaseModel):
|
||||||
|
"""Dummy Event to use with Dummy Event service."""
|
||||||
|
|
||||||
|
event_name: str
|
||||||
|
payload: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class DummyEventService(EventServiceBase):
|
||||||
|
"""Dummy event service for testing."""
|
||||||
|
|
||||||
|
events: List[DummyEvent]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.events = []
|
||||||
|
|
||||||
|
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||||
|
"""Dispatch an event by appending it to self.events."""
|
||||||
|
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||||
|
store = installer.record_store
|
||||||
|
matches = store.search_by_attr(model_name="test_embedding")
|
||||||
|
assert len(matches) == 0
|
||||||
|
key = installer.register_path(test_file)
|
||||||
|
assert key is not None
|
||||||
|
assert len(key) == 32
|
||||||
|
|
||||||
|
|
||||||
|
def test_registration_meta(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||||
|
store = installer.record_store
|
||||||
|
key = installer.register_path(test_file)
|
||||||
|
model_record = store.get_model(key)
|
||||||
|
assert model_record is not None
|
||||||
|
assert model_record.name == "test_embedding"
|
||||||
|
assert model_record.type == ModelType.TextualInversion
|
||||||
|
assert Path(model_record.path) == test_file
|
||||||
|
assert model_record.base == BaseModelType("sd-1")
|
||||||
|
assert model_record.description is not None
|
||||||
|
assert model_record.source is not None
|
||||||
|
assert Path(model_record.source) == test_file
|
||||||
|
|
||||||
|
|
||||||
|
def test_registration_meta_override_fail(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||||
|
key = None
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
key = installer.register_path(test_file, {"name": "banana_sushi", "type": ModelType("lora")})
|
||||||
|
assert key is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_registration_meta_override_succeed(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||||
|
store = installer.record_store
|
||||||
|
key = installer.register_path(
|
||||||
|
test_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash"}
|
||||||
|
)
|
||||||
|
model_record = store.get_model(key)
|
||||||
|
assert model_record.name == "banana_sushi"
|
||||||
|
assert model_record.source == "fake/repo_id"
|
||||||
|
assert model_record.current_hash == "New Hash"
|
||||||
|
|
||||||
|
|
||||||
|
def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
|
||||||
|
store = installer.record_store
|
||||||
|
key = installer.install_path(test_file)
|
||||||
|
model_record = store.get_model(key)
|
||||||
|
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
|
||||||
|
assert model_record.source == test_file.as_posix()
|
||||||
|
|
||||||
|
|
||||||
|
def test_background_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
|
||||||
|
"""Note: may want to break this down into several smaller unit tests."""
|
||||||
|
path = test_file
|
||||||
|
description = "Test of metadata assignment"
|
||||||
|
source = LocalModelSource(path=path, inplace=False)
|
||||||
|
job = installer.import_model(source, config={"description": description})
|
||||||
|
assert job is not None
|
||||||
|
assert isinstance(job, ModelInstallJob)
|
||||||
|
|
||||||
|
# See if job is registered properly
|
||||||
|
assert job in installer.get_job(source)
|
||||||
|
|
||||||
|
# test that the job object tracked installation correctly
|
||||||
|
jobs = installer.wait_for_installs()
|
||||||
|
assert len(jobs) > 0
|
||||||
|
my_job = [x for x in jobs if x.source == source]
|
||||||
|
assert len(my_job) == 1
|
||||||
|
assert my_job[0].status == InstallStatus.COMPLETED
|
||||||
|
|
||||||
|
# test that the expected events were issued
|
||||||
|
bus = installer.event_bus
|
||||||
|
assert bus is not None # sigh - ruff is a stickler for type checking
|
||||||
|
assert isinstance(bus, DummyEventService)
|
||||||
|
assert len(bus.events) == 2
|
||||||
|
event_names = [x.event_name for x in bus.events]
|
||||||
|
assert "model_install_started" in event_names
|
||||||
|
assert "model_install_completed" in event_names
|
||||||
|
assert Path(bus.events[0].payload["source"]) == source
|
||||||
|
assert Path(bus.events[1].payload["source"]) == source
|
||||||
|
key = bus.events[1].payload["key"]
|
||||||
|
assert key is not None
|
||||||
|
|
||||||
|
# see if the thing actually got installed at the expected location
|
||||||
|
model_record = installer.record_store.get_model(key)
|
||||||
|
assert model_record is not None
|
||||||
|
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
|
||||||
|
assert Path(app_config.models_dir / model_record.path).exists()
|
||||||
|
|
||||||
|
# see if metadata was properly passed through
|
||||||
|
assert model_record.description == description
|
||||||
|
|
||||||
|
# see if prune works properly
|
||||||
|
installer.prune_jobs()
|
||||||
|
assert not installer.get_job(source)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
|
||||||
|
store = installer.record_store
|
||||||
|
key = installer.install_path(test_file)
|
||||||
|
model_record = store.get_model(key)
|
||||||
|
assert Path(app_config.models_dir / model_record.path).exists()
|
||||||
|
assert test_file.exists() # original should still be there after installation
|
||||||
|
installer.delete(key)
|
||||||
|
assert not Path(
|
||||||
|
app_config.models_dir / model_record.path
|
||||||
|
).exists() # after deletion, installed copy should not exist
|
||||||
|
assert test_file.exists() # but original should still be there
|
||||||
|
with pytest.raises(UnknownModelException):
|
||||||
|
store.get_model(key)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_register(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
|
||||||
|
store = installer.record_store
|
||||||
|
key = installer.register_path(test_file)
|
||||||
|
model_record = store.get_model(key)
|
||||||
|
assert Path(app_config.models_dir / model_record.path).exists()
|
||||||
|
assert test_file.exists() # original should still be there after installation
|
||||||
|
installer.delete(key)
|
||||||
|
assert Path(app_config.models_dir / model_record.path).exists()
|
||||||
|
with pytest.raises(UnknownModelException):
|
||||||
|
store.get_model(key)
|
@ -0,0 +1 @@
|
|||||||
|
This directory is used by pytest-datadir.
|
@ -0,0 +1,79 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: invokeai.backend.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: invokeai.backend.stable_diffusion.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
personalization_config:
|
||||||
|
target: invokeai.backend.stable_diffusion.embedding_manager.EmbeddingManager
|
||||||
|
params:
|
||||||
|
placeholder_strings: ["*"]
|
||||||
|
initializer_words: ['sculpture']
|
||||||
|
per_image_tokens: false
|
||||||
|
num_vectors_per_token: 1
|
||||||
|
progressive_words: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: invokeai.backend.stable_diffusion.encoders.modules.WeightedFrozenCLIPEmbedder
|
@ -0,0 +1 @@
|
|||||||
|
Dummy file to establish git path.
|
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user