mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
70 Commits
ryan/dense
...
separate-g
Author | SHA1 | Date | |
---|---|---|---|
cd3f5f30dc | |||
71ee28ac12 | |||
46c904d08a | |||
7d5a88b69d | |||
afa4df1991 | |||
e30cb4b52f | |||
ba1f6bf926 | |||
4a9cca6c2d | |||
b0275700b3 | |||
8319aca5f9 | |||
51a604f907 | |||
7515d73628 | |||
2c453aa531 | |||
2cca6e4c76 | |||
ef171e890a | |||
caafbf2f0d | |||
2db5eaf907 | |||
f234bf6256 | |||
cfa78b4052 | |||
ba1dd4b02b | |||
bcf58cac59 | |||
e866d90ab2 | |||
e8797787cf | |||
0082ecb22b | |||
656839fcd1 | |||
99407c899f | |||
48119d9010 | |||
7c9128b253 | |||
4f9bb00275 | |||
78895b3e80 | |||
3030a34b88 | |||
58fa9c2fac | |||
a8b6635050 | |||
6829610a71 | |||
5551cf8ac4 | |||
37b969d339 | |||
c953e61294 | |||
93dd3c848e | |||
02bde7bb75 | |||
3391c19926 | |||
0f60b1ced4 | |||
44c40d7d1a | |||
0b9a212363 | |||
c3aa985c93 | |||
7cb0da1f66 | |||
3534366146 | |||
f2b5f8753f | |||
f13f5984c0 | |||
94e1e64296 | |||
2411bf53c0 | |||
9378e47a06 | |||
4471ea8ad1 | |||
2c835fd550 | |||
61b737bb9f | |||
a8cd3dfc99 | |||
0cce582f2f | |||
4347d1c7f7 | |||
bd4fd9693d | |||
9b40c28144 | |||
16a5d718bf | |||
76cbc745e1 | |||
0a614943f6 | |||
e426096d32 | |||
c561cd751f | |||
af9298f0ef | |||
5b74117836 | |||
38474c9797 | |||
b880a31039 | |||
dd31bc4586 | |||
316573df2d |
@ -32,7 +32,6 @@ model. These are the:
|
||||
Responsible for loading a model from disk
|
||||
into RAM and VRAM and getting it ready for inference.
|
||||
|
||||
|
||||
## Location of the Code
|
||||
|
||||
The four main services can be found in
|
||||
@ -63,23 +62,21 @@ provides the following fields:
|
||||
|----------------|-----------------|------------------|
|
||||
| `key` | str | Unique identifier for the model |
|
||||
| `name` | str | Name of the model (not unique) |
|
||||
| `model_type` | ModelType | The type of the model |
|
||||
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
|
||||
| `base_model` | BaseModelType | The base model that the model is compatible with |
|
||||
| `model_type` | ModelType | The type of the model |
|
||||
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
|
||||
| `base_model` | BaseModelType | The base model that the model is compatible with |
|
||||
| `path` | str | Location of model on disk |
|
||||
| `original_hash` | str | Hash of the model when it was first installed |
|
||||
| `current_hash` | str | Most recent hash of the model's contents |
|
||||
| `hash` | str | Hash of the model |
|
||||
| `description` | str | Human-readable description of the model (optional) |
|
||||
| `source` | str | Model's source URL or repo id (optional) |
|
||||
|
||||
The `key` is a unique 32-character random ID which was generated at
|
||||
install time. The `original_hash` field stores a hash of the model's
|
||||
install time. The `hash` field stores a hash of the model's
|
||||
contents at install time obtained by sampling several parts of the
|
||||
model's files using the `imohash` library. Over the course of the
|
||||
model's lifetime it may be transformed in various ways, such as
|
||||
changing its precision or converting it from a .safetensors to a
|
||||
diffusers model. When this happens, `original_hash` is unchanged, but
|
||||
`current_hash` is updated to indicate the current contents.
|
||||
diffusers model.
|
||||
|
||||
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
|
||||
are defined in `invokeai.backend.model_manager.config`. They are also
|
||||
@ -94,7 +91,6 @@ The `path` field can be absolute or relative. If relative, it is taken
|
||||
to be relative to the `models_dir` setting in the user's
|
||||
`invokeai.yaml` file.
|
||||
|
||||
|
||||
### CheckpointConfig
|
||||
|
||||
This adds support for checkpoint configurations, and adds the
|
||||
@ -174,7 +170,7 @@ store = context.services.model_manager.store
|
||||
or from elsewhere in the code by accessing
|
||||
`ApiDependencies.invoker.services.model_manager.store`.
|
||||
|
||||
### Creating a `ModelRecordService`
|
||||
### Creating a `ModelRecordService`
|
||||
|
||||
To create a new `ModelRecordService` database or open an existing one,
|
||||
you can directly create either a `ModelRecordServiceSQL` or a
|
||||
@ -217,27 +213,27 @@ for use in the InvokeAI web server. Its signature is:
|
||||
```
|
||||
def open(
|
||||
cls,
|
||||
config: InvokeAIAppConfig,
|
||||
conn: Optional[sqlite3.Connection] = None,
|
||||
lock: Optional[threading.Lock] = None
|
||||
config: InvokeAIAppConfig,
|
||||
conn: Optional[sqlite3.Connection] = None,
|
||||
lock: Optional[threading.Lock] = None
|
||||
) -> Union[ModelRecordServiceSQL, ModelRecordServiceFile]:
|
||||
```
|
||||
|
||||
The way it works is as follows:
|
||||
|
||||
1. Retrieve the value of the `model_config_db` option from the user's
|
||||
`invokeai.yaml` config file.
|
||||
`invokeai.yaml` config file.
|
||||
2. If `model_config_db` is `auto` (the default), then:
|
||||
- Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
|
||||
opened on the passed connection and lock.
|
||||
- Open up a new connection to `databases/invokeai.db` if `conn`
|
||||
* Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
|
||||
opened on the passed connection and lock.
|
||||
* Open up a new connection to `databases/invokeai.db` if `conn`
|
||||
and/or `lock` are missing (see note below).
|
||||
3. If `model_config_db` is a Path, then use `from_db_file`
|
||||
to return the appropriate type of ModelRecordService.
|
||||
4. If `model_config_db` is None, then retrieve the legacy
|
||||
`conf_path` option from `invokeai.yaml` and use the Path
|
||||
indicated there. This will default to `configs/models.yaml`.
|
||||
|
||||
|
||||
So a typical startup pattern would be:
|
||||
|
||||
```
|
||||
@ -255,7 +251,7 @@ store = ModelRecordServiceBase.open(config, db_conn, lock)
|
||||
|
||||
Configurations can be retrieved in several ways.
|
||||
|
||||
#### get_model(key) -> AnyModelConfig:
|
||||
#### get_model(key) -> AnyModelConfig
|
||||
|
||||
The basic functionality is to call the record store object's
|
||||
`get_model()` method with the desired model's unique key. It returns
|
||||
@ -272,28 +268,28 @@ print(model_conf.path)
|
||||
If the key is unrecognized, this call raises an
|
||||
`UnknownModelException`.
|
||||
|
||||
#### exists(key) -> AnyModelConfig:
|
||||
#### exists(key) -> AnyModelConfig
|
||||
|
||||
Returns True if a model with the given key exists in the databsae.
|
||||
|
||||
#### search_by_path(path) -> AnyModelConfig:
|
||||
#### search_by_path(path) -> AnyModelConfig
|
||||
|
||||
Returns the configuration of the model whose path is `path`. The path
|
||||
is matched using a simple string comparison and won't correctly match
|
||||
models referred to by different paths (e.g. using symbolic links).
|
||||
|
||||
#### search_by_name(name, base, type) -> List[AnyModelConfig]:
|
||||
#### search_by_name(name, base, type) -> List[AnyModelConfig]
|
||||
|
||||
This method searches for models that match some combination of `name`,
|
||||
`BaseType` and `ModelType`. Calling without any arguments will return
|
||||
all the models in the database.
|
||||
|
||||
#### all_models() -> List[AnyModelConfig]:
|
||||
#### all_models() -> List[AnyModelConfig]
|
||||
|
||||
Return all the model configs in the database. Exactly equivalent to
|
||||
calling `search_by_name()` with no arguments.
|
||||
|
||||
#### search_by_tag(tags) -> List[AnyModelConfig]:
|
||||
#### search_by_tag(tags) -> List[AnyModelConfig]
|
||||
|
||||
`tags` is a list of strings. This method returns a list of model
|
||||
configs that contain all of the given tags. Examples:
|
||||
@ -312,11 +308,11 @@ commercializable_models = [x for x in store.all_models() \
|
||||
if x.license.contains('allowCommercialUse=Sell')]
|
||||
```
|
||||
|
||||
#### version() -> str:
|
||||
#### version() -> str
|
||||
|
||||
Returns the version of the database, currently at `3.2`
|
||||
|
||||
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase:
|
||||
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase
|
||||
|
||||
This method exists to ease the transition from the previous version of
|
||||
the model manager, in which `get_model()` took the three arguments
|
||||
@ -337,7 +333,7 @@ model and pass its key to `get_model()`.
|
||||
Several methods allow you to create and update stored model config
|
||||
records.
|
||||
|
||||
#### add_model(key, config) -> AnyModelConfig:
|
||||
#### add_model(key, config) -> AnyModelConfig
|
||||
|
||||
Given a key and a configuration, this will add the model's
|
||||
configuration record to the database. `config` can either be a subclass of
|
||||
@ -352,7 +348,7 @@ model with the same key is already in the database, or an
|
||||
`InvalidModelConfigException` if a dict was passed and Pydantic
|
||||
experienced a parse or validation error.
|
||||
|
||||
### update_model(key, config) -> AnyModelConfig:
|
||||
### update_model(key, config) -> AnyModelConfig
|
||||
|
||||
Given a key and a configuration, this will update the model
|
||||
configuration record in the database. `config` can be either a
|
||||
@ -370,31 +366,31 @@ The `ModelInstallService` class implements the
|
||||
shop for all your model install needs. It provides the following
|
||||
functionality:
|
||||
|
||||
- Registering a model config record for a model already located on the
|
||||
* Registering a model config record for a model already located on the
|
||||
local filesystem, without moving it or changing its path.
|
||||
|
||||
- Installing a model alreadiy located on the local filesystem, by
|
||||
* Installing a model alreadiy located on the local filesystem, by
|
||||
moving it into the InvokeAI root directory under the
|
||||
`models` folder (or wherever config parameter `models_dir`
|
||||
specifies).
|
||||
|
||||
- Probing of models to determine their type, base type and other key
|
||||
|
||||
* Probing of models to determine their type, base type and other key
|
||||
information.
|
||||
|
||||
- Interface with the InvokeAI event bus to provide status updates on
|
||||
* Interface with the InvokeAI event bus to provide status updates on
|
||||
the download, installation and registration process.
|
||||
|
||||
- Downloading a model from an arbitrary URL and installing it in
|
||||
* Downloading a model from an arbitrary URL and installing it in
|
||||
`models_dir`.
|
||||
|
||||
- Special handling for Civitai model URLs which allow the user to
|
||||
* Special handling for Civitai model URLs which allow the user to
|
||||
paste in a model page's URL or download link
|
||||
|
||||
- Special handling for HuggingFace repo_ids to recursively download
|
||||
* Special handling for HuggingFace repo_ids to recursively download
|
||||
the contents of the repository, paying attention to alternative
|
||||
variants such as fp16.
|
||||
|
||||
- Saving tags and other metadata about the model into the invokeai database
|
||||
* Saving tags and other metadata about the model into the invokeai database
|
||||
when fetching from a repo that provides that type of information,
|
||||
(currently only Civitai and HuggingFace).
|
||||
|
||||
@ -427,8 +423,8 @@ queue.start()
|
||||
|
||||
installer = ModelInstallService(app_config=config,
|
||||
record_store=record_store,
|
||||
download_queue=queue
|
||||
)
|
||||
download_queue=queue
|
||||
)
|
||||
installer.start()
|
||||
```
|
||||
|
||||
@ -443,7 +439,6 @@ required parameters:
|
||||
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|
||||
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
|
||||
|
||||
|
||||
Once initialized, the installer will provide the following methods:
|
||||
|
||||
#### install_job = installer.heuristic_import(source, [config], [access_token])
|
||||
@ -457,15 +452,15 @@ The `source` is a string that can be any of these forms
|
||||
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
|
||||
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
|
||||
3. A HuggingFace repo_id with any of the following formats:
|
||||
- `model/name` -- entire model
|
||||
- `model/name:fp32` -- entire model, using the fp32 variant
|
||||
- `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
||||
- `model/name::vae` -- vae submodel, using default precision
|
||||
- `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
||||
- `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
||||
* `model/name` -- entire model
|
||||
* `model/name:fp32` -- entire model, using the fp32 variant
|
||||
* `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
||||
* `model/name::vae` -- vae submodel, using default precision
|
||||
* `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
||||
* `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
||||
|
||||
Note that by specifying a relative path to the top of the HuggingFace
|
||||
repo, you can download and install arbitrary models files.
|
||||
repo, you can download and install arbitrary models files.
|
||||
|
||||
The variant, if not provided, will be automatically filled in with
|
||||
`fp32` if the user has requested full precision, and `fp16`
|
||||
@ -491,9 +486,9 @@ following illustrates basic usage:
|
||||
|
||||
```
|
||||
from invokeai.app.services.model_install import (
|
||||
LocalModelSource,
|
||||
HFModelSource,
|
||||
URLModelSource,
|
||||
LocalModelSource,
|
||||
HFModelSource,
|
||||
URLModelSource,
|
||||
)
|
||||
|
||||
source1 = LocalModelSource(path='/opt/models/sushi.safetensors') # a local safetensors file
|
||||
@ -513,13 +508,13 @@ for source in [source1, source2, source3, source4, source5, source6, source7]:
|
||||
source2job = installer.wait_for_installs(timeout=120)
|
||||
for source in sources:
|
||||
job = source2job[source]
|
||||
if job.complete:
|
||||
model_config = job.config_out
|
||||
model_key = model_config.key
|
||||
print(f"{source} installed as {model_key}")
|
||||
elif job.errored:
|
||||
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
||||
|
||||
if job.complete:
|
||||
model_config = job.config_out
|
||||
model_key = model_config.key
|
||||
print(f"{source} installed as {model_key}")
|
||||
elif job.errored:
|
||||
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
||||
|
||||
```
|
||||
|
||||
As shown here, the `import_model()` method accepts a variety of
|
||||
@ -528,7 +523,7 @@ HuggingFace repo_ids with and without a subfolder designation,
|
||||
Civitai model URLs and arbitrary URLs that point to checkpoint files
|
||||
(but not to folders).
|
||||
|
||||
Each call to `import_model()` return a `ModelInstallJob` job,
|
||||
Each call to `import_model()` return a `ModelInstallJob` job,
|
||||
an object which tracks the progress of the install.
|
||||
|
||||
If a remote model is requested, the model's files are downloaded in
|
||||
@ -555,7 +550,7 @@ The full list of arguments to `import_model()` is as follows:
|
||||
| `config` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
|
||||
|
||||
The next few sections describe the various types of ModelSource that
|
||||
can be passed to `import_model()`.
|
||||
can be passed to `import_model()`.
|
||||
|
||||
`config` can be used to override all or a portion of the configuration
|
||||
attributes returned by the model prober. See the section below for
|
||||
@ -566,7 +561,6 @@ details.
|
||||
This is used for a model that is located on a locally-accessible Posix
|
||||
filesystem, such as a local disk or networked fileshare.
|
||||
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `path` | str | Path | None | Path to the model file or directory |
|
||||
@ -625,7 +619,6 @@ HuggingFace has the most complicated `ModelSource` structure:
|
||||
| `subfolder` | Path | None | Look for the model in a subfolder of the repo. |
|
||||
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
|
||||
|
||||
|
||||
The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`.
|
||||
|
||||
The `variant` is one of the various diffusers formats that HuggingFace
|
||||
@ -661,7 +654,6 @@ in. To download these files, you must provide an
|
||||
`HfFolder.get_token()` will be called to fill it in with the cached
|
||||
one.
|
||||
|
||||
|
||||
#### Monitoring the install job process
|
||||
|
||||
When you create an install job with `import_model()`, it launches the
|
||||
@ -675,14 +667,13 @@ The `ModelInstallJob` class has the following structure:
|
||||
| `id` | `int` | Integer ID for this job |
|
||||
| `status` | `InstallStatus` | An enum of [`waiting`, `downloading`, `running`, `completed`, `error` and `cancelled`]|
|
||||
| `config_in` | `dict` | Overriding configuration values provided by the caller |
|
||||
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
|
||||
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
|
||||
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
|
||||
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
|
||||
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
|
||||
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
|
||||
| `local_path` | `Path` | If a remote model, holds the path of the model after it is downloaded; if a local model, same as `source` |
|
||||
| `error_type` | `str` | Name of the exception that led to an error status |
|
||||
| `error` | `str` | Traceback of the error |
|
||||
|
||||
|
||||
If the `event_bus` argument was provided, events will also be
|
||||
broadcast to the InvokeAI event bus. The events will appear on the bus
|
||||
as an event of type `EventServiceBase.model_event`, a timestamp and
|
||||
@ -702,14 +693,13 @@ following keys:
|
||||
| `total_bytes` | int | Total size of all the files that make up the model |
|
||||
| `parts` | List[Dict]| Information on the progress of the individual files that make up the model |
|
||||
|
||||
|
||||
The parts is a list of dictionaries that give information on each of
|
||||
the components pieces of the download. The dictionary's keys are
|
||||
`source`, `local_path`, `bytes` and `total_bytes`, and correspond to
|
||||
the like-named keys in the main event.
|
||||
|
||||
Note that downloading events will not be issued for local models, and
|
||||
that downloading events occur *before* the running event.
|
||||
that downloading events occur _before_ the running event.
|
||||
|
||||
##### `model_install_running`
|
||||
|
||||
@ -752,14 +742,13 @@ properties: `waiting`, `downloading`, `running`, `complete`, `errored`
|
||||
and `cancelled`, as well as `in_terminal_state`. The last will return
|
||||
True if the job is in the complete, errored or cancelled states.
|
||||
|
||||
|
||||
#### Model configuration and probing
|
||||
|
||||
The install service uses the `invokeai.backend.model_manager.probe`
|
||||
module during import to determine the model's type, base type, and
|
||||
other configuration parameters. Among other things, it assigns a
|
||||
default name and description for the model based on probed
|
||||
fields.
|
||||
fields.
|
||||
|
||||
When downloading remote models is implemented, additional
|
||||
configuration information, such as list of trigger terms, will be
|
||||
@ -774,11 +763,11 @@ attributes. Here is an example of setting the
|
||||
```
|
||||
install_job = installer.import_model(
|
||||
source=HFModelSource(repo_id='stabilityai/stable-diffusion-2-1',variant='fp32'),
|
||||
config=dict(
|
||||
prediction_type=SchedulerPredictionType('v_prediction')
|
||||
name='stable diffusion 2 base model',
|
||||
)
|
||||
)
|
||||
config=dict(
|
||||
prediction_type=SchedulerPredictionType('v_prediction')
|
||||
name='stable diffusion 2 base model',
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
### Other installer methods
|
||||
@ -862,7 +851,6 @@ This method is similar to `unregister()`, but also unconditionally
|
||||
deletes the corresponding model weights file(s), regardless of whether
|
||||
they are inside or outside the InvokeAI models hierarchy.
|
||||
|
||||
|
||||
#### path = installer.download_and_cache(remote_source, [access_token], [timeout])
|
||||
|
||||
This utility routine will download the model file located at source,
|
||||
@ -953,7 +941,7 @@ following fields:
|
||||
|
||||
When you create a job, you can assign it a `priority`. If multiple
|
||||
jobs are queued, the job with the lowest priority runs first. (Don't
|
||||
blame me! The Unix developers came up with this convention.)
|
||||
blame me! The Unix developers came up with this convention.)
|
||||
|
||||
Every job has a `source` and a `destination`. `source` is a string in
|
||||
the base class, but subclassses redefine it more specifically.
|
||||
@ -974,7 +962,7 @@ is in its lifecycle. Values are defined in the string enum
|
||||
`DownloadJobStatus`, a symbol available from
|
||||
`invokeai.app.services.download_manager`. Possible values are:
|
||||
|
||||
| **Value** | **String Value** | ** Description ** |
|
||||
| **Value** | **String Value** | **Description** |
|
||||
|--------------|---------------------|-------------------|
|
||||
| `IDLE` | idle | Job created, but not submitted to the queue |
|
||||
| `ENQUEUED` | enqueued | Job is patiently waiting on the queue |
|
||||
@ -991,7 +979,7 @@ debugging and performance testing.
|
||||
|
||||
In case of an error, the Exception that caused the error will be
|
||||
placed in the `error` field, and the job's status will be set to
|
||||
`DownloadJobStatus.ERROR`.
|
||||
`DownloadJobStatus.ERROR`.
|
||||
|
||||
After an error occurs, any partially downloaded files will be deleted
|
||||
from disk, unless `preserve_partial_downloads` was set to True at job
|
||||
@ -1040,11 +1028,11 @@ While a job is being downloaded, the queue will emit events at
|
||||
periodic intervals. A typical series of events during a successful
|
||||
download session will look like this:
|
||||
|
||||
- enqueued
|
||||
- running
|
||||
- running
|
||||
- running
|
||||
- completed
|
||||
* enqueued
|
||||
* running
|
||||
* running
|
||||
* running
|
||||
* completed
|
||||
|
||||
There will be a single enqueued event, followed by one or more running
|
||||
events, and finally one `completed`, `error` or `cancelled`
|
||||
@ -1053,12 +1041,12 @@ events.
|
||||
It is possible for a caller to pause download temporarily, in which
|
||||
case the events may look something like this:
|
||||
|
||||
- enqueued
|
||||
- running
|
||||
- running
|
||||
- paused
|
||||
- running
|
||||
- completed
|
||||
* enqueued
|
||||
* running
|
||||
* running
|
||||
* paused
|
||||
* running
|
||||
* completed
|
||||
|
||||
The download queue logs when downloads start and end (unless `quiet`
|
||||
is set to True at initialization time) but doesn't log any progress
|
||||
@ -1120,11 +1108,11 @@ A typical initialization sequence will look like:
|
||||
from invokeai.app.services.download_manager import DownloadQueueService
|
||||
|
||||
def log_download_event(job: DownloadJobBase):
|
||||
logger.info(f'job={job.id}: status={job.status}')
|
||||
logger.info(f'job={job.id}: status={job.status}')
|
||||
|
||||
queue = DownloadQueueService(
|
||||
event_handlers=[log_download_event]
|
||||
)
|
||||
event_handlers=[log_download_event]
|
||||
)
|
||||
```
|
||||
|
||||
Event handlers can be provided to the queue at initialization time as
|
||||
@ -1155,9 +1143,9 @@ To use the former method, follow this example:
|
||||
```
|
||||
job = DownloadJobRemoteSource(
|
||||
source='http://www.civitai.com/models/13456',
|
||||
destination='/tmp/models/',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
)
|
||||
destination='/tmp/models/',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
)
|
||||
queue.submit_download_job(job, start=True)
|
||||
```
|
||||
|
||||
@ -1172,13 +1160,13 @@ To have the queue create the job for you, follow this example instead:
|
||||
```
|
||||
job = queue.create_download_job(
|
||||
source='http://www.civitai.com/models/13456',
|
||||
destdir='/tmp/models/',
|
||||
filename='my_model.safetensors',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
start=True,
|
||||
)
|
||||
destdir='/tmp/models/',
|
||||
filename='my_model.safetensors',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
start=True,
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
The `filename` argument forces the downloader to use the specified
|
||||
name for the file rather than the name provided by the remote source,
|
||||
and is equivalent to manually specifying a destination of
|
||||
@ -1187,7 +1175,6 @@ and is equivalent to manually specifying a destination of
|
||||
Here is the full list of arguments that can be provided to
|
||||
`create_download_job()`:
|
||||
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `source` | Union[str, Path, AnyHttpUrl] | | Download remote or local source |
|
||||
@ -1200,7 +1187,7 @@ Here is the full list of arguments that can be provided to
|
||||
|
||||
Internally, `create_download_job()` has a little bit of internal logic
|
||||
that looks at the type of the source and selects the right subclass of
|
||||
`DownloadJobBase` to create and enqueue.
|
||||
`DownloadJobBase` to create and enqueue.
|
||||
|
||||
**TODO**: move this logic into its own method for overriding in
|
||||
subclasses.
|
||||
@ -1275,7 +1262,7 @@ for getting the model to run. For example "author" is metadata, while
|
||||
"type", "base" and "format" are not. The latter fields are part of the
|
||||
model's config, as defined in `invokeai.backend.model_manager.config`.
|
||||
|
||||
### Example Usage:
|
||||
### Example Usage
|
||||
|
||||
```
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
@ -1328,7 +1315,6 @@ This is the common base class for metadata:
|
||||
| `author` | str | Model's author |
|
||||
| `tags` | Set[str] | Model tags |
|
||||
|
||||
|
||||
Note that the model config record also has a `name` field. It is
|
||||
intended that the config record version be locally customizable, while
|
||||
the metadata version is read-only. However, enforcing this is expected
|
||||
@ -1348,7 +1334,6 @@ This descends from `ModelMetadataBase` and adds the following fields:
|
||||
| `last_modified`| datetime | Date of last commit of this model to the repo |
|
||||
| `files` | List[Path] | List of the files in the model repo |
|
||||
|
||||
|
||||
#### `CivitaiMetadata`
|
||||
|
||||
This descends from `ModelMetadataBase` and adds the following fields:
|
||||
@ -1415,7 +1400,6 @@ testing suite to avoid hitting the internet.
|
||||
The HuggingFace and Civitai fetcher subclasses add additional
|
||||
repo-specific fetching methods:
|
||||
|
||||
|
||||
#### HuggingFaceMetadataFetch
|
||||
|
||||
This overrides its base class `from_json()` method to return a
|
||||
@ -1434,13 +1418,12 @@ retrieves its metadata. Functionally equivalent to `from_id()`, the
|
||||
only difference is that it returna a `CivitaiMetadata` object rather
|
||||
than an `AnyModelRepoMetadata`.
|
||||
|
||||
|
||||
### Metadata Storage
|
||||
|
||||
The `ModelMetadataStore` provides a simple facility to store model
|
||||
metadata in the `invokeai.db` database. The data is stored as a JSON
|
||||
blob, with a few common fields (`name`, `author`, `tags`) broken out
|
||||
to be searchable.
|
||||
to be searchable.
|
||||
|
||||
When a metadata object is saved to the database, it is identified
|
||||
using the model key, _and this key must correspond to an existing
|
||||
@ -1535,16 +1518,16 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
)
|
||||
convert_cache = ModelConvertCache(
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
)
|
||||
loader = ModelLoadService(
|
||||
app_config=config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
registry=ModelLoaderRegistry
|
||||
app_config=config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
registry=ModelLoaderRegistry
|
||||
)
|
||||
```
|
||||
|
||||
@ -1567,7 +1550,6 @@ The returned `LoadedModel` object contains a copy of the configuration
|
||||
record returned by the model record `get_model()` method, as well as
|
||||
the in-memory loaded model:
|
||||
|
||||
|
||||
| **Attribute Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
|
||||
@ -1581,7 +1563,6 @@ return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
|
||||
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
|
||||
models. The others are obvious.
|
||||
|
||||
|
||||
`LoadedModel` acts as a context manager. The context loads the model
|
||||
into the execution device (e.g. VRAM on CUDA systems), locks the model
|
||||
in the execution device for the duration of the context, and returns
|
||||
@ -1590,14 +1571,14 @@ the model. Use it like this:
|
||||
```
|
||||
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||
with model_info as vae:
|
||||
image = vae.decode(latents)[0]
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
`get_model_by_key()` may raise any of the following exceptions:
|
||||
|
||||
- `UnknownModelException` -- key not in database
|
||||
- `ModelNotFoundException` -- key in database but model not found at path
|
||||
- `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
* `UnknownModelException` -- key not in database
|
||||
* `ModelNotFoundException` -- key in database but model not found at path
|
||||
* `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
|
||||
### Emitting model loading events
|
||||
|
||||
@ -1609,15 +1590,15 @@ following payload:
|
||||
|
||||
```
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel_type=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
precision=str(model_info.precision),
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel_type=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
precision=str(model_info.precision),
|
||||
)
|
||||
```
|
||||
|
||||
@ -1724,6 +1705,7 @@ object, or in `context.services.model_manager` from within an
|
||||
invocation.
|
||||
|
||||
In the examples below, we have retrieved the manager using:
|
||||
|
||||
```
|
||||
mm = ApiDependencies.invoker.services.model_manager
|
||||
```
|
||||
|
@ -26,7 +26,6 @@ from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_metadata import ModelMetadataStoreSQL
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
@ -93,10 +92,9 @@ class ApiDependencies:
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
model_metadata_service = ModelMetadataStoreSQL(db=db)
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
|
||||
model_record_service=ModelRecordServiceSQL(db=db),
|
||||
download_queue=download_queue_service,
|
||||
events=events,
|
||||
)
|
||||
|
@ -3,9 +3,7 @@
|
||||
|
||||
import pathlib
|
||||
import shutil
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
@ -14,15 +12,11 @@ from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob
|
||||
from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelRecordOrderBy,
|
||||
ModelSummary,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@ -31,9 +25,6 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
@ -49,15 +40,6 @@ class ModelsList(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class ModelTagSet(BaseModel):
|
||||
"""Return tags for a set of models."""
|
||||
|
||||
key: str
|
||||
name: str
|
||||
author: str
|
||||
tags: Set[str]
|
||||
|
||||
|
||||
##############################################################################
|
||||
# These are example inputs and outputs that are used in places where Swagger
|
||||
# is unable to generate a correct example.
|
||||
@ -68,19 +50,16 @@ example_model_config = {
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "checkpoint",
|
||||
"config": "string",
|
||||
"config_path": "string",
|
||||
"key": "string",
|
||||
"original_hash": "string",
|
||||
"current_hash": "string",
|
||||
"hash": "string",
|
||||
"description": "string",
|
||||
"source": "string",
|
||||
"last_modified": 0,
|
||||
"vae": "string",
|
||||
"converted_at": 0,
|
||||
"variant": "normal",
|
||||
"prediction_type": "epsilon",
|
||||
"repo_variant": "fp16",
|
||||
"upcast_attention": False,
|
||||
"ztsnr_training": False,
|
||||
}
|
||||
|
||||
example_model_input = {
|
||||
@ -89,50 +68,12 @@ example_model_input = {
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "checkpoint",
|
||||
"config": "configs/stable-diffusion/v1-inference.yaml",
|
||||
"config_path": "configs/stable-diffusion/v1-inference.yaml",
|
||||
"description": "Model description",
|
||||
"vae": None,
|
||||
"variant": "normal",
|
||||
}
|
||||
|
||||
example_model_metadata = {
|
||||
"name": "ip_adapter_sd_image_encoder",
|
||||
"author": "InvokeAI",
|
||||
"tags": [
|
||||
"transformers",
|
||||
"safetensors",
|
||||
"clip_vision_model",
|
||||
"endpoints_compatible",
|
||||
"region:us",
|
||||
"has_space",
|
||||
"license:apache-2.0",
|
||||
],
|
||||
"files": [
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
|
||||
"path": "ip_adapter_sd_image_encoder/README.md",
|
||||
"size": 628,
|
||||
"sha256": None,
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
|
||||
"path": "ip_adapter_sd_image_encoder/config.json",
|
||||
"size": 560,
|
||||
"sha256": None,
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
|
||||
"path": "ip_adapter_sd_image_encoder/model.safetensors",
|
||||
"size": 2528373448,
|
||||
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
|
||||
},
|
||||
],
|
||||
"type": "huggingface",
|
||||
"id": "InvokeAI/ip_adapter_sd_image_encoder",
|
||||
"tag_dict": {"license": "apache-2.0"},
|
||||
"last_modified": "2023-09-23T17:33:25Z",
|
||||
}
|
||||
|
||||
##############################################################################
|
||||
# ROUTES
|
||||
##############################################################################
|
||||
@ -212,89 +153,16 @@ async def get_model_record(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
async def list_model_summary(
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of models per page"),
|
||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Gets a page of model summary data."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
return results
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/i/{key}/metadata",
|
||||
operation_id="get_model_metadata",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model metadata was retrieved successfully",
|
||||
"content": {"application/json": {"example": example_model_metadata}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def get_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Get a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}/metadata",
|
||||
operation_id="update_model_metadata",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The model metadata was updated successfully",
|
||||
"content": {"application/json": {"example": example_model_metadata}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def update_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
changes: ModelMetadataChanges = Body(description="The changes"),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Updates or creates a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
|
||||
|
||||
try:
|
||||
original_metadata = record_store.get_metadata(key)
|
||||
if original_metadata:
|
||||
if changes.default_settings:
|
||||
original_metadata.default_settings = changes.default_settings
|
||||
|
||||
metadata_store.update_metadata(key, original_metadata)
|
||||
else:
|
||||
metadata_store.add_metadata(
|
||||
key, BaseMetadata(name="", author="", default_settings=changes.default_settings)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An error occurred while updating the model metadata: {e}",
|
||||
)
|
||||
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags",
|
||||
operation_id="list_tags",
|
||||
)
|
||||
async def list_tags() -> Set[str]:
|
||||
"""Get a unique set of all the model tags."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Set[str] = record_store.list_tags()
|
||||
return result
|
||||
# @model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
# async def list_model_summary(
|
||||
# page: int = Query(default=0, description="The page to get"),
|
||||
# per_page: int = Query(default=10, description="The number of models per page"),
|
||||
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
# ) -> PaginatedResults[ModelSummary]:
|
||||
# """Gets a page of model summary data."""
|
||||
# record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
# return results
|
||||
|
||||
|
||||
class FoundModel(BaseModel):
|
||||
@ -366,19 +234,6 @@ async def scan_for_models(
|
||||
return scan_results
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags/search",
|
||||
operation_id="search_by_metadata_tags",
|
||||
)
|
||||
async def search_by_metadata_tags(
|
||||
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results = record_store.search_by_metadata_tag(tags)
|
||||
return ModelsList(models=results)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}",
|
||||
operation_id="update_model_record",
|
||||
@ -395,15 +250,13 @@ async def search_by_metadata_tags(
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
info: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
|
||||
) -> AnyModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
"""Update a model's config."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
model_response: AnyModelConfig = record_store.update_model(key, config=info)
|
||||
model_response: AnyModelConfig = record_store.update_model(key, changes=changes)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -415,14 +268,14 @@ async def update_model_record(
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="del_model_record",
|
||||
operation_id="delete_model",
|
||||
responses={
|
||||
204: {"description": "Model deleted successfully"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=204,
|
||||
)
|
||||
async def del_model_record(
|
||||
async def delete_model(
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
) -> Response:
|
||||
"""
|
||||
@ -443,42 +296,39 @@ async def del_model_record(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/i/",
|
||||
operation_id="add_model_record",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The model added successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def add_model_record(
|
||||
config: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
) -> AnyModelConfig:
|
||||
"""Add a model using the configuration information appropriate for its type."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
if config.key == "<NOKEY>":
|
||||
config.key = sha1(randbytes(100)).hexdigest()
|
||||
logger.info(f"Created model {config.key} for {config.name}")
|
||||
try:
|
||||
record_store.add_model(config.key, config)
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
# @model_manager_router.post(
|
||||
# "/i/",
|
||||
# operation_id="add_model_record",
|
||||
# responses={
|
||||
# 201: {
|
||||
# "description": "The model added successfully",
|
||||
# "content": {"application/json": {"example": example_model_config}},
|
||||
# },
|
||||
# 409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
# 415: {"description": "Unrecognized file/folder format"},
|
||||
# },
|
||||
# status_code=201,
|
||||
# )
|
||||
# async def add_model_record(
|
||||
# config: Annotated[
|
||||
# AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
# ],
|
||||
# ) -> AnyModelConfig:
|
||||
# """Add a model using the configuration information appropriate for its type."""
|
||||
# logger = ApiDependencies.invoker.services.logger
|
||||
# record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
# try:
|
||||
# record_store.add_model(config)
|
||||
# except DuplicateModelException as e:
|
||||
# logger.error(str(e))
|
||||
# raise HTTPException(status_code=409, detail=str(e))
|
||||
# except InvalidModelException as e:
|
||||
# logger.error(str(e))
|
||||
# raise HTTPException(status_code=415)
|
||||
|
||||
# now fetch it out
|
||||
result: AnyModelConfig = record_store.get_model(config.key)
|
||||
return result
|
||||
# # now fetch it out
|
||||
# result: AnyModelConfig = record_store.get_model(config.key)
|
||||
# return result
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
@ -553,10 +403,10 @@ async def install_model(
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/import",
|
||||
operation_id="list_model_install_jobs",
|
||||
"/install",
|
||||
operation_id="list_model_installs",
|
||||
)
|
||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
async def list_model_installs() -> List[ModelInstallJob]:
|
||||
"""Return the list of model install jobs.
|
||||
|
||||
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
|
||||
@ -570,9 +420,8 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
* "cancelled" -- Job was cancelled before completion.
|
||||
|
||||
Once completed, information about the model such as its size, base
|
||||
model, type, and metadata can be retrieved from the `config_out`
|
||||
field. For multi-file models such as diffusers, information on individual files
|
||||
can be retrieved from `download_parts`.
|
||||
model and type can be retrieved from the `config_out` field. For multi-file models such as diffusers,
|
||||
information on individual files can be retrieved from `download_parts`.
|
||||
|
||||
See the example and schema below for more information.
|
||||
"""
|
||||
@ -581,7 +430,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/import/{id}",
|
||||
"/install/{id}",
|
||||
operation_id="get_model_install_job",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
@ -601,7 +450,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/import/{id}",
|
||||
"/install/{id}",
|
||||
operation_id="cancel_model_install_job",
|
||||
responses={
|
||||
201: {"description": "The job was cancelled successfully"},
|
||||
@ -619,8 +468,8 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
||||
installer.cancel_job(job)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/import",
|
||||
@model_manager_router.delete(
|
||||
"/install",
|
||||
operation_id="prune_model_install_jobs",
|
||||
responses={
|
||||
204: {"description": "All completed and errored jobs have been pruned"},
|
||||
@ -699,7 +548,8 @@ async def convert_model(
|
||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||
original_name = model_config.name
|
||||
model_config.name = f"{original_name}.DELETE"
|
||||
store.update_model(key, config=model_config)
|
||||
changes = ModelRecordChanges(name=model_config.name)
|
||||
store.update_model(key, changes=changes)
|
||||
|
||||
# install the diffusers
|
||||
try:
|
||||
@ -708,7 +558,7 @@ async def convert_model(
|
||||
config={
|
||||
"name": original_name,
|
||||
"description": model_config.description,
|
||||
"original_hash": model_config.original_hash,
|
||||
"hash": model_config.hash,
|
||||
"source": model_config.source,
|
||||
},
|
||||
)
|
||||
@ -716,10 +566,6 @@ async def convert_model(
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
# get the original metadata
|
||||
if orig_metadata := store.get_metadata(key):
|
||||
store.metadata_store.add_metadata(new_key, orig_metadata)
|
||||
|
||||
# delete the original safetensors file
|
||||
installer.delete(key)
|
||||
|
||||
@ -731,66 +577,66 @@ async def convert_model(
|
||||
return new_config
|
||||
|
||||
|
||||
@model_manager_router.put(
|
||||
"/merge",
|
||||
operation_id="merge",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Model converted successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "Model not found"},
|
||||
409: {"description": "There is already a model registered at this location"},
|
||||
},
|
||||
)
|
||||
async def merge(
|
||||
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
force: bool = Body(
|
||||
description="Force merging of models created with different versions of diffusers",
|
||||
default=False,
|
||||
),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||
merge_dest_directory: Optional[str] = Body(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
),
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
||||
```
|
||||
Argument Description [default]
|
||||
-------- ----------------------
|
||||
keys List of 2-3 model keys to merge together. All models must use the same base type.
|
||||
merged_model_name Name for the merged model [Concat model names]
|
||||
alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||
force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||
interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||
merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||
```
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
merger = ModelMerger(installer)
|
||||
model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
response = merger.merge_diffusion_models_and_save(
|
||||
model_keys=keys,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
except UnknownModelException:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"One or more of the models '{keys}' not found",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
# @model_manager_router.put(
|
||||
# "/merge",
|
||||
# operation_id="merge",
|
||||
# responses={
|
||||
# 200: {
|
||||
# "description": "Model converted successfully",
|
||||
# "content": {"application/json": {"example": example_model_config}},
|
||||
# },
|
||||
# 400: {"description": "Bad request"},
|
||||
# 404: {"description": "Model not found"},
|
||||
# 409: {"description": "There is already a model registered at this location"},
|
||||
# },
|
||||
# )
|
||||
# async def merge(
|
||||
# keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
# merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
# alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
# force: bool = Body(
|
||||
# description="Force merging of models created with different versions of diffusers",
|
||||
# default=False,
|
||||
# ),
|
||||
# interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||
# merge_dest_directory: Optional[str] = Body(
|
||||
# description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
# default=None,
|
||||
# ),
|
||||
# ) -> AnyModelConfig:
|
||||
# """
|
||||
# Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
||||
# ```
|
||||
# Argument Description [default]
|
||||
# -------- ----------------------
|
||||
# keys List of 2-3 model keys to merge together. All models must use the same base type.
|
||||
# merged_model_name Name for the merged model [Concat model names]
|
||||
# alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||
# force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||
# interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||
# merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||
# ```
|
||||
# """
|
||||
# logger = ApiDependencies.invoker.services.logger
|
||||
# try:
|
||||
# logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
# dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
# installer = ApiDependencies.invoker.services.model_manager.install
|
||||
# merger = ModelMerger(installer)
|
||||
# model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
# response = merger.merge_diffusion_models_and_save(
|
||||
# model_keys=keys,
|
||||
# merged_model_name=merged_model_name or "+".join(model_names),
|
||||
# alpha=alpha,
|
||||
# interp=interp,
|
||||
# force=force,
|
||||
# merge_dest_directory=dest,
|
||||
# )
|
||||
# except UnknownModelException:
|
||||
# raise HTTPException(
|
||||
# status_code=404,
|
||||
# detail=f"One or more of the models '{keys}' not found",
|
||||
# )
|
||||
# except ValueError as e:
|
||||
# raise HTTPException(status_code=400, detail=str(e))
|
||||
# return response
|
||||
|
@ -133,7 +133,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
submodel_type=SubModelType.VAE,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -85,7 +85,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
submodel_type=SubModelType.VAE,
|
||||
),
|
||||
),
|
||||
)
|
||||
@ -142,7 +142,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
submodel_type=SubModelType.VAE,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -256,6 +256,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development)
|
||||
profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
|
||||
profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development)
|
||||
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce startup time.", json_schema_extra=Categories.Development)
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
|
||||
|
||||
|
@ -18,10 +18,9 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
@ -151,6 +150,13 @@ ModelSource = Annotated[
|
||||
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
|
||||
]
|
||||
|
||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||
URLModelSource: ModelSourceType.Url,
|
||||
HFModelSource: ModelSourceType.HFRepoID,
|
||||
CivitaiModelSource: ModelSourceType.CivitAI,
|
||||
LocalModelSource: ModelSourceType.Path,
|
||||
}
|
||||
|
||||
|
||||
class ModelInstallJob(BaseModel):
|
||||
"""Object that tracks the current status of an install request."""
|
||||
@ -260,7 +266,6 @@ class ModelInstallServiceBase(ABC):
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
metadata_store: ModelMetadataStoreBase,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
):
|
||||
"""
|
||||
@ -347,6 +352,7 @@ class ModelInstallServiceBase(ABC):
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
r"""Install the indicated model using heuristics to interpret user intentions.
|
||||
|
||||
@ -392,7 +398,7 @@ class ModelInstallServiceBase(ABC):
|
||||
will override corresponding autoassigned probe fields in the
|
||||
model's config record. Use it to override
|
||||
`name`, `description`, `base_type`, `model_type`, `format`,
|
||||
`prediction_type`, `image_size`, and/or `ztsnr_training`.
|
||||
`prediction_type`, and/or `image_size`.
|
||||
|
||||
This will download the model located at `source`,
|
||||
probe it, and install it into the models directory.
|
||||
|
@ -20,12 +20,15 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
@ -35,12 +38,14 @@ from invokeai.backend.model_manager.metadata import (
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||
from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
||||
|
||||
from .model_install_base import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
CivitaiModelSource,
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
@ -90,7 +95,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._running = False
|
||||
self._session = session
|
||||
self._next_job_id = 0
|
||||
self._metadata_store = record_store.metadata_store # for convenience
|
||||
|
||||
@property
|
||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||
@ -139,6 +143,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
config["source_type"] = ModelSourceType.Path
|
||||
return self._register(model_path, config)
|
||||
|
||||
def install_path(
|
||||
@ -148,11 +153,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
config["key"] = config.get("key", uuid_string())
|
||||
|
||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||
if self._app_config.skip_model_hash:
|
||||
config["hash"] = uuid_string()
|
||||
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
|
||||
|
||||
if preferred_name := config.get("name"):
|
||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||
@ -178,7 +183,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: bool = False,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
@ -374,15 +379,18 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.bytes = job.total_bytes
|
||||
self._signal_job_running(job)
|
||||
job.config_in["source"] = str(job.source)
|
||||
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||
# enter the metadata, if there is any
|
||||
if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)):
|
||||
job.config_in["source_api_response"] = job.source_metadata.api_response
|
||||
if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_phrases:
|
||||
job.config_in["trigger_phrases"] = job.source_metadata.trigger_phrases
|
||||
|
||||
if job.inplace:
|
||||
key = self.register_path(job.local_path, job.config_in)
|
||||
else:
|
||||
key = self.install_path(job.local_path, job.config_in)
|
||||
job.config_out = self.record_store.get_model(key)
|
||||
|
||||
# enter the metadata, if there is any
|
||||
if job.source_metadata:
|
||||
self._metadata_store.add_metadata(key, job.source_metadata)
|
||||
self._signal_job_completed(job)
|
||||
|
||||
except InvalidModelConfigException as excp:
|
||||
@ -468,7 +476,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||
new_path = self._move_model(old_path, new_path)
|
||||
model.path = new_path.relative_to(models_dir).as_posix()
|
||||
self.record_store.update_model(key, model)
|
||||
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
|
||||
return model
|
||||
|
||||
def _scan_register(self, model: Path) -> bool:
|
||||
@ -520,24 +528,15 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
move(old_path, new_path)
|
||||
return new_path
|
||||
|
||||
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
|
||||
if config: # used to override probe fields
|
||||
for key, value in config.items():
|
||||
setattr(info, key, value)
|
||||
return info
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
# Note that we may be passed a pre-populated AnyModelConfig object,
|
||||
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
|
||||
config["key"] = config.get("key", uuid_string())
|
||||
info = info or ModelProbe.probe(model_path, config)
|
||||
override_key: Optional[str] = config.get("key") if config else None
|
||||
config = config or {}
|
||||
|
||||
assert info.original_hash # always assigned by probe()
|
||||
info.key = override_key or info.original_hash
|
||||
if self._app_config.skip_model_hash:
|
||||
config["hash"] = uuid_string()
|
||||
|
||||
info = info or ModelProbe.probe(model_path, config)
|
||||
|
||||
model_path = model_path.absolute()
|
||||
if model_path.is_relative_to(self.app_config.models_path):
|
||||
@ -546,11 +545,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
info.path = model_path.as_posix()
|
||||
|
||||
# add 'main' specific fields
|
||||
if hasattr(info, "config"):
|
||||
if isinstance(info, CheckpointConfigBase):
|
||||
# make config relative to our root
|
||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
|
||||
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
self.record_store.add_model(info.key, info)
|
||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve()
|
||||
info.config_path = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
self.record_store.add_model(info)
|
||||
return info.key
|
||||
|
||||
def _next_id(self) -> int:
|
||||
@ -571,13 +570,15 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
local_path=Path(source.path),
|
||||
inplace=source.inplace,
|
||||
inplace=source.inplace or False,
|
||||
)
|
||||
|
||||
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
if not source.access_token:
|
||||
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
|
||||
metadata = CivitaiMetadataFetch(self._session).from_id(str(source.version_id))
|
||||
metadata = CivitaiMetadataFetch(self._session, self.app_config.get_config().civitai_api_key).from_id(
|
||||
str(source.version_id)
|
||||
)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
|
||||
@ -605,15 +606,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
# URLs from Civitai or HuggingFace will be handled specially
|
||||
url_patterns = {
|
||||
r"^https?://civitai.com/": CivitaiMetadataFetch,
|
||||
r"^https?://huggingface.co/[^/]+/[^/]+$": HuggingFaceMetadataFetch,
|
||||
}
|
||||
metadata = None
|
||||
for pattern, fetcher in url_patterns.items():
|
||||
if re.match(pattern, str(source.url), re.IGNORECASE):
|
||||
metadata = fetcher(self._session).from_url(source.url)
|
||||
break
|
||||
fetcher = None
|
||||
try:
|
||||
fetcher = self.get_fetcher_from_url(str(source.url))
|
||||
except ValueError:
|
||||
pass
|
||||
kwargs: dict[str, Any] = {"session": self._session}
|
||||
if fetcher is CivitaiMetadataFetch:
|
||||
kwargs["api_key"] = self._app_config.get_config().civitai_api_key
|
||||
if fetcher is not None:
|
||||
metadata = fetcher(**kwargs).from_url(source.url)
|
||||
self._logger.debug(f"metadata={metadata}")
|
||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
@ -628,7 +631,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _import_remote_model(
|
||||
self,
|
||||
source: ModelSource,
|
||||
source: HFModelSource | CivitaiModelSource | URLModelSource,
|
||||
remote_files: List[RemoteModelFile],
|
||||
metadata: Optional[AnyModelRepoMetadata],
|
||||
config: Optional[Dict[str, Any]],
|
||||
@ -656,7 +659,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# In the event that there is a subfolder specified in the source,
|
||||
# we need to remove it from the destination path in order to avoid
|
||||
# creating unwanted subfolders
|
||||
if hasattr(source, "subfolder") and source.subfolder:
|
||||
if isinstance(source, HFModelSource) and source.subfolder:
|
||||
root = Path(remote_files[0].path.parts[0])
|
||||
subfolder = root / source.subfolder
|
||||
else:
|
||||
@ -843,3 +846,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"{job.source}: model installation was cancelled")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_cancelled(str(job.source))
|
||||
|
||||
@staticmethod
|
||||
def get_fetcher_from_url(url: str):
|
||||
if re.match(r"^https?://civitai.com/", url.lower()):
|
||||
return CivitaiMetadataFetch
|
||||
elif re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||
return HuggingFaceMetadataFetch
|
||||
raise ValueError(f"Unsupported model source: '{url}'")
|
||||
|
@ -1,9 +0,0 @@
|
||||
"""Init file for ModelMetadataStoreService module."""
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
from .metadata_store_sql import ModelMetadataStoreSQL
|
||||
|
||||
__all__ = [
|
||||
"ModelMetadataStoreBase",
|
||||
"ModelMetadataStoreSQL",
|
||||
]
|
@ -1,81 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Storage for Model Metadata
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings
|
||||
|
||||
|
||||
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
|
||||
"""A set of changes to apply to model metadata.
|
||||
Only limited changes are valid:
|
||||
- `default_settings`: the user-configured default settings for this model
|
||||
"""
|
||||
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
default=None, description="The user-configured default settings for this model"
|
||||
)
|
||||
"""The user-configured default settings for this model"""
|
||||
|
||||
|
||||
class ModelMetadataStoreBase(ABC):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
@abstractmethod
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
|
||||
@abstractmethod
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
|
||||
@abstractmethod
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
@ -1,223 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise UnknownMetadataException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id,metadata FROM model_metadata;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select tag_text from tags;
|
||||
"""
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.model_id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches if matches else set()
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
if tags:
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
@ -6,20 +6,19 @@ Abstract base class for storing and retrieving model configuration records.
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
@ -60,11 +59,33 @@ class ModelSummary(BaseModel):
|
||||
tags: Set[str] = Field(description="tags associated with model")
|
||||
|
||||
|
||||
class ModelRecordChanges(BaseModelExcludeNull):
|
||||
"""A set of changes to apply to a model."""
|
||||
|
||||
# Changes applicable to all models
|
||||
name: Optional[str] = Field(description="Name of the model.", default=None)
|
||||
path: Optional[str] = Field(description="Path to the model.", default=None)
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
# Checkpoint-specific changes
|
||||
# TODO(MM2): Should we expose these? Feels footgun-y...
|
||||
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
|
||||
prediction_type: Optional[SchedulerPredictionType] = Field(
|
||||
description="The prediction type of the model.", default=None
|
||||
)
|
||||
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
|
||||
|
||||
|
||||
class ModelRecordServiceBase(ABC):
|
||||
"""Abstract base class for storage and retrieval of model configs."""
|
||||
|
||||
@abstractmethod
|
||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@ -88,13 +109,12 @@ class ModelRecordServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
:param key: Unique key for the model to be updated.
|
||||
:param changes: A set of changes to apply to this model. Changes are validated before being written.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -109,40 +129,6 @@ class ModelRecordServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Retrieve metadata (if any) from when model was downloaded from a repo.
|
||||
|
||||
:param key: Model key
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Search model metadata for ones with all listed tags and return their corresponding configs.
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
@ -217,21 +203,3 @@ class ModelRecordServiceBase(ABC):
|
||||
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
|
||||
)
|
||||
return model_configs[0]
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
key: str,
|
||||
new_name: str,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Rename the indicated model. Just a special case of update_model().
|
||||
|
||||
In some implementations, renaming the model may involve changing where
|
||||
it is stored on the filesystem. So this is broken out.
|
||||
|
||||
:param key: Model key
|
||||
:param new_name: New name for model
|
||||
"""
|
||||
config = self.get_model(key)
|
||||
config.name = new_name
|
||||
return self.update_model(key, config)
|
||||
|
@ -43,7 +43,7 @@ import json
|
||||
import sqlite3
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@ -53,12 +53,11 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_records_base import (
|
||||
DuplicateModelException,
|
||||
ModelRecordChanges,
|
||||
ModelRecordOrderBy,
|
||||
ModelRecordServiceBase,
|
||||
ModelSummary,
|
||||
@ -69,7 +68,7 @@ from .model_records_base import (
|
||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
@ -78,14 +77,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = db.conn.cursor()
|
||||
self._metadata_store = metadata_store
|
||||
|
||||
@property
|
||||
def db(self) -> SqliteDatabase:
|
||||
"""Return the underlying database."""
|
||||
return self._db
|
||||
|
||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@ -95,23 +93,19 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_config (
|
||||
INSERT INTO models (
|
||||
id,
|
||||
original_hash,
|
||||
config
|
||||
)
|
||||
VALUES (?,?,?);
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.original_hash,
|
||||
json_serialized,
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
),
|
||||
)
|
||||
self._db.conn.commit()
|
||||
@ -119,12 +113,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._db.conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "model_config.path" in str(e):
|
||||
msg = f"A model with path '{record.path}' is already installed"
|
||||
elif "model_config.name" in str(e):
|
||||
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
else:
|
||||
msg = f"A model with key '{key}' is already installed"
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
else:
|
||||
raise e
|
||||
@ -132,7 +126,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(key)
|
||||
return self.get_model(config.key)
|
||||
|
||||
def del_model(self, key: str) -> None:
|
||||
"""
|
||||
@ -146,7 +140,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_config
|
||||
DELETE FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@ -158,21 +152,20 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
record = self.get_model(key)
|
||||
|
||||
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
||||
for field_name in changes.model_fields_set:
|
||||
setattr(record, field_name, getattr(changes, field_name))
|
||||
|
||||
json_serialized = record.model_dump_json()
|
||||
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_config
|
||||
UPDATE models
|
||||
SET
|
||||
config=?
|
||||
WHERE id=?;
|
||||
@ -199,7 +192,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@ -220,7 +213,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM model_config
|
||||
select count(*) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@ -246,9 +239,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
results = []
|
||||
where_clause = []
|
||||
bindings = []
|
||||
where_clause: list[str] = []
|
||||
bindings: list[str] = []
|
||||
if model_name:
|
||||
where_clause.append("name=?")
|
||||
bindings.append(model_name)
|
||||
@ -265,14 +257,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
select config, strftime('%s',updated_at) FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
{where};
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
result = self._cursor.fetchall()
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
@ -281,7 +272,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
@ -292,13 +283,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated original_hash."""
|
||||
"""Return models with the indicated hash."""
|
||||
results = []
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE original_hash=?;
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
@ -307,83 +298,35 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
]
|
||||
return results
|
||||
|
||||
@property
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
return self._metadata_store
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Retrieve metadata (if any) from when model was downloaded from a repo.
|
||||
|
||||
:param key: Model key
|
||||
"""
|
||||
store = self.metadata_store
|
||||
try:
|
||||
metadata = store.get_metadata(key)
|
||||
return metadata
|
||||
except UnknownMetadataException:
|
||||
return None
|
||||
|
||||
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Search model metadata for ones with all listed tags and return their corresponding configs.
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
keys = store.search_by_tag(tags)
|
||||
return [self.get_model(x) for x in keys]
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_tags()
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_all_metadata()
|
||||
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Return a paginated summary listing of each model in the database."""
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name",
|
||||
ModelRecordOrderBy.Type: "a.type",
|
||||
ModelRecordOrderBy.Base: "a.base",
|
||||
ModelRecordOrderBy.Name: "a.name",
|
||||
ModelRecordOrderBy.Format: "a.format",
|
||||
ModelRecordOrderBy.Default: "type, base, format, name",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]:
|
||||
"""Fix up results so that there are no null values."""
|
||||
result: Dict[str, Union[str, int, Set[str]]] = {}
|
||||
for key, item in summary.items():
|
||||
result[key] = item or ""
|
||||
result["tags"] = set(json.loads(summary["tags"] or "[]"))
|
||||
return result
|
||||
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
with self._db.lock:
|
||||
# query1: get the total number of model configs
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from model_config;
|
||||
select count(*) from models;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(self._cursor.fetchone()[0])
|
||||
|
||||
# query2: fetch key fields from the join of model_config and model_metadata
|
||||
# query2: fetch key fields
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT a.id as key, a.type, a.base, a.format, a.name,
|
||||
json_extract(a.config, '$.description') as description,
|
||||
json_extract(b.metadata, '$.tags') as tags
|
||||
FROM model_config AS a
|
||||
LEFT JOIN model_metadata AS b on a.id=b.id
|
||||
SELECT config
|
||||
FROM models
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
@ -394,7 +337,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
items = [ModelSummary.model_validate(_fixup(dict(x))) for x in rows]
|
||||
items = [ModelSummary.model_validate(dict(x)) for x in rows]
|
||||
return PaginatedResults(
|
||||
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
|
||||
)
|
||||
|
@ -1,6 +1,35 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from threading import Event
|
||||
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
|
||||
|
||||
class SessionRunnerBase(ABC):
|
||||
"""
|
||||
Base class for session runner.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def start(self, services: InvocationServices, cancel_event: Event) -> None:
|
||||
"""Starts the session runner"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run(self, queue_item: SessionQueueItem) -> None:
|
||||
"""Runs the session"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def complete(self, queue_item: SessionQueueItem) -> None:
|
||||
"""Completes the session"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None:
|
||||
"""Runs an already prepared node on the session"""
|
||||
pass
|
||||
|
||||
|
||||
class SessionProcessorBase(ABC):
|
||||
|
@ -2,13 +2,14 @@ import traceback
|
||||
from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
@ -16,15 +17,164 @@ from invokeai.app.services.shared.invocation_context import InvocationContextDat
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
|
||||
from ..invoker import Invoker
|
||||
from .session_processor_base import SessionProcessorBase
|
||||
from .session_processor_base import SessionProcessorBase, SessionRunnerBase
|
||||
from .session_processor_common import SessionProcessorStatus
|
||||
|
||||
|
||||
class DefaultSessionRunner(SessionRunnerBase):
|
||||
"""Processes a single session's invocations"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||
):
|
||||
self.on_before_run_node = on_before_run_node
|
||||
self.on_after_run_node = on_after_run_node
|
||||
|
||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
|
||||
"""Start the session runner"""
|
||||
self.services = services
|
||||
self.cancel_event = cancel_event
|
||||
|
||||
def run(self, queue_item: SessionQueueItem):
|
||||
"""Run the graph"""
|
||||
if not queue_item.session:
|
||||
raise ValueError("Queue item has no session")
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
|
||||
# Prepare the next node
|
||||
invocation = queue_item.session.next()
|
||||
if invocation is None:
|
||||
# If there are no more invocations, complete the graph
|
||||
break
|
||||
# Build invocation context (the node-facing API
|
||||
self.run_node(invocation.id, queue_item)
|
||||
self.complete(queue_item)
|
||||
|
||||
def complete(self, queue_item: SessionQueueItem):
|
||||
"""Complete the graph"""
|
||||
self.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
)
|
||||
|
||||
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run before a node is executed"""
|
||||
# Send starting event
|
||||
self.services.events.emit_invocation_started(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
)
|
||||
if self.on_before_run_node is not None:
|
||||
self.on_before_run_node(invocation, queue_item)
|
||||
|
||||
def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run after a node is executed"""
|
||||
if self.on_after_run_node is not None:
|
||||
self.on_after_run_node(invocation, queue_item)
|
||||
|
||||
def run_node(self, node_id: str, queue_item: SessionQueueItem):
|
||||
"""Run a single node in the graph"""
|
||||
# If this error raises a NodeNotFoundError that's handled by the processor
|
||||
invocation = queue_item.session.execution_graph.get_node(node_id)
|
||||
try:
|
||||
self._on_before_run_node(invocation, queue_item)
|
||||
data = InvocationContextData(
|
||||
invocation=invocation,
|
||||
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
queue_item=queue_item,
|
||||
)
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self.services,
|
||||
cancel_event=self.cancel_event,
|
||||
)
|
||||
|
||||
# Invoke the node
|
||||
outputs = invocation.invoke_internal(context=context, services=self.services)
|
||||
|
||||
# Save outputs and history
|
||||
queue_item.session.complete(invocation.id, outputs)
|
||||
|
||||
self._on_after_run_node(invocation, queue_item)
|
||||
# Send complete event on successful runs
|
||||
self.services.events.emit_invocation_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=data.source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
queue_item.session.set_node_error(invocation.id, error)
|
||||
self.services.logger.error(
|
||||
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self.services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.session_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
|
||||
|
||||
class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
|
||||
"""Processes sessions from the session queue"""
|
||||
|
||||
def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None:
|
||||
super().__init__()
|
||||
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
|
||||
|
||||
def start(
|
||||
self,
|
||||
invoker: Invoker,
|
||||
thread_limit: int = 1,
|
||||
polling_interval: int = 1,
|
||||
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
||||
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
||||
) -> None:
|
||||
self._invoker: Invoker = invoker
|
||||
self._queue_item: Optional[SessionQueueItem] = None
|
||||
self._invocation: Optional[BaseInvocation] = None
|
||||
self.on_before_run_session = on_before_run_session
|
||||
self.on_after_run_session = on_after_run_session
|
||||
|
||||
self._resume_event = ThreadEvent()
|
||||
self._stop_event = ThreadEvent()
|
||||
@ -59,6 +209,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
"cancel_event": self._cancel_event,
|
||||
},
|
||||
)
|
||||
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
@ -117,131 +268,34 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
cancel_event.clear()
|
||||
|
||||
# If we have a on_before_run_session callback, call it
|
||||
if self.on_before_run_session is not None:
|
||||
self.on_before_run_session(self._queue_item)
|
||||
|
||||
# If profiling is enabled, start the profiler
|
||||
if self._profiler is not None:
|
||||
self._profiler.start(profile_id=self._queue_item.session_id)
|
||||
|
||||
# Prepare invocations and take the first
|
||||
self._invocation = self._queue_item.session.next()
|
||||
# Run the graph
|
||||
self.session_runner.run(queue_item=self._queue_item)
|
||||
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while self._invocation is not None and not cancel_event.is_set():
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
||||
|
||||
# Send starting event
|
||||
self._invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session_id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
try:
|
||||
with self._invoker.services.performance_statistics.collect_stats(
|
||||
self._invocation, self._queue_item.session.id
|
||||
):
|
||||
# Build invocation context (the node-facing API)
|
||||
data = InvocationContextData(
|
||||
invocation=self._invocation,
|
||||
source_invocation_id=source_invocation_id,
|
||||
queue_item=self._queue_item,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._invoker.services,
|
||||
cancel_event=self._cancel_event,
|
||||
)
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
|
||||
# Invoke the node
|
||||
outputs = self._invocation.invoke_internal(
|
||||
context=context, services=self._invoker.services
|
||||
)
|
||||
|
||||
# Save outputs and history
|
||||
self._queue_item.session.complete(self._invocation.id, outputs)
|
||||
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
self._queue_item.session.set_node_error(self._invocation.id, error)
|
||||
self._invoker.services.logger.error(
|
||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self._invoker.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=self._queue_item.session_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
pass
|
||||
|
||||
# The session is complete if the all invocations are complete or there was an error
|
||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
)
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
|
||||
# Set the invocation to None to prepare for the next session
|
||||
self._invocation = None
|
||||
else:
|
||||
# Prepare the next invocation
|
||||
self._invocation = self._queue_item.session.next()
|
||||
# If we have a on_after_run_session callback, call it
|
||||
if self.on_after_run_session is not None:
|
||||
self.on_after_run_session(self._queue_item)
|
||||
|
||||
# The session is complete, immediately poll for next session
|
||||
self._queue_item = None
|
||||
@ -275,3 +329,4 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
poll_now_event.clear()
|
||||
self._queue_item = None
|
||||
self._thread_semaphore.release()
|
||||
self._invoker.services.logger.debug("Session processor stopped")
|
||||
|
@ -9,6 +9,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@ -35,6 +36,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_4())
|
||||
migrator.register_migration(build_migration_5())
|
||||
migrator.register_migration(build_migration_6())
|
||||
migrator.register_migration(build_migration_7())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
@ -0,0 +1,88 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration7Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._create_models_table(cursor)
|
||||
self._drop_old_models_tables(cursor)
|
||||
|
||||
def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the old model_records, model_metadata, model_tags and tags tables."""
|
||||
|
||||
tables = ["model_records", "model_metadata", "model_tags", "tags"]
|
||||
|
||||
for table in tables:
|
||||
cursor.execute(f"DROP TABLE IF EXISTS {table};")
|
||||
|
||||
def _create_models_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the v4.0.0 models table."""
|
||||
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS models (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
hash TEXT GENERATED ALWAYS as (json_extract(config, '$.hash')) VIRTUAL NOT NULL,
|
||||
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||
description TEXT GENERATED ALWAYS as (json_extract(config, '$.description')) VIRTUAL,
|
||||
source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL,
|
||||
source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL,
|
||||
source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL,
|
||||
trigger_phrases TEXT GENERATED ALWAYS as (json_extract(config, '$.trigger_phrases')) VIRTUAL,
|
||||
-- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS models_updated_at
|
||||
AFTER UPDATE
|
||||
ON models FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE models SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
# Add indexes for searchable fields
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS base_index ON models(base);",
|
||||
"CREATE INDEX IF NOT EXISTS type_index ON models(type);",
|
||||
"CREATE INDEX IF NOT EXISTS name_index ON models(name);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON models(path);",
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def build_migration_7() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 6 to 7.
|
||||
|
||||
This migration does the following:
|
||||
- Adds the new models table
|
||||
- Drops the old model_records, model_metadata, model_tags and tags tables.
|
||||
- TODO(MM2): Migrates model names and descriptions from `models.yaml` to the new table (?).
|
||||
"""
|
||||
migration_7 = Migration(
|
||||
from_version=6,
|
||||
to_version=7,
|
||||
callback=Migration7Callback(),
|
||||
)
|
||||
|
||||
return migration_7
|
@ -150,7 +150,7 @@ class MigrateModelYamlToDb1:
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.original_hash,
|
||||
record.hash,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
|
@ -17,7 +17,8 @@ class MigrateCallback(Protocol):
|
||||
See :class:`Migration` for an example.
|
||||
"""
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
...
|
||||
|
||||
|
||||
class MigrationError(RuntimeError):
|
||||
|
@ -1,55 +0,0 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from invokeai.app.services.shared.graph import Edge
|
||||
|
||||
|
||||
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
||||
"""
|
||||
Parses raw session string, returning a dict of the graph.
|
||||
|
||||
Only the general graph shape is validated; none of the fields are validated.
|
||||
|
||||
Any `metadata_accumulator` nodes and edges are removed.
|
||||
|
||||
Any validation failure will return None.
|
||||
"""
|
||||
|
||||
graph = json.loads(session_raw).get("graph", None)
|
||||
|
||||
# sanity check make sure the graph is at least reasonably shaped
|
||||
if (
|
||||
not isinstance(graph, dict)
|
||||
or "nodes" not in graph
|
||||
or not isinstance(graph["nodes"], dict)
|
||||
or "edges" not in graph
|
||||
or not isinstance(graph["edges"], list)
|
||||
):
|
||||
# something has gone terribly awry, return an empty dict
|
||||
return None
|
||||
|
||||
try:
|
||||
# delete the `metadata_accumulator` node
|
||||
del graph["nodes"]["metadata_accumulator"]
|
||||
except KeyError:
|
||||
# no accumulator node, all good
|
||||
pass
|
||||
|
||||
# delete any edges to or from it
|
||||
for i, edge in enumerate(graph["edges"]):
|
||||
try:
|
||||
# try to parse the edge
|
||||
Edge(**edge)
|
||||
except ValidationError:
|
||||
# something has gone terribly awry, return an empty dict
|
||||
return None
|
||||
|
||||
if (
|
||||
edge["source"]["node_id"] == "metadata_accumulator"
|
||||
or edge["destination"]["node_id"] == "metadata_accumulator"
|
||||
):
|
||||
del graph["edges"][i]
|
||||
|
||||
return graph
|
@ -25,10 +25,13 @@ from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from diffusers import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
from ..raw_model import RawModel
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
@ -56,8 +59,8 @@ class ModelType(str, Enum):
|
||||
|
||||
ONNX = "onnx"
|
||||
Main = "main"
|
||||
Vae = "vae"
|
||||
Lora = "lora"
|
||||
VAE = "vae"
|
||||
LoRA = "lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
@ -73,9 +76,9 @@ class SubModelType(str, Enum):
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Vae = "vae"
|
||||
VaeDecoder = "vae_decoder"
|
||||
VaeEncoder = "vae_encoder"
|
||||
VAE = "vae"
|
||||
VAEDecoder = "vae_decoder"
|
||||
VAEEncoder = "vae_encoder"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
|
||||
@ -93,8 +96,8 @@ class ModelFormat(str, Enum):
|
||||
|
||||
Diffusers = "diffusers"
|
||||
Checkpoint = "checkpoint"
|
||||
Lycoris = "lycoris"
|
||||
Onnx = "onnx"
|
||||
LyCORIS = "lycoris"
|
||||
ONNX = "onnx"
|
||||
Olive = "olive"
|
||||
EmbeddingFile = "embedding_file"
|
||||
EmbeddingFolder = "embedding_folder"
|
||||
@ -112,128 +115,187 @@ class SchedulerPredictionType(str, Enum):
|
||||
class ModelRepoVariant(str, Enum):
|
||||
"""Various hugging face variants on the diffusers format."""
|
||||
|
||||
DEFAULT = "" # model files without "fp16" or other qualifier - empty str
|
||||
Default = "" # model files without "fp16" or other qualifier - empty str
|
||||
FP16 = "fp16"
|
||||
FP32 = "fp32"
|
||||
ONNX = "onnx"
|
||||
OPENVINO = "openvino"
|
||||
FLAX = "flax"
|
||||
OpenVINO = "openvino"
|
||||
Flax = "flax"
|
||||
|
||||
|
||||
class ModelSourceType(str, Enum):
|
||||
"""Model source type."""
|
||||
|
||||
Path = "path"
|
||||
Url = "url"
|
||||
HFRepoID = "hf_repo_id"
|
||||
CivitAI = "civitai"
|
||||
|
||||
|
||||
class ModelDefaultSettings(BaseModel):
|
||||
vae: str | None
|
||||
vae_precision: str | None
|
||||
scheduler: SCHEDULER_NAME_VALUES | None
|
||||
steps: int | None
|
||||
cfg_scale: float | None
|
||||
cfg_rescale_multiplier: float | None
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
"""Base class for model configuration information."""
|
||||
|
||||
path: str = Field(description="filesystem path to the model file or directory")
|
||||
name: str = Field(description="model name")
|
||||
base: BaseModelType = Field(description="base model")
|
||||
type: ModelType = Field(description="type of the model")
|
||||
format: ModelFormat = Field(description="model format")
|
||||
key: str = Field(description="unique key for model", default="<NOKEY>")
|
||||
original_hash: Optional[str] = Field(
|
||||
description="original fasthash of model contents", default=None
|
||||
) # this is assigned at install time and will not change
|
||||
current_hash: Optional[str] = Field(
|
||||
description="current fasthash of model contents", default=None
|
||||
) # if model is converted or otherwise modified, this will hold updated hash
|
||||
description: Optional[str] = Field(description="human readable description of the model", default=None)
|
||||
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
|
||||
last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time)
|
||||
key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
|
||||
hash: str = Field(description="The hash of the model file(s).")
|
||||
path: str = Field(
|
||||
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
|
||||
)
|
||||
name: str = Field(description="Name of the model.")
|
||||
base: BaseModelType = Field(description="The base model.")
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
source: str = Field(description="The original source of the model (path, URL or repo_id).")
|
||||
source_type: ModelSourceType = Field(description="The type of source")
|
||||
source_api_response: Optional[str] = Field(
|
||||
description="The original API response from the source, as stringified JSON.", default=None
|
||||
)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
schema["required"].extend(
|
||||
["key", "base", "type", "format", "original_hash", "current_hash", "source", "last_modified"]
|
||||
)
|
||||
schema["required"].extend(["key", "type", "format"])
|
||||
|
||||
model_config = ConfigDict(
|
||||
use_enum_values=False,
|
||||
validate_assignment=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
)
|
||||
|
||||
def update(self, attributes: Dict[str, Any]) -> None:
|
||||
"""Update the object with fields in dict."""
|
||||
for key, value in attributes.items():
|
||||
setattr(self, key, value) # may raise a validation error
|
||||
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
||||
|
||||
|
||||
class _CheckpointConfig(ModelConfigBase):
|
||||
class CheckpointConfigBase(ModelConfigBase):
|
||||
"""Model config for checkpoint-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
config: str = Field(description="path to the checkpoint model config file")
|
||||
config_path: str = Field(description="path to the checkpoint model config file")
|
||||
converted_at: Optional[float] = Field(
|
||||
description="When this model was last converted to diffusers", default_factory=time.time
|
||||
)
|
||||
|
||||
|
||||
class _DiffusersConfig(ModelConfigBase):
|
||||
class DiffusersConfigBase(ModelConfigBase):
|
||||
"""Model config for diffusers-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
|
||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
|
||||
|
||||
|
||||
class LoRAConfig(ModelConfigBase):
|
||||
class LoRALyCORISConfig(ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
type: Literal[ModelType.Lora] = ModelType.Lora
|
||||
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
|
||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
|
||||
class VaeCheckpointConfig(ModelConfigBase):
|
||||
class LoRADiffusersConfig(ModelConfigBase):
|
||||
"""Model config for LoRA/Diffusers models."""
|
||||
|
||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class VAECheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
class VaeDiffusersConfig(ModelConfigBase):
|
||||
|
||||
class VAEDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for standalone VAE models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
class ControlNetDiffusersConfig(_DiffusersConfig):
|
||||
|
||||
class ControlNetDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
class ControlNetCheckpointConfig(_CheckpointConfig):
|
||||
|
||||
class ControlNetCheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
class TextualInversionConfig(ModelConfigBase):
|
||||
|
||||
class TextualInversionFileConfig(ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
|
||||
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
|
||||
|
||||
|
||||
class _MainConfig(ModelConfigBase):
|
||||
"""Model config for main models."""
|
||||
class TextualInversionFolderConfig(ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
vae: Optional[str] = Field(default=None)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
ztsnr_training: bool = False
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
|
||||
|
||||
|
||||
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
||||
class MainCheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||
class MainDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class IPAdapterConfig(ModelConfigBase):
|
||||
"""Model config for IP Adaptor format models."""
|
||||
@ -242,6 +304,10 @@ class IPAdapterConfig(ModelConfigBase):
|
||||
image_encoder_model_id: str
|
||||
format: Literal[ModelFormat.InvokeAI]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for ClipVision."""
|
||||
@ -249,58 +315,65 @@ class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
class T2IConfig(ModelConfigBase):
|
||||
|
||||
class T2IAdapterConfig(ModelConfigBase):
|
||||
"""Model config for T2I."""
|
||||
|
||||
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
_ControlNetConfig = Annotated[
|
||||
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
|
||||
Field(discriminator="format"),
|
||||
]
|
||||
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
|
||||
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
|
||||
|
||||
AnyModelConfig = Union[
|
||||
_MainModelConfig,
|
||||
_VaeConfig,
|
||||
_ControlNetConfig,
|
||||
# ModelConfigBase,
|
||||
LoRAConfig,
|
||||
TextualInversionConfig,
|
||||
IPAdapterConfig,
|
||||
CLIPVisionDiffusersConfig,
|
||||
T2IConfig,
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
"""
|
||||
Computes the discriminator value for a model config.
|
||||
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
|
||||
"""
|
||||
format_ = None
|
||||
type_ = None
|
||||
if isinstance(v, dict):
|
||||
format_ = v.get("format")
|
||||
if isinstance(format_, Enum):
|
||||
format_ = format_.value
|
||||
type_ = v.get("type")
|
||||
if isinstance(type_, Enum):
|
||||
type_ = type_.value
|
||||
else:
|
||||
format_ = v.format.value
|
||||
type_ = v.type.value
|
||||
v = f"{type_}.{format_}"
|
||||
return v
|
||||
|
||||
|
||||
AnyModelConfig = Annotated[
|
||||
Union[
|
||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
||||
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
|
||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
|
||||
# IMPLEMENTATION NOTE:
|
||||
# The preferred alternative to the above is a discriminated Union as shown
|
||||
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
|
||||
# This is a known issue. Please see:
|
||||
# https://github.com/tiangolo/fastapi/discussions/9761 and
|
||||
# https://github.com/tiangolo/fastapi/discussions/9287
|
||||
# AnyModelConfig = Annotated[
|
||||
# Union[
|
||||
# _MainModelConfig,
|
||||
# _ONNXConfig,
|
||||
# _VaeConfig,
|
||||
# _ControlNetConfig,
|
||||
# LoRAConfig,
|
||||
# TextualInversionConfig,
|
||||
# IPAdapterConfig,
|
||||
# CLIPVisionDiffusersConfig,
|
||||
# T2IConfig,
|
||||
# ],
|
||||
# Field(discriminator="type"),
|
||||
# ]
|
||||
|
||||
|
||||
class ModelConfigFactory(object):
|
||||
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
||||
|
||||
@ -332,6 +405,6 @@ class ModelConfigFactory(object):
|
||||
assert model is not None
|
||||
if key:
|
||||
model.key = key
|
||||
if timestamp:
|
||||
model.last_modified = timestamp
|
||||
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
|
||||
model.converted_at = timestamp
|
||||
return model # type: ignore
|
||||
|
@ -13,6 +13,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelRepoVariant,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
@ -50,7 +51,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
if model_config.type == "main" and not submodel_type:
|
||||
if model_config.type is ModelType.Main and not submodel_type:
|
||||
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
||||
|
||||
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
|
||||
@ -80,7 +81,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||
return self._convert_model(config, model_path, cache_path)
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
def _load_if_needed(
|
||||
@ -119,7 +120,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
return calc_model_size_by_fs(
|
||||
model_path=model_path,
|
||||
subfolder=submodel_type.value if submodel_type else None,
|
||||
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
|
||||
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
|
||||
)
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
|
@ -15,10 +15,8 @@ Use like this:
|
||||
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional, Tuple, Type
|
||||
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
|
||||
|
||||
from ..config import (
|
||||
AnyModelConfig,
|
||||
@ -27,8 +25,6 @@ from ..config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
VaeCheckpointConfig,
|
||||
VaeDiffusersConfig,
|
||||
)
|
||||
from . import ModelLoaderBase
|
||||
|
||||
@ -61,6 +57,9 @@ class ModelLoaderRegistryBase(ABC):
|
||||
"""
|
||||
|
||||
|
||||
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
|
||||
|
||||
|
||||
class ModelLoaderRegistry:
|
||||
"""
|
||||
This class allows model loaders to register their type, base and format.
|
||||
@ -71,10 +70,10 @@ class ModelLoaderRegistry:
|
||||
@classmethod
|
||||
def register(
|
||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
||||
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
|
||||
"""Define a decorator which registers the subclass of loader."""
|
||||
|
||||
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
|
||||
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
|
||||
key = cls._to_registry_key(base, type, format)
|
||||
if key in cls._registry:
|
||||
raise Exception(
|
||||
@ -90,33 +89,15 @@ class ModelLoaderRegistry:
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
|
||||
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
|
||||
|
||||
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
|
||||
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
|
||||
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
|
||||
key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
|
||||
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
||||
if not implementation:
|
||||
raise NotImplementedError(
|
||||
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
||||
)
|
||||
return implementation, conf2, submodel_type
|
||||
|
||||
@classmethod
|
||||
def _handle_subtype_overrides(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
|
||||
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
|
||||
model_path = Path(config.vae)
|
||||
config_class = (
|
||||
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
|
||||
)
|
||||
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
|
||||
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
|
||||
submodel_type = None
|
||||
else:
|
||||
new_conf = config
|
||||
return new_conf, submodel_type
|
||||
return implementation, config, submodel_type
|
||||
|
||||
@staticmethod
|
||||
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
|
||||
|
@ -3,8 +3,8 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
@ -12,6 +12,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
@ -20,15 +21,15 @@ from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||
class ControlnetLoader(GenericDiffusersLoader):
|
||||
class ControlNetLoader(GenericDiffusersLoader):
|
||||
"""Class to load ControlNet models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
@ -37,13 +38,13 @@ class ControlnetLoader(GenericDiffusersLoader):
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
raise Exception(f"ControlNet conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
assert hasattr(config, "config")
|
||||
config_file = config.config
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = config.config_path
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
|
@ -3,9 +3,10 @@
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
@ -41,6 +42,7 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
# TO DO: Add exception handling
|
||||
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
||||
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
|
||||
result = None
|
||||
if submodel_type:
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
||||
@ -64,6 +66,7 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||
assert result is not None
|
||||
return result
|
||||
|
||||
# TO DO: Add exception handling
|
||||
@ -75,7 +78,7 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
result: ModelMixin = getattr(res_type, class_name)
|
||||
return result
|
||||
|
||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
|
||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> dict[str, Any]:
|
||||
return ConfigLoader.load_config(model_path, config_name=config_name)
|
||||
|
||||
|
||||
@ -83,8 +86,8 @@ class ConfigLoader(ConfigMixin):
|
||||
"""Subclass of ConfigMixin for loading diffusers configuration files."""
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
def load_config(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # pyright: ignore [reportIncompatibleMethodOverride]
|
||||
"""Load a diffusrs ConfigMixin configuration."""
|
||||
cls.config_name = kwargs.pop("config_name")
|
||||
# Diffusers doesn't provide typing info
|
||||
# TODO(psyche): the types on this diffusers method are not correct
|
||||
return super().load_config(*args, **kwargs) # type: ignore
|
||||
|
@ -31,7 +31,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||
model = build_ip_adapter(
|
||||
ip_adapter_ckpt_path=model_path / "ip_adapter.bin",
|
||||
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
|
||||
device=torch.device("cpu"),
|
||||
dtype=self._torch_dtype,
|
||||
)
|
||||
|
@ -22,8 +22,8 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
||||
class LoraLoader(ModelLoader):
|
||||
"""Class to load LoRA models."""
|
||||
|
||||
|
@ -18,7 +18,7 @@ from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.ONNX)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
|
||||
class OnnyxDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load onnx models."""
|
||||
|
@ -4,7 +4,8 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
@ -16,7 +17,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelVariantType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import MainCheckpointConfig
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
@ -54,11 +55,11 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
return result
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
@ -73,7 +74,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
|
||||
)
|
||||
|
||||
config_file = config.config
|
||||
config_file = config.config_path
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
convert_ckpt_to_diffusers(
|
||||
|
@ -3,9 +3,9 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
@ -13,24 +13,25 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
class VaeLoader(GenericDiffusersLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
@ -38,16 +39,15 @@ class VaeLoader(GenericDiffusersLoader):
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
# TO DO: check whether sdxl VAE models convert.
|
||||
# TODO(MM2): check whether sdxl VAE models convert.
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
raise Exception(f"VAE conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
config_file = (
|
||||
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
|
||||
)
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = config.config_path
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
@ -55,7 +55,7 @@ class VaeLoader(GenericDiffusersLoader):
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
|
||||
ckpt_config = OmegaConf.load(self._app_config.root_path / config_file)
|
||||
assert isinstance(ckpt_config, DictConfig)
|
||||
|
||||
vae_model = convert_ldm_vae_to_diffusers(
|
||||
|
@ -16,6 +16,7 @@ from diffusers import AutoPipelineForText2Image
|
||||
from diffusers.utils import logging as dlogging
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
|
||||
from . import (
|
||||
@ -117,7 +118,6 @@ class ModelMerger(object):
|
||||
config = self._installer.app_config
|
||||
store = self._installer.record_store
|
||||
base_models: Set[BaseModelType] = set()
|
||||
vae = None
|
||||
variant = None if self._installer.app_config.full_precision else "fp16"
|
||||
|
||||
assert (
|
||||
@ -134,10 +134,6 @@ class ModelMerger(object):
|
||||
"normal"
|
||||
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
|
||||
|
||||
# pick up the first model's vae
|
||||
if key == model_keys[0]:
|
||||
vae = info.vae
|
||||
|
||||
# tally base models used
|
||||
base_models.add(info.base)
|
||||
model_paths.extend([config.models_path / info.path])
|
||||
@ -163,12 +159,10 @@ class ModelMerger(object):
|
||||
|
||||
# update model's config
|
||||
model_config = self._installer.record_store.get_model(key)
|
||||
model_config.update(
|
||||
{
|
||||
"name": merged_model_name,
|
||||
"description": f"Merge of models {', '.join(model_names)}",
|
||||
"vae": vae,
|
||||
}
|
||||
model_config.name = merged_model_name
|
||||
model_config.description = f"Merge of models {', '.join(model_names)}"
|
||||
|
||||
self._installer.record_store.update_model(
|
||||
key, ModelRecordChanges(name=model_config.name, description=model_config.description)
|
||||
)
|
||||
self._installer.record_store.update_model(key, model_config)
|
||||
return model_config
|
||||
|
@ -25,9 +25,7 @@ from .metadata_base import (
|
||||
AnyModelRepoMetadataValidator,
|
||||
BaseMetadata,
|
||||
CivitaiMetadata,
|
||||
CommercialUsage,
|
||||
HuggingFaceMetadata,
|
||||
LicenseRestrictions,
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
@ -38,10 +36,8 @@ __all__ = [
|
||||
"AnyModelRepoMetadataValidator",
|
||||
"CivitaiMetadata",
|
||||
"CivitaiMetadataFetch",
|
||||
"CommercialUsage",
|
||||
"HuggingFaceMetadata",
|
||||
"HuggingFaceMetadataFetch",
|
||||
"LicenseRestrictions",
|
||||
"ModelMetadataFetchBase",
|
||||
"BaseMetadata",
|
||||
"ModelMetadataWithFiles",
|
||||
|
@ -23,22 +23,21 @@ metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
||||
print(metadata.trained_words)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
|
||||
from ..metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadata,
|
||||
CommercialUsage,
|
||||
LicenseRestrictions,
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
@ -52,10 +51,13 @@ CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
||||
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||
|
||||
|
||||
StringSetAdapter = TypeAdapter(set[str])
|
||||
|
||||
|
||||
class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
"""Fetch model metadata from Civitai."""
|
||||
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
@ -63,6 +65,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
this module without an internet connection.
|
||||
"""
|
||||
self._requests = session or requests.Session()
|
||||
self._api_key = api_key
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
@ -102,22 +105,21 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
May raise an `UnknownMetadataException`.
|
||||
"""
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(model_url).json()
|
||||
return self._from_model_json(model_json)
|
||||
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
||||
return self._from_api_response(model_json)
|
||||
|
||||
def _from_model_json(self, model_json: Dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
try:
|
||||
version_id = version_id or model_json["modelVersions"][0]["id"]
|
||||
version_id = version_id or api_response["modelVersions"][0]["id"]
|
||||
except TypeError as excp:
|
||||
raise UnknownMetadataException from excp
|
||||
|
||||
# loop till we find the section containing the version requested
|
||||
version_sections = [x for x in model_json["modelVersions"] if x["id"] == version_id]
|
||||
version_sections = [x for x in api_response["modelVersions"] if x["id"] == version_id]
|
||||
if not version_sections:
|
||||
raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
|
||||
|
||||
version_json = version_sections[0]
|
||||
safe_thumbnails = [x["url"] for x in version_json["images"] if x["nsfw"] == "None"]
|
||||
|
||||
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
|
||||
primary = [x for x in version_json["files"] if x.get("primary")]
|
||||
@ -134,36 +136,23 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
url = url + f"?type={primary_file['type']}{metadata_string}"
|
||||
model_files = [
|
||||
RemoteModelFile(
|
||||
url=url,
|
||||
url=self._get_url_with_api_key(url),
|
||||
path=Path(primary_file["name"]),
|
||||
size=int(primary_file["sizeKB"] * 1024),
|
||||
sha256=primary_file["hashes"]["SHA256"],
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
trigger_phrases = StringSetAdapter.validate_python(version_json.get("trainedWords"))
|
||||
except ValidationError:
|
||||
trigger_phrases: set[str] = set()
|
||||
|
||||
return CivitaiMetadata(
|
||||
id=model_json["id"],
|
||||
name=version_json["name"],
|
||||
version_id=version_json["id"],
|
||||
version_name=version_json["name"],
|
||||
created=datetime.fromisoformat(_fix_timezone(version_json["createdAt"])),
|
||||
updated=datetime.fromisoformat(_fix_timezone(version_json["updatedAt"])),
|
||||
published=datetime.fromisoformat(_fix_timezone(version_json["publishedAt"])),
|
||||
base_model_trained_on=version_json["baseModel"], # note - need a dictionary to turn into a BaseModelType
|
||||
files=model_files,
|
||||
download_url=version_json["downloadUrl"],
|
||||
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
|
||||
author=model_json["creator"]["username"],
|
||||
description=model_json["description"],
|
||||
version_description=version_json["description"] or "",
|
||||
tags=model_json["tags"],
|
||||
trained_words=version_json["trainedWords"],
|
||||
nsfw=model_json["nsfw"],
|
||||
restrictions=LicenseRestrictions(
|
||||
AllowNoCredit=model_json["allowNoCredit"],
|
||||
AllowCommercialUse={CommercialUsage(x) for x in model_json["allowCommercialUse"]},
|
||||
AllowDerivatives=model_json["allowDerivatives"],
|
||||
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
||||
),
|
||||
trigger_phrases=trigger_phrases,
|
||||
api_response=json.dumps(version_json),
|
||||
)
|
||||
|
||||
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
@ -174,14 +163,14 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
"""
|
||||
if model_id is None:
|
||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||
version = self._requests.get(version_url).json()
|
||||
version = self._requests.get(self._get_url_with_api_key(version_url)).json()
|
||||
if error := version.get("error"):
|
||||
raise UnknownMetadataException(error)
|
||||
model_id = version["modelId"]
|
||||
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(model_url).json()
|
||||
return self._from_model_json(model_json, version_id)
|
||||
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
||||
return self._from_api_response(model_json, version_id)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: str) -> CivitaiMetadata:
|
||||
@ -189,6 +178,11 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
metadata = CivitaiMetadata.model_validate_json(json)
|
||||
return metadata
|
||||
|
||||
def _get_url_with_api_key(self, url: str) -> str:
|
||||
if not self._api_key:
|
||||
return url
|
||||
|
||||
def _fix_timezone(date: str) -> str:
|
||||
return re.sub(r"Z$", "+00:00", date)
|
||||
if "?" in url:
|
||||
return f"{url}&token={self._api_key}"
|
||||
|
||||
return f"{url}?token={self._api_key}"
|
||||
|
@ -13,6 +13,7 @@ metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
|
||||
print(metadata.tags)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@ -23,7 +24,7 @@ from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFo
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
|
||||
from ..metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
@ -60,6 +61,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
# Little loop which tries fetching a revision corresponding to the selected variant.
|
||||
# If not available, then set variant to None and get the default.
|
||||
# If this too fails, raise exception.
|
||||
|
||||
model_info = None
|
||||
while not model_info:
|
||||
try:
|
||||
@ -72,23 +74,24 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
else:
|
||||
variant = None
|
||||
|
||||
files: list[RemoteModelFile] = []
|
||||
|
||||
_, name = id.split("/")
|
||||
return HuggingFaceMetadata(
|
||||
id=model_info.id,
|
||||
author=model_info.author,
|
||||
name=name,
|
||||
last_modified=model_info.last_modified,
|
||||
tag_dict=model_info.card_data.to_dict() if model_info.card_data else {},
|
||||
tags=model_info.tags,
|
||||
files=[
|
||||
|
||||
for s in model_info.siblings or []:
|
||||
assert s.rfilename is not None
|
||||
assert s.size is not None
|
||||
files.append(
|
||||
RemoteModelFile(
|
||||
url=hf_hub_url(id, x.rfilename, revision=variant),
|
||||
path=Path(name, x.rfilename),
|
||||
size=x.size,
|
||||
sha256=x.lfs.get("sha256") if x.lfs else None,
|
||||
url=hf_hub_url(id, s.rfilename, revision=variant),
|
||||
path=Path(name, s.rfilename),
|
||||
size=s.size,
|
||||
sha256=s.lfs.get("sha256") if s.lfs else None,
|
||||
)
|
||||
for x in model_info.siblings
|
||||
],
|
||||
)
|
||||
|
||||
return HuggingFaceMetadata(
|
||||
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str)
|
||||
)
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
|
@ -14,10 +14,8 @@ versions of these fields are intended to be kept in sync with the
|
||||
remote repo.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from huggingface_hub import configure_http_backend, hf_hub_url
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
@ -25,7 +23,6 @@ from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
|
||||
from ..util import select_hf_files
|
||||
@ -35,31 +32,6 @@ class UnknownMetadataException(Exception):
|
||||
"""Raised when no metadata is available for a model."""
|
||||
|
||||
|
||||
class CommercialUsage(str, Enum):
|
||||
"""Type of commercial usage allowed."""
|
||||
|
||||
No = "None"
|
||||
Image = "Image"
|
||||
Rent = "Rent"
|
||||
RentCivit = "RentCivit"
|
||||
Sell = "Sell"
|
||||
|
||||
|
||||
class LicenseRestrictions(BaseModel):
|
||||
"""Broad categories of licensing restrictions."""
|
||||
|
||||
AllowNoCredit: bool = Field(
|
||||
description="if true, model can be redistributed without crediting author", default=False
|
||||
)
|
||||
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
|
||||
AllowDifferentLicense: bool = Field(
|
||||
description="if true, derivatives of this model be redistributed under a different license", default=False
|
||||
)
|
||||
AllowCommercialUse: Optional[Set[CommercialUsage] | CommercialUsage] = Field(
|
||||
description="Type of commercial use allowed if no commercial use is allowed.", default=None
|
||||
)
|
||||
|
||||
|
||||
class RemoteModelFile(BaseModel):
|
||||
"""Information about a downloadable file that forms part of a model."""
|
||||
|
||||
@ -69,24 +41,10 @@ class RemoteModelFile(BaseModel):
|
||||
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
||||
|
||||
|
||||
class ModelDefaultSettings(BaseModel):
|
||||
vae: str | None
|
||||
vae_precision: str | None
|
||||
scheduler: SCHEDULER_NAME_VALUES | None
|
||||
steps: int | None
|
||||
cfg_scale: float | None
|
||||
cfg_rescale_multiplier: float | None
|
||||
|
||||
|
||||
class ModelMetadataBase(BaseModel):
|
||||
"""Base class for model metadata information."""
|
||||
|
||||
name: str = Field(description="model's name")
|
||||
author: str = Field(description="model's author")
|
||||
tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="default settings for this model", default=None
|
||||
)
|
||||
|
||||
|
||||
class BaseMetadata(ModelMetadataBase):
|
||||
@ -124,60 +82,16 @@ class CivitaiMetadata(ModelMetadataWithFiles):
|
||||
"""Extended metadata fields provided by Civitai."""
|
||||
|
||||
type: Literal["civitai"] = "civitai"
|
||||
id: int = Field(description="Civitai version identifier")
|
||||
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
|
||||
version_id: int = Field(description="Civitai model version identifier")
|
||||
created: datetime = Field(description="date the model was created")
|
||||
updated: datetime = Field(description="date the model was last modified")
|
||||
published: datetime = Field(description="date the model was published to Civitai")
|
||||
description: str = Field(description="text description of model; may contain HTML")
|
||||
version_description: str = Field(
|
||||
description="text description of the model's reversion; usually change history; may contain HTML"
|
||||
)
|
||||
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
|
||||
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
|
||||
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
|
||||
download_url: AnyHttpUrl = Field(description="download URL for this model")
|
||||
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
|
||||
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
|
||||
weight_minmax: Tuple[float, float] = Field(
|
||||
description="minimum and maximum slider values for a LoRA or other secondary model", default=(-1.0, +2.0)
|
||||
) # note: For future use
|
||||
|
||||
@property
|
||||
def credit_required(self) -> bool:
|
||||
"""Return True if you must give credit for derivatives of this model and images generated from it."""
|
||||
return not self.restrictions.AllowNoCredit
|
||||
|
||||
@property
|
||||
def allow_commercial_use(self) -> bool:
|
||||
"""Return True if commercial use is allowed."""
|
||||
if self.restrictions.AllowCommercialUse is None:
|
||||
return False
|
||||
else:
|
||||
# accommodate schema change
|
||||
acu = self.restrictions.AllowCommercialUse
|
||||
commercial_usage = acu if isinstance(acu, set) else {acu}
|
||||
return CommercialUsage.No not in commercial_usage
|
||||
|
||||
@property
|
||||
def allow_derivatives(self) -> bool:
|
||||
"""Return True if derivatives of this model can be redistributed."""
|
||||
return self.restrictions.AllowDerivatives
|
||||
|
||||
@property
|
||||
def allow_different_license(self) -> bool:
|
||||
"""Return true if derivatives of this model can use a different license."""
|
||||
return self.restrictions.AllowDifferentLicense
|
||||
trigger_phrases: set[str] = Field(description="Trigger phrases extracted from the API response")
|
||||
api_response: Optional[str] = Field(description="Response from the Civitai API as stringified JSON", default=None)
|
||||
|
||||
|
||||
class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||
"""Extended metadata fields provided by HuggingFace."""
|
||||
|
||||
type: Literal["huggingface"] = "huggingface"
|
||||
id: str = Field(description="huggingface model id")
|
||||
tag_dict: Dict[str, Any]
|
||||
last_modified: datetime = Field(description="date of last commit to repo")
|
||||
id: str = Field(description="The HF model id")
|
||||
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
|
||||
|
||||
def download_urls(
|
||||
self,
|
||||
@ -206,7 +120,7 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||
# the next step reads model_index.json to determine which subdirectories belong
|
||||
# to the model
|
||||
if Path(f"{prefix}model_index.json") in paths:
|
||||
url = hf_hub_url(self.id, filename="model_index.json", subfolder=subfolder)
|
||||
url = hf_hub_url(self.id, filename="model_index.json", subfolder=str(subfolder) if subfolder else None)
|
||||
resp = session.get(url)
|
||||
resp.raise_for_status()
|
||||
submodels = resp.json()
|
||||
|
@ -1,221 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
from .fetch import ModelMetadataFetchBase
|
||||
from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException
|
||||
|
||||
|
||||
class ModelMetadataStore:
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise UnknownMetadataException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id,metadata FROM model_metadata;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select tag_text from tags;
|
||||
"""
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.model_id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches if matches else set()
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
@ -8,6 +8,7 @@ import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.util.util import SilenceWarnings
|
||||
|
||||
from .config import (
|
||||
@ -17,6 +18,7 @@ from .config import (
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
@ -95,8 +97,8 @@ class ModelProbe(object):
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"AutoencoderTiny": ModelType.Vae,
|
||||
"AutoencoderKL": ModelType.VAE,
|
||||
"AutoencoderTiny": ModelType.VAE,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||
"T2IAdapter": ModelType.T2IAdapter,
|
||||
@ -108,14 +110,6 @@ class ModelProbe(object):
|
||||
) -> None:
|
||||
cls.PROBES[format][model_type] = probe_class
|
||||
|
||||
@classmethod
|
||||
def heuristic_probe(
|
||||
cls,
|
||||
model_path: Path,
|
||||
fields: Optional[Dict[str, Any]] = None,
|
||||
) -> AnyModelConfig:
|
||||
return cls.probe(model_path, fields)
|
||||
|
||||
@classmethod
|
||||
def probe(
|
||||
cls,
|
||||
@ -137,19 +131,21 @@ class ModelProbe(object):
|
||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||
model_info = None
|
||||
model_type = None
|
||||
if format_type == "diffusers":
|
||||
if format_type is ModelFormat.Diffusers:
|
||||
model_type = cls.get_model_type_from_folder(model_path)
|
||||
else:
|
||||
model_type = cls.get_model_type_from_checkpoint(model_path)
|
||||
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
|
||||
format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
|
||||
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
if not probe_class:
|
||||
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||
|
||||
hash = ModelHash().hash(model_path)
|
||||
probe = probe_class(model_path)
|
||||
|
||||
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
|
||||
fields["source"] = fields.get("source") or model_path.as_posix()
|
||||
fields["key"] = fields.get("key", uuid_string())
|
||||
fields["path"] = model_path.as_posix()
|
||||
fields["type"] = fields.get("type") or model_type
|
||||
fields["base"] = fields.get("base") or probe.get_base_type()
|
||||
@ -161,15 +157,17 @@ class ModelProbe(object):
|
||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["original_hash"] = fields.get("original_hash") or hash
|
||||
fields["current_hash"] = fields.get("current_hash") or hash
|
||||
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
||||
|
||||
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
|
||||
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
||||
# additional fields needed for main and controlnet models
|
||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
|
||||
fields["config"] = cls._get_checkpoint_config_path(
|
||||
if (
|
||||
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
|
||||
and fields["format"] is ModelFormat.Checkpoint
|
||||
):
|
||||
fields["config_path"] = cls._get_checkpoint_config_path(
|
||||
model_path,
|
||||
model_type=fields["type"],
|
||||
base_type=fields["base"],
|
||||
@ -179,7 +177,7 @@ class ModelProbe(object):
|
||||
|
||||
# additional fields needed for main non-checkpoint models
|
||||
elif fields["type"] == ModelType.Main and fields["format"] in [
|
||||
ModelFormat.Onnx,
|
||||
ModelFormat.ONNX,
|
||||
ModelFormat.Olive,
|
||||
ModelFormat.Diffusers,
|
||||
]:
|
||||
@ -213,11 +211,11 @@ class ModelProbe(object):
|
||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||
return ModelType.Main
|
||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||
return ModelType.Vae
|
||||
return ModelType.VAE
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
return ModelType.Lora
|
||||
return ModelType.LoRA
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
return ModelType.Lora
|
||||
return ModelType.LoRA
|
||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||
return ModelType.ControlNet
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
@ -239,7 +237,7 @@ class ModelProbe(object):
|
||||
if (folder_path / f"learned_embeds.{suffix}").exists():
|
||||
return ModelType.TextualInversion
|
||||
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
|
||||
return ModelType.Lora
|
||||
return ModelType.LoRA
|
||||
if (folder_path / "unet/model.onnx").exists():
|
||||
return ModelType.ONNX
|
||||
if (folder_path / "image_encoder.txt").exists():
|
||||
@ -285,13 +283,21 @@ class ModelProbe(object):
|
||||
if possible_conf.exists():
|
||||
return possible_conf.absolute()
|
||||
|
||||
if model_type == ModelType.Main:
|
||||
if model_type is ModelType.Main:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
elif model_type == ModelType.ControlNet:
|
||||
elif model_type is ModelType.ControlNet:
|
||||
config_file = (
|
||||
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
|
||||
"../controlnet/cldm_v15.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
else "../controlnet/cldm_v21.yaml"
|
||||
)
|
||||
elif model_type is ModelType.VAE:
|
||||
config_file = (
|
||||
"../stable-diffusion/v1-inference.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
else "../stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
else:
|
||||
raise InvalidModelConfigException(
|
||||
@ -497,12 +503,12 @@ class FolderProbeBase(ProbeBase):
|
||||
if ".fp16" in x.suffixes:
|
||||
return ModelRepoVariant.FP16
|
||||
if "openvino_model" in x.name:
|
||||
return ModelRepoVariant.OPENVINO
|
||||
return ModelRepoVariant.OpenVINO
|
||||
if "flax_model" in x.name:
|
||||
return ModelRepoVariant.FLAX
|
||||
return ModelRepoVariant.Flax
|
||||
if x.suffix == ".onnx":
|
||||
return ModelRepoVariant.ONNX
|
||||
return ModelRepoVariant.DEFAULT
|
||||
return ModelRepoVariant.Default
|
||||
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
@ -708,8 +714,8 @@ class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
|
||||
############## register probe classes ######
|
||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
@ -717,8 +723,8 @@ ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderPro
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
|
@ -13,6 +13,7 @@ files_to_download = select_hf_model_files(metadata.files, variant='onnx')
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
@ -34,7 +35,7 @@ def filter_files(
|
||||
The file list can be obtained from the `files` field of HuggingFaceMetadata,
|
||||
as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
|
||||
"""
|
||||
variant = variant or ModelRepoVariant.DEFAULT
|
||||
variant = variant or ModelRepoVariant.Default
|
||||
paths: List[Path] = []
|
||||
root = files[0].parts[0]
|
||||
|
||||
@ -73,64 +74,81 @@ def filter_files(
|
||||
return sorted(_filter_by_variant(paths, variant))
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubfolderCandidate:
|
||||
path: Path
|
||||
score: int
|
||||
|
||||
|
||||
def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]:
|
||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||
result = set()
|
||||
basenames: Dict[Path, Path] = {}
|
||||
result: set[Path] = set()
|
||||
subfolder_weights: dict[Path, list[SubfolderCandidate]] = {}
|
||||
for path in files:
|
||||
if path.suffix in [".onnx", ".pb", ".onnx_data"]:
|
||||
if variant == ModelRepoVariant.ONNX:
|
||||
result.add(path)
|
||||
|
||||
elif "openvino_model" in path.name:
|
||||
if variant == ModelRepoVariant.OPENVINO:
|
||||
if variant == ModelRepoVariant.OpenVINO:
|
||||
result.add(path)
|
||||
|
||||
elif "flax_model" in path.name:
|
||||
if variant == ModelRepoVariant.FLAX:
|
||||
if variant == ModelRepoVariant.Flax:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".json", ".txt"]:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [
|
||||
elif variant in [
|
||||
ModelRepoVariant.FP16,
|
||||
ModelRepoVariant.FP32,
|
||||
ModelRepoVariant.DEFAULT,
|
||||
]:
|
||||
parent = path.parent
|
||||
suffixes = path.suffixes
|
||||
if len(suffixes) == 2:
|
||||
variant_label, suffix = suffixes
|
||||
basename = parent / Path(path.stem).stem
|
||||
else:
|
||||
variant_label = ""
|
||||
suffix = suffixes[0]
|
||||
basename = parent / path.stem
|
||||
ModelRepoVariant.Default,
|
||||
] and path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"]:
|
||||
# For weights files, we want to select the best one for each subfolder. For example, we may have multiple
|
||||
# text encoders:
|
||||
#
|
||||
# - text_encoder/model.fp16.safetensors
|
||||
# - text_encoder/model.safetensors
|
||||
# - text_encoder/pytorch_model.bin
|
||||
# - text_encoder/pytorch_model.fp16.bin
|
||||
#
|
||||
# We prefer safetensors over other file formats and an exact variant match. We'll score each file based on
|
||||
# variant and format and select the best one.
|
||||
|
||||
if previous := basenames.get(basename):
|
||||
if (
|
||||
previous.suffix != ".safetensors" and suffix == ".safetensors"
|
||||
): # replace non-safetensors with safetensors when available
|
||||
basenames[basename] = path
|
||||
if variant_label == f".{variant}":
|
||||
basenames[basename] = path
|
||||
elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]:
|
||||
basenames[basename] = path
|
||||
else:
|
||||
basenames[basename] = path
|
||||
parent = path.parent
|
||||
score = 0
|
||||
|
||||
if path.suffix == ".safetensors":
|
||||
score += 1
|
||||
|
||||
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
|
||||
|
||||
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
|
||||
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
|
||||
if candidate_variant_label == f".{variant}" or (
|
||||
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]
|
||||
):
|
||||
score += 1
|
||||
|
||||
if parent not in subfolder_weights:
|
||||
subfolder_weights[parent] = []
|
||||
|
||||
subfolder_weights[parent].append(SubfolderCandidate(path=path, score=score))
|
||||
|
||||
else:
|
||||
continue
|
||||
|
||||
for v in basenames.values():
|
||||
result.add(v)
|
||||
for candidate_list in subfolder_weights.values():
|
||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||
if highest_score_candidate:
|
||||
result.add(highest_score_candidate.path)
|
||||
|
||||
# If one of the architecture-related variants was specified and no files matched other than
|
||||
# config and text files then we return an empty list
|
||||
if (
|
||||
variant
|
||||
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OPENVINO, ModelRepoVariant.FLAX]
|
||||
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OpenVINO, ModelRepoVariant.Flax]
|
||||
and not any(variant.value in x.name for x in result)
|
||||
):
|
||||
return set()
|
||||
|
@ -858,9 +858,9 @@ def do_textual_inversion_training(
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
@ -144,7 +144,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
|
||||
self.nextrely = top_of_table
|
||||
self.lora_models = self.add_model_widgets(
|
||||
model_type=ModelType.Lora,
|
||||
model_type=ModelType.LoRA,
|
||||
window_width=window_width,
|
||||
)
|
||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||
|
@ -30,7 +30,7 @@
|
||||
"lint:prettier": "prettier --check .",
|
||||
"lint:tsc": "tsc --noEmit",
|
||||
"lint": "concurrently -g -c red,green,yellow,blue,magenta pnpm:lint:*",
|
||||
"fix": "knip --fix && eslint --fix . && prettier --log-level warn --write .",
|
||||
"fix": "eslint --fix . && prettier --log-level warn --write .",
|
||||
"preinstall": "npx only-allow pnpm",
|
||||
"storybook": "storybook dev -p 6006",
|
||||
"build-storybook": "storybook build",
|
||||
|
@ -304,6 +304,12 @@
|
||||
"method": "High Resolution Fix Method"
|
||||
}
|
||||
},
|
||||
"prompt": {
|
||||
"addPromptTrigger": "Add Prompt Trigger",
|
||||
"compatibleEmbeddings": "Compatible Embeddings",
|
||||
"noPromptTriggers": "No triggers available",
|
||||
"noMatchingTriggers": "No matching triggers"
|
||||
},
|
||||
"embedding": {
|
||||
"addEmbedding": "Add Embedding",
|
||||
"incompatibleModel": "Incompatible base model:",
|
||||
|
@ -153,7 +153,7 @@ addFirstListImagesListener(startAppListening);
|
||||
// Ad-hoc upscale workflwo
|
||||
addUpscaleRequestedListener(startAppListening);
|
||||
|
||||
// Dynamic prompts
|
||||
// Prompts
|
||||
addDynamicPromptsListener(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening);
|
||||
|
@ -7,8 +7,10 @@ import {
|
||||
selectAllT2IAdapters,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { loraRemoved } from 'features/lora/store/loraSlice';
|
||||
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
|
||||
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
|
||||
import { heightChanged, modelChanged, vaeSelected, widthChanged } from 'features/parameters/store/generationSlice';
|
||||
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
||||
import { forEach, some } from 'lodash-es';
|
||||
import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models';
|
||||
@ -24,7 +26,9 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`);
|
||||
|
||||
const currentModel = getState().generation.model;
|
||||
const state = getState();
|
||||
|
||||
const currentModel = state.generation.model;
|
||||
const models = mainModelsAdapterSelectors.selectAll(action.payload);
|
||||
|
||||
if (models.length === 0) {
|
||||
@ -39,6 +43,29 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
||||
return;
|
||||
}
|
||||
|
||||
const defaultModel = state.config.sd.defaultModel;
|
||||
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
|
||||
|
||||
if (defaultModelInList) {
|
||||
const result = zParameterModel.safeParse(defaultModelInList);
|
||||
if (result.success) {
|
||||
dispatch(modelChanged(defaultModelInList, currentModel));
|
||||
|
||||
const optimalDimension = getOptimalDimension(defaultModelInList);
|
||||
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
|
||||
return;
|
||||
}
|
||||
const { width, height } = calculateNewSize(
|
||||
state.generation.aspectRatio.value,
|
||||
optimalDimension * optimalDimension
|
||||
);
|
||||
|
||||
dispatch(widthChanged(width));
|
||||
dispatch(heightChanged(height));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const result = zParameterModel.safeParse(models[0]);
|
||||
|
||||
if (!result.success) {
|
||||
|
@ -34,13 +34,13 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
||||
return;
|
||||
}
|
||||
|
||||
const metadata = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(currentModel.key)).unwrap();
|
||||
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
|
||||
|
||||
if (!metadata || !metadata.default_settings) {
|
||||
if (!modelConfig || !modelConfig.default_settings) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = metadata.default_settings;
|
||||
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = modelConfig.default_settings;
|
||||
|
||||
if (vae) {
|
||||
// we store this as "default" within default settings
|
||||
|
@ -14,7 +14,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
const { bytes, total_bytes, id } = action.payload.data;
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.bytes = bytes;
|
||||
@ -33,7 +33,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
const { id } = action.payload.data;
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'completed';
|
||||
@ -41,7 +41,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelConfig' }]));
|
||||
dispatch(api.util.invalidateTags(['Model']));
|
||||
},
|
||||
});
|
||||
|
||||
@ -51,7 +51,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
const { id, error, error_type } = action.payload.data;
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'error';
|
||||
|
@ -1,21 +0,0 @@
|
||||
import type { Meta, StoryObj } from '@storybook/react';
|
||||
|
||||
import { EmbeddingSelect } from './EmbeddingSelect';
|
||||
import type { EmbeddingSelectProps } from './types';
|
||||
|
||||
const meta: Meta<typeof EmbeddingSelect> = {
|
||||
title: 'Feature/Prompt/EmbeddingSelect',
|
||||
tags: ['autodocs'],
|
||||
component: EmbeddingSelect,
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof EmbeddingSelect>;
|
||||
|
||||
const Component = (props: EmbeddingSelectProps) => {
|
||||
return <EmbeddingSelect {...props}>Invoke</EmbeddingSelect>;
|
||||
};
|
||||
|
||||
export const Default: Story = {
|
||||
render: Component,
|
||||
};
|
@ -1,67 +0,0 @@
|
||||
import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import type { EmbeddingSelectProps } from 'features/embedding/types';
|
||||
import { t } from 'i18next';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { TextualInversionModelConfig } from 'services/api/types';
|
||||
|
||||
const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
|
||||
|
||||
export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(embedding: TextualInversionModelConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === embedding.base;
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
return !hasMainModel || !isCompatible;
|
||||
},
|
||||
[currentBaseModel]
|
||||
);
|
||||
const { data, isLoading } = useGetTextualInversionModelsQuery();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(embedding: TextualInversionModelConfig | null) => {
|
||||
if (!embedding) {
|
||||
return;
|
||||
}
|
||||
onSelect(embedding.name);
|
||||
},
|
||||
[onSelect]
|
||||
);
|
||||
|
||||
const { options, onChange } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
getIsDisabled,
|
||||
onChange: _onChange,
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<Combobox
|
||||
placeholder={isLoading ? t('common.loading') : t('embedding.addEmbedding')}
|
||||
defaultMenuIsOpen
|
||||
autoFocus
|
||||
value={null}
|
||||
options={options}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
onChange={onChange}
|
||||
onMenuClose={onClose}
|
||||
data-testid="add-embedding"
|
||||
sx={selectStyles}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
EmbeddingSelect.displayName = 'EmbeddingSelect';
|
||||
|
||||
const selectStyles: ChakraProps['sx'] = {
|
||||
w: 'full',
|
||||
};
|
@ -1,228 +0,0 @@
|
||||
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import BaseModelSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect';
|
||||
import BooleanSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect';
|
||||
import ModelFormatSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect';
|
||||
import ModelTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect';
|
||||
import ModelVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect';
|
||||
import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect';
|
||||
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { isNil, omitBy } from 'lodash-es';
|
||||
import { useCallback, useEffect } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
export const AdvancedImport = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [installModel] = useInstallModelMutation();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const {
|
||||
register,
|
||||
handleSubmit,
|
||||
control,
|
||||
formState: { errors },
|
||||
setValue,
|
||||
resetField,
|
||||
reset,
|
||||
watch,
|
||||
} = useForm<AnyModelConfig>({
|
||||
defaultValues: {
|
||||
name: '',
|
||||
base: 'sd-1',
|
||||
type: 'main',
|
||||
path: '',
|
||||
description: '',
|
||||
format: 'diffusers',
|
||||
vae: '',
|
||||
variant: 'normal',
|
||||
},
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
|
||||
(values) => {
|
||||
installModel({
|
||||
source: values.path,
|
||||
config: omitBy(values, isNil),
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelAdded', {
|
||||
modelName: values.name,
|
||||
}),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
reset();
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
},
|
||||
[installModel, dispatch, t, reset]
|
||||
);
|
||||
|
||||
const watchedModelType = watch('type');
|
||||
const watchedModelFormat = watch('format');
|
||||
|
||||
useEffect(() => {
|
||||
if (watchedModelType === 'main') {
|
||||
setValue('format', 'diffusers');
|
||||
setValue('repo_variant', '');
|
||||
setValue('variant', 'normal');
|
||||
}
|
||||
if (watchedModelType === 'lora') {
|
||||
setValue('format', 'lycoris');
|
||||
} else if (watchedModelType === 'embedding') {
|
||||
setValue('format', 'embedding_file');
|
||||
} else if (watchedModelType === 'ip_adapter') {
|
||||
setValue('format', 'invokeai');
|
||||
} else {
|
||||
setValue('format', 'diffusers');
|
||||
}
|
||||
resetField('upcast_attention');
|
||||
resetField('ztsnr_training');
|
||||
resetField('vae');
|
||||
resetField('config');
|
||||
resetField('prediction_type');
|
||||
resetField('image_encoder_model_id');
|
||||
}, [watchedModelType, resetField, setValue]);
|
||||
|
||||
return (
|
||||
<ScrollableContent>
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
|
||||
<Flex alignItems="flex-end" gap="4">
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
||||
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
|
||||
</FormControl>
|
||||
<Text px="2" fontSize="xs" textAlign="center">
|
||||
{t('modelManager.advancedImportInfo')}
|
||||
</Text>
|
||||
</Flex>
|
||||
|
||||
<Flex p={4} borderRadius={4} bg="base.850" height="100%" direction="column" gap="3">
|
||||
<FormControl isInvalid={Boolean(errors.name)}>
|
||||
<Flex direction="column" width="full">
|
||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
||||
<Input
|
||||
{...register('name', {
|
||||
validate: (value) => value.trim().length >= 3 || 'Must be at least 3 characters',
|
||||
})}
|
||||
/>
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
<Flex>
|
||||
<FormControl>
|
||||
<Flex direction="column" width="full">
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Textarea size="sm" {...register('description')} />
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect control={control} name="base" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('common.format')}</FormLabel>
|
||||
<ModelFormatSelect control={control} name="format" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>{t('modelManager.path')}</FormLabel>
|
||||
<Input
|
||||
{...register('path', {
|
||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||
})}
|
||||
/>
|
||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
</Flex>
|
||||
{watchedModelType === 'main' && (
|
||||
<>
|
||||
<Flex gap={4}>
|
||||
{watchedModelFormat === 'diffusers' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
|
||||
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
|
||||
</FormControl>
|
||||
)}
|
||||
{watchedModelFormat === 'checkpoint' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||
<Input {...register('config')} />
|
||||
</FormControl>
|
||||
)}
|
||||
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
|
||||
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
||||
<Input {...register('vae')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{watchedModelType === 'ip_adapter' && (
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
|
||||
<Input {...register('image_encoder_model_id')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
)}
|
||||
<Button mt={2} type="submit">
|
||||
{t('modelManager.addModel')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</form>
|
||||
</ScrollableContent>
|
||||
);
|
||||
};
|
@ -12,7 +12,7 @@ type SimpleImportModelConfig = {
|
||||
location: string;
|
||||
};
|
||||
|
||||
export const SimpleImport = () => {
|
||||
export const InstallModelForm = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [installModel, { isLoading }] = useInstallModelMutation();
|
@ -5,19 +5,19 @@ import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models';
|
||||
import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { ImportQueueItem } from './ImportQueueItem';
|
||||
import { ModelInstallQueueItem } from './ModelInstallQueueItem';
|
||||
|
||||
export const ImportQueue = () => {
|
||||
export const ModelInstallQueue = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { data } = useGetModelImportsQuery();
|
||||
const { data } = useListModelInstallsQuery();
|
||||
|
||||
const [pruneModelImports] = usePruneModelImportsMutation();
|
||||
const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation();
|
||||
|
||||
const pruneQueue = useCallback(() => {
|
||||
pruneModelImports()
|
||||
const pruneCompletedModelInstalls = useCallback(() => {
|
||||
_pruneCompletedModelInstalls()
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
@ -41,7 +41,7 @@ export const ImportQueue = () => {
|
||||
);
|
||||
}
|
||||
});
|
||||
}, [pruneModelImports, dispatch]);
|
||||
}, [_pruneCompletedModelInstalls, dispatch]);
|
||||
|
||||
const pruneAvailable = useMemo(() => {
|
||||
return data?.some(
|
||||
@ -53,14 +53,19 @@ export const ImportQueue = () => {
|
||||
<Flex flexDir="column" p={3} h="full">
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<Text>{t('modelManager.importQueue')}</Text>
|
||||
<Button size="sm" isDisabled={!pruneAvailable} onClick={pruneQueue} tooltip={t('modelManager.pruneTooltip')}>
|
||||
<Button
|
||||
size="sm"
|
||||
isDisabled={!pruneAvailable}
|
||||
onClick={pruneCompletedModelInstalls}
|
||||
tooltip={t('modelManager.pruneTooltip')}
|
||||
>
|
||||
{t('modelManager.prune')}
|
||||
</Button>
|
||||
</Flex>
|
||||
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column-reverse" gap="2">
|
||||
{data?.map((model) => <ImportQueueItem key={model.id} model={model} />)}
|
||||
{data?.map((model) => <ModelInstallQueueItem key={model.id} installJob={model} />)}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
</Box>
|
@ -6,17 +6,24 @@ import type { ModelInstallStatus } from 'services/api/types';
|
||||
const STATUSES = {
|
||||
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },
|
||||
downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||
downloads_done: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||
running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||
completed: { colorScheme: 'green', translationKey: 'queue.completed' },
|
||||
error: { colorScheme: 'red', translationKey: 'queue.failed' },
|
||||
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
|
||||
};
|
||||
|
||||
const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus; errorReason?: string | null }) => {
|
||||
const ModelInstallQueueBadge = ({
|
||||
status,
|
||||
errorReason,
|
||||
}: {
|
||||
status?: ModelInstallStatus;
|
||||
errorReason?: string | null;
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (!status || !Object.keys(STATUSES).includes(status)) {
|
||||
return <></>;
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
@ -25,4 +32,4 @@ const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
export default memo(ImportQueueBadge);
|
||||
export default memo(ModelInstallQueueBadge);
|
@ -3,15 +3,16 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
|
||||
import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
|
||||
import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
|
||||
|
||||
import ImportQueueBadge from './ImportQueueBadge';
|
||||
import ModelInstallQueueBadge from './ModelInstallQueueBadge';
|
||||
|
||||
type ModelListItemProps = {
|
||||
model: ModelInstallJob;
|
||||
installJob: ModelInstallJob;
|
||||
};
|
||||
|
||||
const formatBytes = (bytes: number) => {
|
||||
@ -26,26 +27,26 @@ const formatBytes = (bytes: number) => {
|
||||
return `${bytes.toFixed(2)} ${units[i]}`;
|
||||
};
|
||||
|
||||
export const ImportQueueItem = (props: ModelListItemProps) => {
|
||||
const { model } = props;
|
||||
export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
||||
const { installJob } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [deleteImportModel] = useDeleteModelImportMutation();
|
||||
const [deleteImportModel] = useCancelModelInstallMutation();
|
||||
|
||||
const source = useMemo(() => {
|
||||
if (model.source.type === 'hf') {
|
||||
return model.source as HFModelSource;
|
||||
} else if (model.source.type === 'local') {
|
||||
return model.source as LocalModelSource;
|
||||
} else if (model.source.type === 'url') {
|
||||
return model.source as URLModelSource;
|
||||
if (installJob.source.type === 'hf') {
|
||||
return installJob.source as HFModelSource;
|
||||
} else if (installJob.source.type === 'local') {
|
||||
return installJob.source as LocalModelSource;
|
||||
} else if (installJob.source.type === 'url') {
|
||||
return installJob.source as URLModelSource;
|
||||
} else {
|
||||
return model.source as LocalModelSource;
|
||||
return installJob.source as LocalModelSource;
|
||||
}
|
||||
}, [model.source]);
|
||||
}, [installJob.source]);
|
||||
|
||||
const handleDeleteModelImport = useCallback(() => {
|
||||
deleteImportModel(model.id)
|
||||
deleteImportModel(installJob.id)
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
@ -69,7 +70,7 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
|
||||
);
|
||||
}
|
||||
});
|
||||
}, [deleteImportModel, model, dispatch]);
|
||||
}, [deleteImportModel, installJob, dispatch]);
|
||||
|
||||
const modelName = useMemo(() => {
|
||||
switch (source.type) {
|
||||
@ -85,19 +86,23 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
|
||||
}, [source]);
|
||||
|
||||
const progressValue = useMemo(() => {
|
||||
if (model.bytes === undefined || model.total_bytes === undefined) {
|
||||
if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (installJob.total_bytes === 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return (model.bytes / model.total_bytes) * 100;
|
||||
}, [model.bytes, model.total_bytes]);
|
||||
return (installJob.bytes / installJob.total_bytes) * 100;
|
||||
}, [installJob.bytes, installJob.total_bytes]);
|
||||
|
||||
const progressString = useMemo(() => {
|
||||
if (model.status !== 'downloading' || model.bytes === undefined || model.total_bytes === undefined) {
|
||||
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
|
||||
return '';
|
||||
}
|
||||
return `${formatBytes(model.bytes)} / ${formatBytes(model.total_bytes)}`;
|
||||
}, [model.bytes, model.total_bytes, model.status]);
|
||||
return `${formatBytes(installJob.bytes)} / ${formatBytes(installJob.total_bytes)}`;
|
||||
}, [installJob.bytes, installJob.total_bytes, installJob.status]);
|
||||
|
||||
return (
|
||||
<Flex gap="2" w="full" alignItems="center">
|
||||
@ -109,19 +114,21 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
|
||||
<Flex flexDir="column" flex={1}>
|
||||
<Tooltip label={progressString}>
|
||||
<Progress
|
||||
value={progressValue}
|
||||
isIndeterminate={progressValue === undefined}
|
||||
value={progressValue ?? 0}
|
||||
isIndeterminate={progressValue === null}
|
||||
aria-label={t('accessibility.invokeProgressBar')}
|
||||
h={2}
|
||||
/>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
<Box minW="100px" textAlign="center">
|
||||
<ImportQueueBadge status={model.status} errorReason={model.error_reason} />
|
||||
<ModelInstallQueueBadge status={installJob.status} errorReason={installJob.error_reason} />
|
||||
</Box>
|
||||
|
||||
<Box minW="20px">
|
||||
{(model.status === 'downloading' || model.status === 'waiting' || model.status === 'running') && (
|
||||
{(installJob.status === 'downloading' ||
|
||||
installJob.status === 'waiting' ||
|
||||
installJob.status === 'running') && (
|
||||
<IconButton
|
||||
isRound={true}
|
||||
size="xs"
|
@ -2,24 +2,24 @@ import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useLazyScanModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useLazyScanFolderQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { ScanModelsResults } from './ScanModelsResults';
|
||||
import { ScanModelsResults } from './ScanFolderResults';
|
||||
|
||||
export const ScanModelsForm = () => {
|
||||
const [scanPath, setScanPath] = useState('');
|
||||
const [errorMessage, setErrorMessage] = useState('');
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery();
|
||||
const [_scanFolder, { isLoading, data }] = useLazyScanFolderQuery();
|
||||
|
||||
const handleSubmitScan = useCallback(async () => {
|
||||
_scanModels({ scan_path: scanPath }).catch((error) => {
|
||||
const scanFolder = useCallback(async () => {
|
||||
_scanFolder({ scan_path: scanPath }).catch((error) => {
|
||||
if (error) {
|
||||
setErrorMessage(error.data.detail);
|
||||
}
|
||||
});
|
||||
}, [_scanModels, scanPath]);
|
||||
}, [_scanFolder, scanPath]);
|
||||
|
||||
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
||||
setScanPath(e.target.value);
|
||||
@ -36,7 +36,7 @@ export const ScanModelsForm = () => {
|
||||
<Input value={scanPath} onChange={handleSetScanPath} />
|
||||
</Flex>
|
||||
|
||||
<Button onClick={handleSubmitScan} isLoading={isLoading} isDisabled={scanPath.length === 0}>
|
||||
<Button onClick={scanFolder} isLoading={isLoading} isDisabled={scanPath.length === 0}>
|
||||
{t('modelManager.scanFolder')}
|
||||
</Button>
|
||||
</Flex>
|
@ -18,7 +18,7 @@ import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { ScanModelResultItem } from './ScanModelResultItem';
|
||||
import { ScanModelResultItem } from './ScanFolderResultItem';
|
||||
|
||||
type ScanModelResultsProps = {
|
||||
results: ScanFolderResponse;
|
@ -1,12 +1,11 @@
|
||||
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { AdvancedImport } from './AddModelPanel/AdvancedImport';
|
||||
import { ImportQueue } from './AddModelPanel/ImportQueue/ImportQueue';
|
||||
import { ScanModelsForm } from './AddModelPanel/ScanModels/ScanModelsForm';
|
||||
import { SimpleImport } from './AddModelPanel/SimpleImport';
|
||||
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
|
||||
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
|
||||
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
|
||||
|
||||
export const ImportModels = () => {
|
||||
export const InstallModels = () => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Flex layerStyle="first" p={3} borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
|
||||
@ -17,15 +16,11 @@ export const ImportModels = () => {
|
||||
<Tabs variant="collapse" height="100%">
|
||||
<TabList>
|
||||
<Tab>{t('common.simple')}</Tab>
|
||||
<Tab>{t('modelManager.advanced')}</Tab>
|
||||
<Tab>{t('modelManager.scan')}</Tab>
|
||||
</TabList>
|
||||
<TabPanels p={3} height="100%">
|
||||
<TabPanel>
|
||||
<SimpleImport />
|
||||
</TabPanel>
|
||||
<TabPanel height="100%">
|
||||
<AdvancedImport />
|
||||
<InstallModelForm />
|
||||
</TabPanel>
|
||||
<TabPanel height="100%">
|
||||
<ScanModelsForm />
|
||||
@ -34,7 +29,7 @@ export const ImportModels = () => {
|
||||
</Tabs>
|
||||
</Box>
|
||||
<Box layerStyle="second" borderRadius="base" w="full" h="50%">
|
||||
<ImportQueue />
|
||||
<ModelInstallQueue />
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
@ -5,7 +5,7 @@ import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoFilter } from 'react-icons/io5';
|
||||
|
||||
export const MODEL_TYPE_LABELS: { [key: string]: string } = {
|
||||
const MODEL_TYPE_LABELS: { [key: string]: string } = {
|
||||
main: 'Main',
|
||||
lora: 'LoRA',
|
||||
embedding: 'Textual Inversion',
|
||||
|
@ -1,14 +1,14 @@
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
|
||||
import { ImportModels } from './ImportModels';
|
||||
import { InstallModels } from './InstallModels';
|
||||
import { Model } from './ModelPanel/Model';
|
||||
|
||||
export const ModelPane = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
return (
|
||||
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
|
||||
{selectedModelKey ? <Model key={selectedModelKey} /> : <ImportModels />}
|
||||
{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
@ -1,11 +1,12 @@
|
||||
import { Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import Loading from 'common/components/Loading/Loading';
|
||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
|
||||
|
||||
@ -23,8 +24,9 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
|
||||
|
||||
export const DefaultSettings = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data, isLoading } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
|
||||
useAppSelector(initialStatesSelector);
|
||||
|
||||
@ -59,7 +61,7 @@ export const DefaultSettings = () => {
|
||||
]);
|
||||
|
||||
if (isLoading) {
|
||||
return <Loading />;
|
||||
return <Text>{t('common.loading')}</Text>;
|
||||
}
|
||||
|
||||
return <DefaultSettingsForm defaultSettingsDefaults={defaultSettingsDefaults} />;
|
||||
|
@ -8,7 +8,7 @@ import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoPencil } from 'react-icons/io5';
|
||||
import { useUpdateModelMetadataMutation } from 'services/api/endpoints/models';
|
||||
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
|
||||
import { DefaultCfgScale } from './DefaultCfgScale';
|
||||
@ -41,7 +41,7 @@ export const DefaultSettingsForm = ({
|
||||
const { t } = useTranslation();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
|
||||
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation();
|
||||
const [updateModel, { isLoading }] = useUpdateModelMutation();
|
||||
|
||||
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
|
||||
defaultValues: defaultSettingsDefaults,
|
||||
@ -62,7 +62,7 @@ export const DefaultSettingsForm = ({
|
||||
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
|
||||
};
|
||||
|
||||
editModelMetadata({
|
||||
updateModel({
|
||||
key: selectedModelKey,
|
||||
body: { default_settings: body },
|
||||
})
|
||||
@ -90,7 +90,7 @@ export const DefaultSettingsForm = ({
|
||||
}
|
||||
});
|
||||
},
|
||||
[selectedModelKey, dispatch, editModelMetadata, t]
|
||||
[selectedModelKey, dispatch, updateModel, t]
|
||||
);
|
||||
|
||||
return (
|
||||
|
@ -3,9 +3,9 @@ import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import type { Control } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
||||
@ -14,8 +14,12 @@ const options: ComboboxOption[] = [
|
||||
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
||||
];
|
||||
|
||||
const BaseModelSelect = (props: UseControllerProps<AnyModelConfig>) => {
|
||||
const { field } = useController(props);
|
||||
type Props = {
|
||||
control: Control<UpdateModelArg['body']>;
|
||||
};
|
||||
|
||||
const BaseModelSelect = ({ control }: Props) => {
|
||||
const { field } = useController({ control, name: 'base' });
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
|
@ -1,27 +0,0 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'none', label: '-' },
|
||||
{ value: 'true', label: 'True' },
|
||||
{ value: 'false', label: 'False' },
|
||||
];
|
||||
|
||||
const BooleanSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value === 'true');
|
||||
},
|
||||
[field]
|
||||
);
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(BooleanSelect);
|
@ -1,47 +0,0 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController, useWatch } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const ModelFormatSelect = (props: UseControllerProps<AnyModelConfig>) => {
|
||||
const { field, formState } = useController(props);
|
||||
const type = useWatch({ control: props.control, name: 'type' });
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const options: ComboboxOption[] = useMemo(() => {
|
||||
const modelType = type || formState.defaultValues?.type;
|
||||
if (modelType === 'lora') {
|
||||
return [
|
||||
{ value: 'lycoris', label: 'LyCORIS' },
|
||||
{ value: 'diffusers', label: 'Diffusers' },
|
||||
];
|
||||
} else if (modelType === 'embedding') {
|
||||
return [
|
||||
{ value: 'embedding_file', label: 'Embedding File' },
|
||||
{ value: 'embedding_folder', label: 'Embedding Folder' },
|
||||
];
|
||||
} else if (modelType === 'ip_adapter') {
|
||||
return [{ value: 'invokeai', label: 'invokeai' }];
|
||||
} else {
|
||||
return [
|
||||
{ value: 'diffusers', label: 'Diffusers' },
|
||||
{ value: 'checkpoint', label: 'Checkpoint' },
|
||||
];
|
||||
}
|
||||
}, [type, formState.defaultValues?.type]);
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [options, field.value]);
|
||||
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(ModelFormatSelect);
|
@ -1,32 +0,0 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { MODEL_TYPE_LABELS } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'main', label: MODEL_TYPE_LABELS['main'] as string },
|
||||
{ value: 'lora', label: MODEL_TYPE_LABELS['lora'] as string },
|
||||
{ value: 'embedding', label: MODEL_TYPE_LABELS['embedding'] as string },
|
||||
{ value: 'vae', label: MODEL_TYPE_LABELS['vae'] as string },
|
||||
{ value: 'controlnet', label: MODEL_TYPE_LABELS['controlnet'] as string },
|
||||
{ value: 'ip_adapter', label: MODEL_TYPE_LABELS['ip_adapter'] as string },
|
||||
{ value: 't2i_adapater', label: MODEL_TYPE_LABELS['t2i_adapter'] as string },
|
||||
] as const;
|
||||
|
||||
const ModelTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(ModelTypeSelect);
|
@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import type { Control } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'normal', label: 'Normal' },
|
||||
@ -12,8 +12,12 @@ const options: ComboboxOption[] = [
|
||||
{ value: 'depth', label: 'Depth' },
|
||||
];
|
||||
|
||||
const ModelVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
type Props = {
|
||||
control: Control<UpdateModelArg['body']>;
|
||||
};
|
||||
|
||||
const ModelVariantSelect = ({ control }: Props) => {
|
||||
const { field } = useController({ control, name: 'variant' });
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
|
@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import type { Control } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'none', label: '-' },
|
||||
@ -13,8 +13,12 @@ const options: ComboboxOption[] = [
|
||||
{ value: 'sample', label: 'sample' },
|
||||
];
|
||||
|
||||
const PredictionTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
type Props = {
|
||||
control: Control<UpdateModelArg['body']>;
|
||||
};
|
||||
|
||||
const PredictionTypeSelect = ({ control }: Props) => {
|
||||
const { field } = useController({ control, name: 'prediction_type' });
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
|
@ -1,27 +0,0 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'none', label: '-' },
|
||||
{ value: 'fp16', label: 'fp16' },
|
||||
{ value: 'fp32', label: 'fp32' },
|
||||
];
|
||||
|
||||
const RepoVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(RepoVariantSelect);
|
@ -1,18 +1,41 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
import type { ModelType } from 'services/api/types';
|
||||
|
||||
import { TriggerPhrases } from './TriggerPhrases';
|
||||
|
||||
const MODEL_TYPE_TRIGGER_PHRASE: ModelType[] = ['main', 'lora'];
|
||||
|
||||
export const ModelMetadata = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||
const { data } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const shouldShowTriggerPhraseSettings = useMemo(() => {
|
||||
if (!data?.type) {
|
||||
return false;
|
||||
}
|
||||
return MODEL_TYPE_TRIGGER_PHRASE.includes(data.type);
|
||||
}, [data]);
|
||||
|
||||
const apiResponseFormatted = useMemo(() => {
|
||||
if (!data?.source_api_response) {
|
||||
return {};
|
||||
}
|
||||
return JSON.parse(data.source_api_response);
|
||||
}, [data?.source_api_response]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex flexDir="column" height="full" gap="3">
|
||||
<DataViewer label="metadata" data={metadata || {}} />
|
||||
</Flex>
|
||||
</>
|
||||
<Flex flexDir="column" height="full" gap="3">
|
||||
{shouldShowTriggerPhraseSettings && (
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<TriggerPhrases />
|
||||
</Box>
|
||||
)}
|
||||
<DataViewer label="metadata" data={apiResponseFormatted} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -0,0 +1,106 @@
|
||||
import {
|
||||
Button,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
Input,
|
||||
Tag,
|
||||
TagCloseButton,
|
||||
TagLabel,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ModelListHeader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelListHeader';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
export const TriggerPhrases = () => {
|
||||
const { t } = useTranslation();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data: modelConfig } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
const [phrase, setPhrase] = useState('');
|
||||
|
||||
const [updateModel, { isLoading }] = useUpdateModelMutation();
|
||||
|
||||
const handlePhraseChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setPhrase(e.target.value);
|
||||
}, []);
|
||||
|
||||
const triggerPhrases = useMemo(() => {
|
||||
return modelConfig?.trigger_phrases || [];
|
||||
}, [modelConfig?.trigger_phrases]);
|
||||
|
||||
const errors = useMemo(() => {
|
||||
const errors = [];
|
||||
|
||||
if (phrase.length && triggerPhrases.includes(phrase)) {
|
||||
errors.push('Phrase is already in list');
|
||||
}
|
||||
|
||||
return errors;
|
||||
}, [phrase, triggerPhrases]);
|
||||
|
||||
const addTriggerPhrase = useCallback(async () => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!phrase.length || triggerPhrases.includes(phrase)) {
|
||||
return;
|
||||
}
|
||||
|
||||
await updateModel({
|
||||
key: selectedModelKey,
|
||||
body: { trigger_phrases: [...triggerPhrases, phrase] },
|
||||
}).unwrap();
|
||||
setPhrase('');
|
||||
}, [updateModel, selectedModelKey, phrase, triggerPhrases]);
|
||||
|
||||
const removeTriggerPhrase = useCallback(
|
||||
async (phraseToRemove: string) => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
const filteredPhrases = triggerPhrases.filter((p) => p !== phraseToRemove);
|
||||
|
||||
await updateModel({ key: selectedModelKey, body: { trigger_phrases: filteredPhrases } }).unwrap();
|
||||
},
|
||||
[updateModel, selectedModelKey, triggerPhrases]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" w="full" gap="5">
|
||||
<ModelListHeader title={t('modelManager.triggerPhrases')} />
|
||||
<form>
|
||||
<FormControl w="full" isInvalid={Boolean(errors.length)}>
|
||||
<Flex flexDir="column" w="full">
|
||||
<Flex gap="3" alignItems="center" w="full">
|
||||
<Input value={phrase} onChange={handlePhraseChange} placeholder={t('modelManager.typePhraseHere')} />
|
||||
<Button
|
||||
type="submit"
|
||||
onClick={addTriggerPhrase}
|
||||
isDisabled={Boolean(errors.length)}
|
||||
isLoading={isLoading}
|
||||
>
|
||||
{t('common.add')}
|
||||
</Button>
|
||||
</Flex>
|
||||
{!!errors.length && errors.map((error) => <FormErrorMessage key={error}>{error}</FormErrorMessage>)}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
</form>
|
||||
|
||||
<Flex gap="4" flexWrap="wrap" mt="3" mb="3">
|
||||
{triggerPhrases.map((phrase, index) => (
|
||||
<Tag size="md" key={index}>
|
||||
<TagLabel>{phrase}</TagLabel>
|
||||
<TagCloseButton onClick={removeTriggerPhrase.bind(null, phrase)} isDisabled={isLoading} />
|
||||
</Tag>
|
||||
))}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -13,7 +13,7 @@ import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import { useConvertModelMutation } from 'services/api/endpoints/models';
|
||||
import type { CheckpointModelConfig } from 'services/api/types';
|
||||
|
||||
interface ModelConvertProps {
|
||||
@ -24,7 +24,7 @@ export const ModelConvert = (props: ModelConvertProps) => {
|
||||
const { model } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
|
||||
const [convertModel, { isLoading }] = useConvertModelMutation();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const modelConvertHandler = useCallback(() => {
|
||||
|
@ -1,5 +1,6 @@
|
||||
import {
|
||||
Button,
|
||||
Checkbox,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
@ -19,66 +20,27 @@ import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
import { useGetModelConfigQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import BaseModelSelect from './Fields/BaseModelSelect';
|
||||
import BooleanSelect from './Fields/BooleanSelect';
|
||||
import ModelFormatSelect from './Fields/ModelFormatSelect';
|
||||
import ModelTypeSelect from './Fields/ModelTypeSelect';
|
||||
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
||||
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
||||
import RepoVariantSelect from './Fields/RepoVariantSelect';
|
||||
|
||||
export const ModelEdit = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelsMutation();
|
||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
// const modelData = useMemo(() => {
|
||||
// if (!data) {
|
||||
// return null;
|
||||
// }
|
||||
// const modelFormat = data.format;
|
||||
// const modelType = data.type;
|
||||
|
||||
// if (modelType === 'main') {
|
||||
// if (modelFormat === 'diffusers') {
|
||||
// return data as DiffusersModelConfig;
|
||||
// } else if (modelFormat === 'checkpoint') {
|
||||
// return data as CheckpointModelConfig;
|
||||
// }
|
||||
// }
|
||||
|
||||
// switch (modelType) {
|
||||
// case 'lora':
|
||||
// return data as LoRAModelConfig;
|
||||
// case 'embedding':
|
||||
// return data as TextualInversionModelConfig;
|
||||
// case 't2i_adapter':
|
||||
// return data as T2IAdapterModelConfig;
|
||||
// case 'ip_adapter':
|
||||
// return data as IPAdapterModelConfig;
|
||||
// case 'controlnet':
|
||||
// return data as ControlNetModelConfig;
|
||||
// case 'vae':
|
||||
// return data as VAEModelConfig;
|
||||
// default:
|
||||
// return null;
|
||||
// }
|
||||
// }, [data]);
|
||||
|
||||
const {
|
||||
register,
|
||||
handleSubmit,
|
||||
control,
|
||||
formState: { errors },
|
||||
reset,
|
||||
watch,
|
||||
} = useForm<UpdateModelArg['body']>({
|
||||
defaultValues: {
|
||||
...data,
|
||||
@ -86,10 +48,7 @@ export const ModelEdit = () => {
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const watchedModelType = watch('type');
|
||||
const watchedModelFormat = watch('format');
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
|
||||
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
|
||||
(values) => {
|
||||
if (!data?.key) {
|
||||
return;
|
||||
@ -143,33 +102,31 @@ export const ModelEdit = () => {
|
||||
return (
|
||||
<Flex flexDir="column" h="full">
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
|
||||
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
||||
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
|
||||
<FormLabel hidden={true}>{t('modelManager.modelName')}</FormLabel>
|
||||
<Input
|
||||
{...register('name', {
|
||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
||||
validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
|
||||
})}
|
||||
size="lg"
|
||||
/>
|
||||
|
||||
<Flex gap={2}>
|
||||
<Button size="sm" onClick={handleClickCancel}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
colorScheme="invokeYellow"
|
||||
onClick={handleSubmit(onSubmit)}
|
||||
isLoading={isSubmitting}
|
||||
isDisabled={Boolean(Object.keys(errors).length)}
|
||||
>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</Flex>
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
<Button size="sm" onClick={handleClickCancel}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
colorScheme="invokeYellow"
|
||||
onClick={handleSubmit(onSubmit)}
|
||||
isLoading={isSubmitting}
|
||||
isDisabled={Boolean(Object.keys(errors).length)}
|
||||
>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={3} mt="4">
|
||||
<Flex>
|
||||
@ -184,76 +141,22 @@ export const ModelEdit = () => {
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect control={control} name="base" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
||||
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
|
||||
<BaseModelSelect control={control} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('common.format')}</FormLabel>
|
||||
<ModelFormatSelect control={control} name="format" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>{t('modelManager.path')}</FormLabel>
|
||||
<Input
|
||||
{...register('path', {
|
||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||
})}
|
||||
/>
|
||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
</Flex>
|
||||
{watchedModelType === 'main' && (
|
||||
<>
|
||||
<Flex gap={4}>
|
||||
{watchedModelFormat === 'diffusers' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
|
||||
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
|
||||
</FormControl>
|
||||
)}
|
||||
{watchedModelFormat === 'checkpoint' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||
<Input {...register('config')} />
|
||||
</FormControl>
|
||||
)}
|
||||
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
|
||||
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
||||
<Input {...register('vae')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{watchedModelType === 'ip_adapter' && (
|
||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
|
||||
<Input {...register('image_encoder_model_id')} />
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={control} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect control={control} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<Checkbox {...register('upcast_attention')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
)}
|
||||
|
@ -91,26 +91,19 @@ export const ModelView = () => {
|
||||
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
|
||||
</Flex>
|
||||
{modelData.type === 'main' && (
|
||||
<>
|
||||
<Flex gap={2}>
|
||||
{modelData.format === 'diffusers' && (
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
||||
)}
|
||||
{modelData.format === 'checkpoint' && (
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
|
||||
)}
|
||||
|
||||
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
|
||||
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
|
||||
</Flex>
|
||||
</>
|
||||
<Flex gap={2}>
|
||||
{modelData.format === 'diffusers' && modelData.repo_variant && (
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
||||
)}
|
||||
{modelData.format === 'checkpoint' && (
|
||||
<>
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} />
|
||||
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
)}
|
||||
{modelData.type === 'ip_adapter' && (
|
||||
<Flex gap={2}>
|
||||
|
@ -1,10 +1,10 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||
import { usePrompt } from 'features/embedding/usePrompt';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { setNegativePrompt } from 'features/parameters/store/generationSlice';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@ -19,19 +19,14 @@ export const ParamNegativePrompt = memo(() => {
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown } = usePrompt({
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown } = usePrompt({
|
||||
prompt,
|
||||
textareaRef,
|
||||
onChange: _onChange,
|
||||
});
|
||||
|
||||
return (
|
||||
<EmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={onSelectEmbedding}
|
||||
width={textareaRef.current?.clientWidth}
|
||||
>
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative">
|
||||
<Textarea
|
||||
id="negativePrompt"
|
||||
@ -45,10 +40,10 @@ export const ParamNegativePrompt = memo(() => {
|
||||
variant="darkFilled"
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</EmbeddingPopover>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { ShowDynamicPromptsPreviewButton } from 'features/dynamicPrompts/components/ShowDynamicPromptsPreviewButton';
|
||||
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||
import { usePrompt } from 'features/embedding/usePrompt';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { setPositivePrompt } from 'features/parameters/store/generationSlice';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { SDXLConcatButton } from 'features/sdxl/components/SDXLPrompts/SDXLConcatButton';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import type { HotkeyCallback } from 'react-hotkeys-hook';
|
||||
@ -25,7 +25,7 @@ export const ParamPositivePrompt = memo(() => {
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown, onFocus } = usePrompt({
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
|
||||
prompt,
|
||||
textareaRef: textareaRef,
|
||||
onChange: handleChange,
|
||||
@ -42,12 +42,7 @@ export const ParamPositivePrompt = memo(() => {
|
||||
useHotkeys('alt+a', focus, []);
|
||||
|
||||
return (
|
||||
<EmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={onSelectEmbedding}
|
||||
width={textareaRef.current?.clientWidth}
|
||||
>
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative">
|
||||
<Textarea
|
||||
id="prompt"
|
||||
@ -61,12 +56,12 @@ export const ParamPositivePrompt = memo(() => {
|
||||
variant="darkFilled"
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
{baseModel === 'sdxl' && <SDXLConcatButton />}
|
||||
<ShowDynamicPromptsPreviewButton />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</EmbeddingPopover>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -16,7 +16,6 @@ import type {
|
||||
ParameterScheduler,
|
||||
ParameterVAEModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { clamp } from 'lodash-es';
|
||||
@ -210,26 +209,6 @@ export const generationSlice = createSlice({
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(configChanged, (state, action) => {
|
||||
const defaultModel = action.payload.sd?.defaultModel;
|
||||
|
||||
if (defaultModel && !state.model) {
|
||||
const [base_model, model_type, model_name] = defaultModel.split('/');
|
||||
|
||||
const result = zParameterModel.safeParse({
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
});
|
||||
|
||||
if (result.success) {
|
||||
state.model = result.data;
|
||||
|
||||
const optimalDimension = getOptimalDimension(result.data);
|
||||
|
||||
state.width = optimalDimension;
|
||||
state.height = optimalDimension;
|
||||
}
|
||||
}
|
||||
if (action.payload.sd?.scheduler) {
|
||||
state.scheduler = action.payload.sd.scheduler;
|
||||
}
|
||||
|
@ -8,15 +8,15 @@ type Props = {
|
||||
onOpen: () => void;
|
||||
};
|
||||
|
||||
export const AddEmbeddingButton = memo((props: Props) => {
|
||||
export const AddPromptTriggerButton = memo((props: Props) => {
|
||||
const { onOpen, isOpen } = props;
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Tooltip label={t('embedding.addEmbedding')}>
|
||||
<Tooltip label={t('prompt.addPromptTrigger')}>
|
||||
<IconButton
|
||||
variant="promptOverlay"
|
||||
isDisabled={isOpen}
|
||||
aria-label={t('embedding.addEmbedding')}
|
||||
aria-label={t('prompt.addPromptTrigger')}
|
||||
icon={<PiCodeBold />}
|
||||
onClick={onOpen}
|
||||
/>
|
||||
@ -24,4 +24,4 @@ export const AddEmbeddingButton = memo((props: Props) => {
|
||||
);
|
||||
});
|
||||
|
||||
AddEmbeddingButton.displayName = 'AddEmbeddingButton';
|
||||
AddPromptTriggerButton.displayName = 'AddPromptTriggerButton';
|
@ -1,9 +1,9 @@
|
||||
import { Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
|
||||
import { EmbeddingSelect } from 'features/embedding/EmbeddingSelect';
|
||||
import type { EmbeddingPopoverProps } from 'features/embedding/types';
|
||||
import { PromptTriggerSelect } from 'features/prompt/PromptTriggerSelect';
|
||||
import type { PromptPopoverProps } from 'features/prompt/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
|
||||
export const PromptPopover = memo((props: PromptPopoverProps) => {
|
||||
const { onSelect, isOpen, onClose, width, children } = props;
|
||||
|
||||
return (
|
||||
@ -14,7 +14,7 @@ export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
|
||||
openDelay={0}
|
||||
closeDelay={0}
|
||||
closeOnBlur={true}
|
||||
returnFocusOnClose={true}
|
||||
returnFocusOnClose={false}
|
||||
isLazy
|
||||
>
|
||||
<PopoverAnchor>{children}</PopoverAnchor>
|
||||
@ -27,11 +27,11 @@ export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
|
||||
borderStyle="solid"
|
||||
>
|
||||
<PopoverBody p={0} width={`calc(${width}px - 0.25rem)`}>
|
||||
<EmbeddingSelect onClose={onClose} onSelect={onSelect} />
|
||||
<PromptTriggerSelect onClose={onClose} onSelect={onSelect} />
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
});
|
||||
|
||||
EmbeddingPopover.displayName = 'EmbeddingPopover';
|
||||
PromptPopover.displayName = 'PromptPopover';
|
@ -0,0 +1,21 @@
|
||||
import type { Meta, StoryObj } from '@storybook/react';
|
||||
|
||||
import { PromptTriggerSelect } from './PromptTriggerSelect';
|
||||
import type { PromptTriggerSelectProps } from './types';
|
||||
|
||||
const meta: Meta<typeof PromptTriggerSelect> = {
|
||||
title: 'Feature/Prompt/PromptTriggerSelect',
|
||||
tags: ['autodocs'],
|
||||
component: PromptTriggerSelect,
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof PromptTriggerSelect>;
|
||||
|
||||
const Component = (props: PromptTriggerSelectProps) => {
|
||||
return <PromptTriggerSelect {...props}>Invoke</PromptTriggerSelect>;
|
||||
};
|
||||
|
||||
export const Default: Story = {
|
||||
render: Component,
|
||||
};
|
@ -0,0 +1,104 @@
|
||||
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { GroupBase } from 'chakra-react-select';
|
||||
import { selectLoraSlice } from 'features/lora/store/loraSlice';
|
||||
import type { PromptTriggerSelectProps } from 'features/prompt/types';
|
||||
import { t } from 'i18next';
|
||||
import { flatten, map } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
loraModelsAdapterSelectors,
|
||||
textualInversionModelsAdapterSelectors,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetTextualInversionModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
|
||||
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
|
||||
|
||||
const selectLoRAs = createMemoizedSelector(selectLoraSlice, (loras) => loras.loras);
|
||||
|
||||
export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSelectProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
const addedLoRAs = useAppSelector(selectLoRAs);
|
||||
const { data: loraModels, isLoading: isLoadingLoRAs } = useGetLoRAModelsQuery();
|
||||
const { data: tiModels, isLoading: isLoadingTIs } = useGetTextualInversionModelsQuery();
|
||||
|
||||
const _onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!v) {
|
||||
onSelect('');
|
||||
return;
|
||||
}
|
||||
|
||||
onSelect(v.value);
|
||||
},
|
||||
[onSelect]
|
||||
);
|
||||
|
||||
const options = useMemo(() => {
|
||||
const _options: GroupBase<ComboboxOption>[] = [];
|
||||
|
||||
if (tiModels) {
|
||||
const embeddingOptions = textualInversionModelsAdapterSelectors
|
||||
.selectAll(tiModels)
|
||||
.filter((ti) => ti.base === currentBaseModel)
|
||||
.map((model) => ({ label: model.name, value: `<${model.name}>` }));
|
||||
|
||||
if (embeddingOptions.length > 0) {
|
||||
_options.push({
|
||||
label: t('prompt.compatibleEmbeddings'),
|
||||
options: embeddingOptions,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (loraModels) {
|
||||
const triggerPhraseOptions = loraModelsAdapterSelectors
|
||||
.selectAll(loraModels)
|
||||
.filter((lora) => map(addedLoRAs, (l) => l.model.key).includes(lora.key))
|
||||
.map((lora) => {
|
||||
if (lora.trigger_phrases) {
|
||||
return lora.trigger_phrases.map((triggerPhrase) => ({ label: triggerPhrase, value: triggerPhrase }));
|
||||
}
|
||||
return [];
|
||||
})
|
||||
.flatMap((x) => x);
|
||||
|
||||
if (triggerPhraseOptions.length > 0) {
|
||||
_options.push({
|
||||
label: t('modelManager.triggerPhrases'),
|
||||
options: flatten(triggerPhraseOptions),
|
||||
});
|
||||
}
|
||||
}
|
||||
return _options;
|
||||
}, [tiModels, loraModels, t, currentBaseModel, addedLoRAs]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<Combobox
|
||||
placeholder={isLoadingLoRAs || isLoadingTIs ? t('common.loading') : t('prompt.addPromptTrigger')}
|
||||
defaultMenuIsOpen
|
||||
autoFocus
|
||||
value={null}
|
||||
options={options}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
onChange={_onChange}
|
||||
onMenuClose={onClose}
|
||||
data-testid="add-prompt-trigger"
|
||||
sx={selectStyles}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
PromptTriggerSelect.displayName = 'PromptTriggerSelect';
|
||||
|
||||
const selectStyles: ChakraProps['sx'] = {
|
||||
w: 'full',
|
||||
};
|
@ -1,12 +1,12 @@
|
||||
import type { PropsWithChildren } from 'react';
|
||||
|
||||
export type EmbeddingSelectProps = {
|
||||
export type PromptTriggerSelectProps = {
|
||||
onSelect: (v: string) => void;
|
||||
onClose: () => void;
|
||||
};
|
||||
|
||||
export type EmbeddingPopoverProps = PropsWithChildren &
|
||||
EmbeddingSelectProps & {
|
||||
export type PromptPopoverProps = PropsWithChildren &
|
||||
PromptTriggerSelectProps & {
|
||||
isOpen: boolean;
|
||||
width?: number | string;
|
||||
};
|
@ -4,13 +4,13 @@ import type { ChangeEventHandler, KeyboardEventHandler, RefObject } from 'react'
|
||||
import { useCallback } from 'react';
|
||||
import { flushSync } from 'react-dom';
|
||||
|
||||
type UseInsertEmbeddingArg = {
|
||||
type UseInsertTriggerArg = {
|
||||
prompt: string;
|
||||
textareaRef: RefObject<HTMLTextAreaElement>;
|
||||
onChange: (v: string) => void;
|
||||
};
|
||||
|
||||
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertEmbeddingArg) => {
|
||||
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertTriggerArg) => {
|
||||
const { isOpen, onClose, onOpen } = useDisclosure();
|
||||
|
||||
const onChange: ChangeEventHandler<HTMLTextAreaElement> = useCallback(
|
||||
@ -20,13 +20,13 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
[_onChange]
|
||||
);
|
||||
|
||||
const insertEmbedding = useCallback(
|
||||
const insertTrigger = useCallback(
|
||||
(v: string) => {
|
||||
if (!textareaRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
// this is where we insert the TI trigger
|
||||
// this is where we insert the trigger
|
||||
const caret = textareaRef.current.selectionStart;
|
||||
|
||||
if (isNil(caret)) {
|
||||
@ -35,13 +35,9 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
|
||||
let newPrompt = prompt.slice(0, caret);
|
||||
|
||||
if (newPrompt[newPrompt.length - 1] !== '<') {
|
||||
newPrompt += '<';
|
||||
}
|
||||
newPrompt += `${v}`;
|
||||
|
||||
newPrompt += `${v}>`;
|
||||
|
||||
// we insert the cursor after the `>`
|
||||
// we insert the cursor after the end of trigger
|
||||
const finalCaretPos = newPrompt.length;
|
||||
|
||||
newPrompt += prompt.slice(caret);
|
||||
@ -51,7 +47,7 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
_onChange(newPrompt);
|
||||
});
|
||||
|
||||
// set the caret position to just after the TI trigger
|
||||
// set the cursor position to just after the trigger
|
||||
textareaRef.current.selectionStart = finalCaretPos;
|
||||
textareaRef.current.selectionEnd = finalCaretPos;
|
||||
},
|
||||
@ -62,17 +58,17 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
textareaRef.current?.focus();
|
||||
}, [textareaRef]);
|
||||
|
||||
const handleClose = useCallback(() => {
|
||||
const handleClosePopover = useCallback(() => {
|
||||
onClose();
|
||||
onFocus();
|
||||
}, [onFocus, onClose]);
|
||||
|
||||
const onSelectEmbedding = useCallback(
|
||||
const onSelect = useCallback(
|
||||
(v: string) => {
|
||||
insertEmbedding(v);
|
||||
handleClose();
|
||||
insertTrigger(v);
|
||||
handleClosePopover();
|
||||
},
|
||||
[handleClose, insertEmbedding]
|
||||
[handleClosePopover, insertTrigger]
|
||||
);
|
||||
|
||||
const onKeyDown: KeyboardEventHandler<HTMLTextAreaElement> = useCallback(
|
||||
@ -90,7 +86,7 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
|
||||
isOpen,
|
||||
onClose,
|
||||
onOpen,
|
||||
onSelectEmbedding,
|
||||
onSelect,
|
||||
onKeyDown,
|
||||
onFocus,
|
||||
};
|
@ -1,9 +1,9 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||
import { usePrompt } from 'features/embedding/usePrompt';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { setNegativeStylePromptSDXL } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
@ -20,7 +20,7 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown, onFocus } = usePrompt({
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
|
||||
prompt,
|
||||
textareaRef: textareaRef,
|
||||
onChange: handleChange,
|
||||
@ -29,12 +29,7 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
|
||||
useHotkeys('alt+a', onFocus, []);
|
||||
|
||||
return (
|
||||
<EmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={onSelectEmbedding}
|
||||
width={textareaRef.current?.clientWidth}
|
||||
>
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative">
|
||||
<Textarea
|
||||
id="prompt"
|
||||
@ -48,10 +43,10 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
|
||||
variant="darkFilled"
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</EmbeddingPopover>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
|
||||
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
|
||||
import { usePrompt } from 'features/embedding/usePrompt';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { setPositiveStylePromptSDXL } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -19,19 +19,14 @@ export const ParamSDXLPositiveStylePrompt = memo(() => {
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown } = usePrompt({
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown } = usePrompt({
|
||||
prompt,
|
||||
textareaRef: textareaRef,
|
||||
onChange: handleChange,
|
||||
});
|
||||
|
||||
return (
|
||||
<EmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={onSelectEmbedding}
|
||||
width={textareaRef.current?.clientWidth}
|
||||
>
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative">
|
||||
<Textarea
|
||||
id="prompt"
|
||||
@ -45,10 +40,10 @@ export const ParamSDXLPositiveStylePrompt = memo(() => {
|
||||
variant="darkFilled"
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</EmbeddingPopover>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
|
||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import type { JSONObject } from 'common/types';
|
||||
import queryString from 'query-string';
|
||||
import type { operations, paths } from 'services/api/schema';
|
||||
import type {
|
||||
@ -24,49 +23,33 @@ export type UpdateModelArg = {
|
||||
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
|
||||
};
|
||||
|
||||
type UpdateModelMetadataArg = {
|
||||
key: paths['/api/v2/models/i/{key}/metadata']['patch']['parameters']['path']['key'];
|
||||
body: paths['/api/v2/models/i/{key}/metadata']['patch']['requestBody']['content']['application/json'];
|
||||
};
|
||||
|
||||
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
||||
type UpdateModelMetadataResponse =
|
||||
paths['/api/v2/models/i/{key}/metadata']['patch']['responses']['200']['content']['application/json'];
|
||||
|
||||
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type GetModelMetadataResponse =
|
||||
paths['/api/v2/models/i/{key}/metadata']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
|
||||
|
||||
type DeleteMainModelArg = {
|
||||
type DeleteModelArg = {
|
||||
key: string;
|
||||
};
|
||||
|
||||
type DeleteMainModelResponse = void;
|
||||
type DeleteModelResponse = void;
|
||||
|
||||
type ConvertMainModelResponse =
|
||||
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
|
||||
|
||||
type InstallModelArg = {
|
||||
source: paths['/api/v2/models/install']['post']['parameters']['query']['source'];
|
||||
access_token?: paths['/api/v2/models/install']['post']['parameters']['query']['access_token'];
|
||||
// TODO(MM2): This is typed as `Optional[Dict[str, Any]]` in backend...
|
||||
config?: JSONObject;
|
||||
// config: NonNullable<paths['/api/v2/models/install']['post']['requestBody']>['content']['application/json'];
|
||||
};
|
||||
|
||||
type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json'];
|
||||
|
||||
type ListImportModelsResponse =
|
||||
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
|
||||
type ListModelInstallsResponse =
|
||||
paths['/api/v2/models/install']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type DeleteImportModelsResponse =
|
||||
paths['/api/v2/models/import/{id}']['delete']['responses']['201']['content']['application/json'];
|
||||
type CancelModelInstallResponse =
|
||||
paths['/api/v2/models/install/{id}']['delete']['responses']['201']['content']['application/json'];
|
||||
|
||||
type PruneModelImportsResponse =
|
||||
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
|
||||
type PruneCompletedModelInstallsResponse =
|
||||
paths['/api/v2/models/install']['delete']['responses']['200']['content']['application/json'];
|
||||
|
||||
export type ScanFolderResponse =
|
||||
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
||||
@ -83,6 +66,7 @@ const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
@ -102,6 +86,10 @@ const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelC
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors(
|
||||
undefined,
|
||||
getSelectorsOptions
|
||||
);
|
||||
const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
@ -146,31 +134,7 @@ const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
|
||||
|
||||
export const modelsApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
||||
query: (base_models) => {
|
||||
const params: ListModelsArg = {
|
||||
model_type: 'main',
|
||||
base_models,
|
||||
};
|
||||
|
||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||
return buildModelsUrl(`?${query}`);
|
||||
},
|
||||
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
|
||||
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
|
||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||
queryFulfilled.then(({ data }) => {
|
||||
upsertModelConfigs(data, dispatch);
|
||||
});
|
||||
},
|
||||
}),
|
||||
getModelMetadata: build.query<GetModelMetadataResponse, string>({
|
||||
query: (key) => {
|
||||
return buildModelsUrl(`i/${key}/metadata`);
|
||||
},
|
||||
providesTags: ['Model'],
|
||||
}),
|
||||
updateModels: build.mutation<UpdateModelResponse, UpdateModelArg>({
|
||||
updateModel: build.mutation<UpdateModelResponse, UpdateModelArg>({
|
||||
query: ({ key, body }) => {
|
||||
return {
|
||||
url: buildModelsUrl(`i/${key}`),
|
||||
@ -180,28 +144,17 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
updateModelMetadata: build.mutation<UpdateModelMetadataResponse, UpdateModelMetadataArg>({
|
||||
query: ({ key, body }) => {
|
||||
return {
|
||||
url: buildModelsUrl(`i/${key}/metadata`),
|
||||
method: 'PATCH',
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
||||
query: ({ source, config, access_token }) => {
|
||||
query: ({ source }) => {
|
||||
return {
|
||||
url: buildModelsUrl('install'),
|
||||
params: { source, access_token },
|
||||
params: { source },
|
||||
method: 'POST',
|
||||
body: config,
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model', 'ModelImports'],
|
||||
invalidatesTags: ['Model', 'ModelInstalls'],
|
||||
}),
|
||||
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
|
||||
deleteModels: build.mutation<DeleteModelResponse, DeleteModelArg>({
|
||||
query: ({ key }) => {
|
||||
return {
|
||||
url: buildModelsUrl(`i/${key}`),
|
||||
@ -210,7 +163,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
convertMainModels: build.mutation<ConvertMainModelResponse, string>({
|
||||
convertModel: build.mutation<ConvertMainModelResponse, string>({
|
||||
query: (key) => {
|
||||
return {
|
||||
url: buildModelsUrl(`convert/${key}`),
|
||||
@ -253,6 +206,57 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
scanFolder: build.query<ScanFolderResponse, ScanFolderArg>({
|
||||
query: (arg) => {
|
||||
const folderQueryStr = arg ? queryString.stringify(arg, {}) : '';
|
||||
return {
|
||||
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
|
||||
};
|
||||
},
|
||||
}),
|
||||
listModelInstalls: build.query<ListModelInstallsResponse, void>({
|
||||
query: () => {
|
||||
return {
|
||||
url: buildModelsUrl('install'),
|
||||
};
|
||||
},
|
||||
providesTags: ['ModelInstalls'],
|
||||
}),
|
||||
cancelModelInstall: build.mutation<CancelModelInstallResponse, number>({
|
||||
query: (id) => {
|
||||
return {
|
||||
url: buildModelsUrl(`install/${id}`),
|
||||
method: 'DELETE',
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['ModelInstalls'],
|
||||
}),
|
||||
pruneCompletedModelInstalls: build.mutation<PruneCompletedModelInstallsResponse, void>({
|
||||
query: () => {
|
||||
return {
|
||||
url: buildModelsUrl('install'),
|
||||
method: 'DELETE',
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['ModelInstalls'],
|
||||
}),
|
||||
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
||||
query: (base_models) => {
|
||||
const params: ListModelsArg = {
|
||||
model_type: 'main',
|
||||
base_models,
|
||||
};
|
||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||
return buildModelsUrl(`?${query}`);
|
||||
},
|
||||
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
|
||||
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
|
||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||
queryFulfilled.then(({ data }) => {
|
||||
upsertModelConfigs(data, dispatch);
|
||||
});
|
||||
},
|
||||
}),
|
||||
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
|
||||
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
|
||||
@ -313,40 +317,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
});
|
||||
},
|
||||
}),
|
||||
scanModels: build.query<ScanFolderResponse, ScanFolderArg>({
|
||||
query: (arg) => {
|
||||
const folderQueryStr = arg ? queryString.stringify(arg, {}) : '';
|
||||
return {
|
||||
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
|
||||
};
|
||||
},
|
||||
}),
|
||||
getModelImports: build.query<ListImportModelsResponse, void>({
|
||||
query: () => {
|
||||
return {
|
||||
url: buildModelsUrl(`import`),
|
||||
};
|
||||
},
|
||||
providesTags: ['ModelImports'],
|
||||
}),
|
||||
deleteModelImport: build.mutation<DeleteImportModelsResponse, number>({
|
||||
query: (id) => {
|
||||
return {
|
||||
url: buildModelsUrl(`import/${id}`),
|
||||
method: 'DELETE',
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['ModelImports'],
|
||||
}),
|
||||
pruneModelImports: build.mutation<PruneModelImportsResponse, void>({
|
||||
query: () => {
|
||||
return {
|
||||
url: buildModelsUrl('import'),
|
||||
method: 'PATCH',
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['ModelImports'],
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
@ -360,16 +330,14 @@ export const {
|
||||
useGetTextualInversionModelsQuery,
|
||||
useGetVaeModelsQuery,
|
||||
useDeleteModelsMutation,
|
||||
useUpdateModelsMutation,
|
||||
useUpdateModelMutation,
|
||||
useInstallModelMutation,
|
||||
useConvertMainModelsMutation,
|
||||
useConvertModelMutation,
|
||||
useSyncModelsMutation,
|
||||
useLazyScanModelsQuery,
|
||||
useGetModelImportsQuery,
|
||||
useGetModelMetadataQuery,
|
||||
useDeleteModelImportMutation,
|
||||
usePruneModelImportsMutation,
|
||||
useUpdateModelMetadataMutation,
|
||||
useLazyScanFolderQuery,
|
||||
useListModelInstallsQuery,
|
||||
useCancelModelInstallMutation,
|
||||
usePruneCompletedModelInstallsMutation,
|
||||
} = modelsApi;
|
||||
|
||||
const upsertModelConfigs = (
|
||||
|
@ -28,7 +28,7 @@ export const tagTypes = [
|
||||
'InvocationCacheStatus',
|
||||
'Model',
|
||||
'ModelConfig',
|
||||
'ModelImports',
|
||||
'ModelInstalls',
|
||||
'T2IAdapterModel',
|
||||
'MainModel',
|
||||
'VaeModel',
|
||||
|
File diff suppressed because one or more lines are too long
@ -43,14 +43,13 @@ export type ControlField = S['ControlField'];
|
||||
// Model Configs
|
||||
|
||||
// TODO(MM2): Can we make key required in the pydantic model?
|
||||
export type LoRAModelConfig = S['LoRAConfig'];
|
||||
export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
|
||||
// TODO(MM2): Can we rename this from Vae -> VAE
|
||||
export type VAEModelConfig = S['VaeCheckpointConfig'] | S['VaeDiffusersConfig'];
|
||||
export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
||||
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
||||
export type IPAdapterModelConfig = S['IPAdapterConfig'];
|
||||
// TODO(MM2): Can we rename this to T2IAdapterConfig
|
||||
export type T2IAdapterModelConfig = S['T2IConfig'];
|
||||
export type TextualInversionModelConfig = S['TextualInversionConfig'];
|
||||
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
||||
export type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
||||
export type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||
export type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
||||
|
@ -187,7 +187,7 @@ version = { attr = "invokeai.version.__version__" }
|
||||
|
||||
#=== Begin: PyTest and Coverage
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers --timeout 60 -m \"not slow\""
|
||||
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\""
|
||||
markers = [
|
||||
"slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\".",
|
||||
"timeout: Marks the timeout override."
|
||||
|
@ -59,12 +59,11 @@ def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase,
|
||||
def test_registration_meta_override_succeed(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
store = mm2_installer.record_store
|
||||
key = mm2_installer.register_path(
|
||||
embedding_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash", "key": "xyzzy"}
|
||||
embedding_file, {"name": "banana_sushi", "source": "fake/repo_id", "key": "xyzzy"}
|
||||
)
|
||||
model_record = store.get_model(key)
|
||||
assert model_record.name == "banana_sushi"
|
||||
assert model_record.source == "fake/repo_id"
|
||||
assert model_record.current_hash == "New Hash"
|
||||
assert model_record.key == "xyzzy"
|
||||
|
||||
|
||||
|
@ -3,28 +3,28 @@ Test the refactored model config classes.
|
||||
"""
|
||||
|
||||
from hashlib import sha256
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
ModelRecordOrderBy,
|
||||
ModelRecordServiceBase,
|
||||
ModelRecordServiceSQL,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
MainCheckpointConfig,
|
||||
MainDiffusersConfig,
|
||||
ModelFormat,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
TextualInversionConfig,
|
||||
VaeDiffusersConfig,
|
||||
TextualInversionFileConfig,
|
||||
VAEDiffusersConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import BaseMetadata
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
@ -37,90 +37,76 @@ def store(
|
||||
config = InvokeAIAppConfig(root=datadir)
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
db = create_mock_sqlite_database(config, logger)
|
||||
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
return ModelRecordServiceSQL(db)
|
||||
|
||||
|
||||
def example_config() -> TextualInversionConfig:
|
||||
return TextualInversionConfig(
|
||||
def example_ti_config(key: Optional[str] = None) -> TextualInversionFileConfig:
|
||||
config = TextualInversionFileConfig(
|
||||
source="test/source/",
|
||||
source_type=ModelSourceType.Path,
|
||||
path="/tmp/pokemon.bin",
|
||||
name="old name",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("embedding"),
|
||||
format="embedding_file",
|
||||
original_hash="ABC123",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.TextualInversion,
|
||||
format=ModelFormat.EmbeddingFile,
|
||||
hash="ABC123",
|
||||
)
|
||||
if key is not None:
|
||||
config.key = key
|
||||
return config
|
||||
|
||||
|
||||
def test_type(store: ModelRecordServiceBase):
|
||||
config = example_config()
|
||||
store.add_model("key1", config)
|
||||
config = example_ti_config("key1")
|
||||
store.add_model(config)
|
||||
config1 = store.get_model("key1")
|
||||
assert type(config1) == TextualInversionConfig
|
||||
assert isinstance(config1, TextualInversionFileConfig)
|
||||
|
||||
|
||||
def test_add(store: ModelRecordServiceBase):
|
||||
raw = {
|
||||
"path": "/tmp/foo.ckpt",
|
||||
"name": "model1",
|
||||
"base": BaseModelType("sd-1"),
|
||||
"type": "main",
|
||||
"config": "/tmp/foo.yaml",
|
||||
"variant": "normal",
|
||||
"format": "checkpoint",
|
||||
"original_hash": "111222333444",
|
||||
}
|
||||
store.add_model("key1", raw)
|
||||
config1 = store.get_model("key1")
|
||||
assert config1 is not None
|
||||
assert type(config1) == MainCheckpointConfig
|
||||
assert config1.base == BaseModelType("sd-1")
|
||||
assert config1.name == "model1"
|
||||
assert config1.original_hash == "111222333444"
|
||||
assert config1.current_hash is None
|
||||
|
||||
|
||||
def test_dup(store: ModelRecordServiceBase):
|
||||
config = example_config()
|
||||
store.add_model("key1", example_config())
|
||||
def test_raises_on_violating_uniqueness(store: ModelRecordServiceBase):
|
||||
# Models have a uniqueness constraint by their name, base and type
|
||||
config1 = example_ti_config("key1")
|
||||
config2 = config1.model_copy(deep=True)
|
||||
config2.key = "key2"
|
||||
store.add_model(config1)
|
||||
with pytest.raises(DuplicateModelException):
|
||||
store.add_model("key1", config)
|
||||
store.add_model(config1)
|
||||
with pytest.raises(DuplicateModelException):
|
||||
store.add_model("key2", config)
|
||||
store.add_model(config2)
|
||||
|
||||
|
||||
def test_update(store: ModelRecordServiceBase):
|
||||
config = example_config()
|
||||
store.add_model("key1", config)
|
||||
def test_model_records_updates_model(store: ModelRecordServiceBase):
|
||||
config = example_ti_config("key1")
|
||||
store.add_model(config)
|
||||
config = store.get_model("key1")
|
||||
assert config.name == "old name"
|
||||
|
||||
config.name = "new name"
|
||||
store.update_model("key1", config)
|
||||
new_name = "new name"
|
||||
changes = ModelRecordChanges(name=new_name)
|
||||
store.update_model(config.key, changes)
|
||||
new_config = store.get_model("key1")
|
||||
assert new_config.name == "new name"
|
||||
assert new_config.name == new_name
|
||||
|
||||
|
||||
def test_rename(store: ModelRecordServiceBase):
|
||||
config = example_config()
|
||||
store.add_model("key1", config)
|
||||
def test_model_records_rejects_invalid_changes(store: ModelRecordServiceBase):
|
||||
config = example_ti_config("key1")
|
||||
store.add_model(config)
|
||||
config = store.get_model("key1")
|
||||
assert config.name == "old name"
|
||||
|
||||
store.rename_model("key1", "new name")
|
||||
new_config = store.get_model("key1")
|
||||
assert new_config.name == "new name"
|
||||
# upcast_attention is an invalid field for TIs
|
||||
changes = ModelRecordChanges(upcast_attention=True)
|
||||
with pytest.raises(ValidationError):
|
||||
store.update_model(config.key, changes)
|
||||
|
||||
|
||||
def test_unknown_key(store: ModelRecordServiceBase):
|
||||
config = example_config()
|
||||
store.add_model("key1", config)
|
||||
config = example_ti_config("key1")
|
||||
store.add_model(config)
|
||||
with pytest.raises(UnknownModelException):
|
||||
store.update_model("unknown_key", config)
|
||||
store.update_model("unknown_key", ModelRecordChanges())
|
||||
|
||||
|
||||
def test_delete(store: ModelRecordServiceBase):
|
||||
config = example_config()
|
||||
store.add_model("key1", config)
|
||||
config = example_ti_config("key1")
|
||||
store.add_model(config)
|
||||
config = store.get_model("key1")
|
||||
store.del_model("key1")
|
||||
with pytest.raises(UnknownModelException):
|
||||
@ -128,49 +114,58 @@ def test_delete(store: ModelRecordServiceBase):
|
||||
|
||||
|
||||
def test_exists(store: ModelRecordServiceBase):
|
||||
config = example_config()
|
||||
store.add_model("key1", config)
|
||||
config = example_ti_config("key1")
|
||||
store.add_model(config)
|
||||
assert store.exists("key1")
|
||||
assert not store.exists("key2")
|
||||
|
||||
|
||||
def test_filter(store: ModelRecordServiceBase):
|
||||
config1 = MainDiffusersConfig(
|
||||
key="config1",
|
||||
path="/tmp/config1",
|
||||
name="config1",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG1HASH",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
hash="CONFIG1HASH",
|
||||
source="test/source",
|
||||
source_type=ModelSourceType.Path,
|
||||
)
|
||||
config2 = MainDiffusersConfig(
|
||||
key="config2",
|
||||
path="/tmp/config2",
|
||||
name="config2",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG2HASH",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
hash="CONFIG2HASH",
|
||||
source="test/source",
|
||||
source_type=ModelSourceType.Path,
|
||||
)
|
||||
config3 = VaeDiffusersConfig(
|
||||
config3 = VAEDiffusersConfig(
|
||||
key="config3",
|
||||
path="/tmp/config3",
|
||||
name="config3",
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType("vae"),
|
||||
original_hash="CONFIG3HASH",
|
||||
type=ModelType.VAE,
|
||||
hash="CONFIG3HASH",
|
||||
source="test/source",
|
||||
source_type=ModelSourceType.Path,
|
||||
)
|
||||
for c in config1, config2, config3:
|
||||
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
|
||||
matches = store.search_by_attr(model_type=ModelType("main"))
|
||||
store.add_model(c)
|
||||
matches = store.search_by_attr(model_type=ModelType.Main)
|
||||
assert len(matches) == 2
|
||||
assert matches[0].name in {"config1", "config2"}
|
||||
|
||||
matches = store.search_by_attr(model_type=ModelType("vae"))
|
||||
matches = store.search_by_attr(model_type=ModelType.VAE)
|
||||
assert len(matches) == 1
|
||||
assert matches[0].name == "config3"
|
||||
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
|
||||
assert matches[0].key == "config3"
|
||||
assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back
|
||||
|
||||
matches = store.search_by_hash("CONFIG1HASH")
|
||||
assert len(matches) == 1
|
||||
assert matches[0].original_hash == "CONFIG1HASH"
|
||||
assert matches[0].hash == "CONFIG1HASH"
|
||||
|
||||
matches = store.all_models()
|
||||
assert len(matches) == 3
|
||||
@ -179,143 +174,116 @@ def test_filter(store: ModelRecordServiceBase):
|
||||
def test_unique(store: ModelRecordServiceBase):
|
||||
config1 = MainDiffusersConfig(
|
||||
path="/tmp/config1",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
name="nonuniquename",
|
||||
original_hash="CONFIG1HASH",
|
||||
hash="CONFIG1HASH",
|
||||
source="test/source/",
|
||||
source_type=ModelSourceType.Path,
|
||||
)
|
||||
config2 = MainDiffusersConfig(
|
||||
path="/tmp/config2",
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType("main"),
|
||||
type=ModelType.Main,
|
||||
name="nonuniquename",
|
||||
original_hash="CONFIG1HASH",
|
||||
hash="CONFIG1HASH",
|
||||
source="test/source/",
|
||||
source_type=ModelSourceType.Path,
|
||||
)
|
||||
config3 = VaeDiffusersConfig(
|
||||
config3 = VAEDiffusersConfig(
|
||||
path="/tmp/config3",
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType("vae"),
|
||||
type=ModelType.VAE,
|
||||
name="nonuniquename",
|
||||
original_hash="CONFIG1HASH",
|
||||
hash="CONFIG1HASH",
|
||||
source="test/source/",
|
||||
source_type=ModelSourceType.Path,
|
||||
)
|
||||
config4 = MainDiffusersConfig(
|
||||
path="/tmp/config4",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
name="nonuniquename",
|
||||
original_hash="CONFIG1HASH",
|
||||
hash="CONFIG1HASH",
|
||||
source="test/source/",
|
||||
source_type=ModelSourceType.Path,
|
||||
)
|
||||
# config1, config2 and config3 are compatible because they have unique combos
|
||||
# of name, type and base
|
||||
for c in config1, config2, config3:
|
||||
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
|
||||
c.key = sha256(c.path.encode("utf-8")).hexdigest()
|
||||
store.add_model(c)
|
||||
|
||||
# config4 clashes with config1 and should raise an integrity error
|
||||
with pytest.raises(DuplicateModelException):
|
||||
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), config4)
|
||||
config4.key = sha256(config4.path.encode("utf-8")).hexdigest()
|
||||
store.add_model(config4)
|
||||
|
||||
|
||||
def test_filter_2(store: ModelRecordServiceBase):
|
||||
config1 = MainDiffusersConfig(
|
||||
path="/tmp/config1",
|
||||
name="config1",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG1HASH",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
hash="CONFIG1HASH",
|
||||
source="test/source/",
|
||||
source_type=ModelSourceType.Path,
|
||||
)
|
||||
config2 = MainDiffusersConfig(
|
||||
path="/tmp/config2",
|
||||
name="config2",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG2HASH",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
hash="CONFIG2HASH",
|
||||
source="test/source/",
|
||||
source_type=ModelSourceType.Path,
|
||||
)
|
||||
config3 = MainDiffusersConfig(
|
||||
path="/tmp/config3",
|
||||
name="dup_name1",
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG3HASH",
|
||||
type=ModelType.Main,
|
||||
hash="CONFIG3HASH",
|
||||
source="test/source/",
|
||||
source_type=ModelSourceType.Path,
|
||||
)
|
||||
config4 = MainDiffusersConfig(
|
||||
path="/tmp/config4",
|
||||
name="dup_name1",
|
||||
base=BaseModelType("sdxl"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG3HASH",
|
||||
type=ModelType.Main,
|
||||
hash="CONFIG3HASH",
|
||||
source="test/source/",
|
||||
source_type=ModelSourceType.Path,
|
||||
)
|
||||
config5 = VaeDiffusersConfig(
|
||||
config5 = VAEDiffusersConfig(
|
||||
path="/tmp/config5",
|
||||
name="dup_name1",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("vae"),
|
||||
original_hash="CONFIG3HASH",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.VAE,
|
||||
hash="CONFIG3HASH",
|
||||
source="test/source/",
|
||||
source_type=ModelSourceType.Path,
|
||||
)
|
||||
for c in config1, config2, config3, config4, config5:
|
||||
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
|
||||
store.add_model(c)
|
||||
|
||||
matches = store.search_by_attr(
|
||||
model_type=ModelType("main"),
|
||||
model_type=ModelType.Main,
|
||||
model_name="dup_name1",
|
||||
)
|
||||
assert len(matches) == 2
|
||||
|
||||
matches = store.search_by_attr(
|
||||
base_model=BaseModelType("sd-1"),
|
||||
model_type=ModelType("main"),
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
model_type=ModelType.Main,
|
||||
)
|
||||
assert len(matches) == 2
|
||||
|
||||
matches = store.search_by_attr(
|
||||
base_model=BaseModelType("sd-1"),
|
||||
model_type=ModelType("vae"),
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
model_type=ModelType.VAE,
|
||||
model_name="dup_name1",
|
||||
)
|
||||
assert len(matches) == 1
|
||||
|
||||
|
||||
def test_summary(mm2_record_store: ModelRecordServiceSQL) -> None:
|
||||
# The fixture provides us with five configs.
|
||||
for x in range(1, 5):
|
||||
key = f"test_config_{x}"
|
||||
name = f"name_{x}"
|
||||
author = f"author_{x}"
|
||||
tags = {f"tag{y}" for y in range(1, x)}
|
||||
mm2_record_store.metadata_store.add_metadata(
|
||||
model_key=key, metadata=BaseMetadata(name=name, author=author, tags=tags)
|
||||
)
|
||||
# sanity check that the tags sent in all right
|
||||
assert mm2_record_store.get_metadata("test_config_3").tags == {"tag1", "tag2"}
|
||||
assert mm2_record_store.get_metadata("test_config_4").tags == {"tag1", "tag2", "tag3"}
|
||||
|
||||
# get summary
|
||||
summary1 = mm2_record_store.list_models(page=0, per_page=100)
|
||||
assert summary1.page == 0
|
||||
assert summary1.pages == 1
|
||||
assert summary1.per_page == 100
|
||||
assert summary1.total == 5
|
||||
assert len(summary1.items) == 5
|
||||
assert summary1.items[0].name == "test5" # lora / sd-1 / diffusers / test5
|
||||
|
||||
# find test_config_3
|
||||
config3 = [x for x in summary1.items if x.key == "test_config_3"][0]
|
||||
assert config3.description == "This is test 3"
|
||||
assert config3.tags == {"tag1", "tag2"}
|
||||
|
||||
# find test_config_5
|
||||
config5 = [x for x in summary1.items if x.key == "test_config_5"][0]
|
||||
assert config5.tags == set()
|
||||
assert config5.description == ""
|
||||
|
||||
# test paging
|
||||
summary2 = mm2_record_store.list_models(page=1, per_page=2)
|
||||
assert summary2.page == 1
|
||||
assert summary2.per_page == 2
|
||||
assert summary2.pages == 3
|
||||
assert summary1.items[2].name == summary2.items[0].name
|
||||
|
||||
# test sorting
|
||||
summary = mm2_record_store.list_models(page=0, per_page=100, order_by=ModelRecordOrderBy.Name)
|
||||
print(summary.items)
|
||||
assert summary.items[0].name == "model1"
|
||||
assert summary.items[-1].name == "test5"
|
||||
|
@ -18,12 +18,17 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
|
||||
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
LoRADiffusersConfig,
|
||||
MainCheckpointConfig,
|
||||
MainDiffusersConfig,
|
||||
ModelFormat,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
VAEDiffusersConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -107,11 +112,6 @@ def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> Downloa
|
||||
return download_queue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
|
||||
return mm2_record_store.metadata_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
||||
ram_cache = ModelCache(
|
||||
@ -137,7 +137,7 @@ def mm2_installer(
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
events = DummyEventService()
|
||||
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
store = ModelRecordServiceSQL(db)
|
||||
|
||||
installer = ModelInstallService(
|
||||
app_config=mm2_app_config,
|
||||
@ -160,61 +160,71 @@ def mm2_installer(
|
||||
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
store = ModelRecordServiceSQL(db)
|
||||
# add five simple config records to the database
|
||||
raw1 = {
|
||||
"path": "/tmp/foo1",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test2",
|
||||
"base": BaseModelType("sd-2"),
|
||||
"type": ModelType("vae"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "stabilityai/sdxl-vae",
|
||||
}
|
||||
raw2 = {
|
||||
"path": "/tmp/foo2.ckpt",
|
||||
"name": "model1",
|
||||
"format": ModelFormat("checkpoint"),
|
||||
"base": BaseModelType("sd-1"),
|
||||
"type": "main",
|
||||
"config": "/tmp/foo.yaml",
|
||||
"variant": "normal",
|
||||
"original_hash": "111222333444",
|
||||
"source": "https://civitai.com/models/206883/split",
|
||||
}
|
||||
raw3 = {
|
||||
"path": "/tmp/foo3",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test3",
|
||||
"base": BaseModelType("sdxl"),
|
||||
"type": ModelType("main"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "author3/model3",
|
||||
"description": "This is test 3",
|
||||
}
|
||||
raw4 = {
|
||||
"path": "/tmp/foo4",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test4",
|
||||
"base": BaseModelType("sdxl"),
|
||||
"type": ModelType("lora"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "author4/model4",
|
||||
}
|
||||
raw5 = {
|
||||
"path": "/tmp/foo5",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test5",
|
||||
"base": BaseModelType("sd-1"),
|
||||
"type": ModelType("lora"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "author4/model5",
|
||||
}
|
||||
store.add_model("test_config_1", raw1)
|
||||
store.add_model("test_config_2", raw2)
|
||||
store.add_model("test_config_3", raw3)
|
||||
store.add_model("test_config_4", raw4)
|
||||
store.add_model("test_config_5", raw5)
|
||||
config1 = VAEDiffusersConfig(
|
||||
key="test_config_1",
|
||||
path="/tmp/foo1",
|
||||
format=ModelFormat.Diffusers,
|
||||
name="test2",
|
||||
base=BaseModelType.StableDiffusion2,
|
||||
type=ModelType.VAE,
|
||||
hash="111222333444",
|
||||
source="stabilityai/sdxl-vae",
|
||||
source_type=ModelSourceType.HFRepoID,
|
||||
)
|
||||
config2 = MainCheckpointConfig(
|
||||
key="test_config_2",
|
||||
path="/tmp/foo2.ckpt",
|
||||
name="model1",
|
||||
format=ModelFormat.Checkpoint,
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
config_path="/tmp/foo.yaml",
|
||||
variant=ModelVariantType.Normal,
|
||||
hash="111222333444",
|
||||
source="https://civitai.com/models/206883/split",
|
||||
source_type=ModelSourceType.CivitAI,
|
||||
)
|
||||
config3 = MainDiffusersConfig(
|
||||
key="test_config_3",
|
||||
path="/tmp/foo3",
|
||||
format=ModelFormat.Diffusers,
|
||||
name="test3",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
type=ModelType.Main,
|
||||
hash="111222333444",
|
||||
source="author3/model3",
|
||||
description="This is test 3",
|
||||
source_type=ModelSourceType.HFRepoID,
|
||||
)
|
||||
config4 = LoRADiffusersConfig(
|
||||
key="test_config_4",
|
||||
path="/tmp/foo4",
|
||||
format=ModelFormat.Diffusers,
|
||||
name="test4",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
type=ModelType.LoRA,
|
||||
hash="111222333444",
|
||||
source="author4/model4",
|
||||
source_type=ModelSourceType.HFRepoID,
|
||||
)
|
||||
config5 = LoRADiffusersConfig(
|
||||
key="test_config_5",
|
||||
path="/tmp/foo5",
|
||||
format=ModelFormat.Diffusers,
|
||||
name="test5",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.LoRA,
|
||||
hash="111222333444",
|
||||
source="author4/model5",
|
||||
source_type=ModelSourceType.HFRepoID,
|
||||
)
|
||||
store.add_model(config1)
|
||||
store.add_model(config2)
|
||||
store.add_model(config3)
|
||||
store.add_model(config4)
|
||||
store.add_model(config5)
|
||||
return store
|
||||
|
||||
|
||||
|
@ -1,202 +0,0 @@
|
||||
"""
|
||||
Test model metadata fetching and storage.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pydantic.networks import HttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreBase
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
CivitaiMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
CommercialUsage,
|
||||
HuggingFaceMetadata,
|
||||
HuggingFaceMetadataFetch,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
from invokeai.backend.model_manager.util import select_hf_files
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
|
||||
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
||||
tags = {"text-to-image", "diffusers"}
|
||||
input_metadata = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags=tags,
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
mm2_metadata_store.add_metadata("test_config_1", input_metadata)
|
||||
output_metadata = mm2_metadata_store.get_metadata("test_config_1")
|
||||
assert input_metadata == output_metadata
|
||||
with pytest.raises(UnknownMetadataException):
|
||||
mm2_metadata_store.add_metadata("unknown_key", input_metadata)
|
||||
assert mm2_metadata_store.list_tags() == tags
|
||||
|
||||
|
||||
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
||||
input_metadata = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers"},
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
mm2_metadata_store.add_metadata("test_config_1", input_metadata)
|
||||
input_metadata.name = "new-name"
|
||||
mm2_metadata_store.update_metadata("test_config_1", input_metadata)
|
||||
output_metadata = mm2_metadata_store.get_metadata("test_config_1")
|
||||
assert output_metadata.name == "new-name"
|
||||
assert input_metadata == output_metadata
|
||||
|
||||
|
||||
def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
||||
metadata1 = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers"},
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata2 = HuggingFaceMetadata(
|
||||
name="model2",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers", "community-contributed"},
|
||||
id="author2/model2",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata3 = HuggingFaceMetadata(
|
||||
name="model3",
|
||||
author="author3",
|
||||
tags={"text-to-image", "checkpoint", "community-contributed"},
|
||||
id="author3/model3",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
mm2_metadata_store.add_metadata("test_config_1", metadata1)
|
||||
mm2_metadata_store.add_metadata("test_config_2", metadata2)
|
||||
mm2_metadata_store.add_metadata("test_config_3", metadata3)
|
||||
|
||||
matches = mm2_metadata_store.search_by_author("stabilityai")
|
||||
assert len(matches) == 2
|
||||
assert "test_config_1" in matches
|
||||
assert "test_config_2" in matches
|
||||
matches = mm2_metadata_store.search_by_author("Sherlock Holmes")
|
||||
assert not matches
|
||||
|
||||
matches = mm2_metadata_store.search_by_name("model3")
|
||||
assert len(matches) == 1
|
||||
assert "test_config_3" in matches
|
||||
|
||||
matches = mm2_metadata_store.search_by_tag({"text-to-image"})
|
||||
assert len(matches) == 3
|
||||
|
||||
matches = mm2_metadata_store.search_by_tag({"text-to-image", "diffusers"})
|
||||
assert len(matches) == 2
|
||||
assert "test_config_1" in matches
|
||||
assert "test_config_2" in matches
|
||||
|
||||
matches = mm2_metadata_store.search_by_tag({"checkpoint", "community-contributed"})
|
||||
assert len(matches) == 1
|
||||
assert "test_config_3" in matches
|
||||
|
||||
# does the tag table update correctly?
|
||||
matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
|
||||
assert not matches
|
||||
assert mm2_metadata_store.list_tags() == {"text-to-image", "diffusers", "community-contributed", "checkpoint"}
|
||||
metadata3.tags.add("licensed-for-commercial-use")
|
||||
mm2_metadata_store.update_metadata("test_config_3", metadata3)
|
||||
assert mm2_metadata_store.list_tags() == {
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
"community-contributed",
|
||||
"checkpoint",
|
||||
"licensed-for-commercial-use",
|
||||
}
|
||||
matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
|
||||
assert len(matches) == 1
|
||||
|
||||
|
||||
def test_metadata_civitai_fetch(mm2_session: Session) -> None:
|
||||
fetcher = CivitaiMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_url(HttpUrl("https://civitai.com/models/215485/SDXL-turbo"))
|
||||
assert isinstance(metadata, CivitaiMetadata)
|
||||
assert metadata.id == 215485
|
||||
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
|
||||
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
|
||||
assert CommercialUsage("RentCivit") in metadata.restrictions.AllowCommercialUse
|
||||
assert metadata.version_id == 242807
|
||||
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
|
||||
|
||||
|
||||
def test_metadata_hf_fetch(mm2_session: Session) -> None:
|
||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
||||
assert metadata.author == "test_author" # this is not the same as the original
|
||||
assert metadata.files
|
||||
assert metadata.tags == {
|
||||
"diffusers",
|
||||
"onnx",
|
||||
"safetensors",
|
||||
"text-to-image",
|
||||
"license:other",
|
||||
"has_space",
|
||||
"diffusers:StableDiffusionXLPipeline",
|
||||
"region:us",
|
||||
}
|
||||
|
||||
|
||||
def test_metadata_hf_filter(mm2_session: Session) -> None:
|
||||
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
||||
files = [x.path for x in metadata.files]
|
||||
fp16_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16"))
|
||||
assert Path("sdxl-turbo/text_encoder/model.fp16.safetensors") in fp16_files
|
||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") not in fp16_files
|
||||
|
||||
fp32_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp32"))
|
||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") in fp32_files
|
||||
assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in fp32_files
|
||||
|
||||
onnx_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("onnx"))
|
||||
assert Path("sdxl-turbo/text_encoder/model.onnx") in onnx_files
|
||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") not in onnx_files
|
||||
|
||||
default_files = select_hf_files.filter_files(files)
|
||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") in default_files
|
||||
assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in default_files
|
||||
|
||||
openvino_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("openvino"))
|
||||
print(openvino_files)
|
||||
assert len(openvino_files) == 0
|
||||
|
||||
flax_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("flax"))
|
||||
print(flax_files)
|
||||
assert not flax_files
|
||||
|
||||
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(
|
||||
HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo-nofp16")
|
||||
)
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
||||
files = [x.path for x in metadata.files]
|
||||
filtered_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16"))
|
||||
assert (
|
||||
Path("sdxl-turbo-nofp16/text_encoder/model.safetensors") in filtered_files
|
||||
) # confirm that default is returned
|
||||
assert Path("sdxl-turbo-nofp16/text_encoder/model.16.safetensors") not in filtered_files
|
||||
|
||||
|
||||
def test_metadata_hf_urls(mm2_session: Session) -> None:
|
||||
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
@ -104,7 +104,7 @@ def sdxl_base_files() -> List[Path]:
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.DEFAULT,
|
||||
ModelRepoVariant.Default,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
@ -129,7 +129,7 @@ def sdxl_base_files() -> List[Path]:
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.OPENVINO,
|
||||
ModelRepoVariant.OpenVINO,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
@ -211,7 +211,7 @@ def sdxl_base_files() -> List[Path]:
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.FLAX,
|
||||
ModelRepoVariant.Flax,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
@ -235,7 +235,94 @@ def sdxl_base_files() -> List[Path]:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[Path]) -> None:
|
||||
def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[str]) -> None:
|
||||
print(f"testing variant {variant}")
|
||||
filtered_files = filter_files(sdxl_base_files, variant)
|
||||
assert set(filtered_files) == {Path(x) for x in expected_list}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sd15_test_files() -> list[Path]:
|
||||
return [
|
||||
Path(f)
|
||||
for f in [
|
||||
"feature_extractor/preprocessor_config.json",
|
||||
"safety_checker/config.json",
|
||||
"safety_checker/model.fp16.safetensors",
|
||||
"safety_checker/model.safetensors",
|
||||
"safety_checker/pytorch_model.bin",
|
||||
"safety_checker/pytorch_model.fp16.bin",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.fp16.safetensors",
|
||||
"text_encoder/model.safetensors",
|
||||
"text_encoder/pytorch_model.bin",
|
||||
"text_encoder/pytorch_model.fp16.bin",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
"unet/diffusion_pytorch_model.non_ema.bin",
|
||||
"unet/diffusion_pytorch_model.non_ema.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.bin",
|
||||
"vae/diffusion_pytorch_model.fp16.bin",
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"variant,expected_files",
|
||||
[
|
||||
(
|
||||
ModelRepoVariant.FP16,
|
||||
[
|
||||
"feature_extractor/preprocessor_config.json",
|
||||
"safety_checker/config.json",
|
||||
"safety_checker/model.fp16.safetensors",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.fp16.safetensors",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.FP32,
|
||||
[
|
||||
"feature_extractor/preprocessor_config.json",
|
||||
"safety_checker/config.json",
|
||||
"safety_checker/model.safetensors",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.safetensors",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_select_multiple_weights(
|
||||
sd15_test_files: list[Path], variant: ModelRepoVariant, expected_files: list[str]
|
||||
) -> None:
|
||||
filtered_files = filter_files(sd15_test_files, variant)
|
||||
assert set(filtered_files) == {Path(f) for f in expected_files}
|
||||
|
@ -21,7 +21,7 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat
|
||||
base_type = probe.get_base_type()
|
||||
assert base_type == expected_type
|
||||
repo_variant = probe.get_repo_variant()
|
||||
assert repo_variant == ModelRepoVariant.DEFAULT
|
||||
assert repo_variant == ModelRepoVariant.Default
|
||||
|
||||
|
||||
def test_repo_variant(datadir: Path):
|
||||
|
Reference in New Issue
Block a user