mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
94 Commits
optimizati
...
separate-g
Author | SHA1 | Date | |
---|---|---|---|
cd3f5f30dc | |||
71ee28ac12 | |||
46c904d08a | |||
7d5a88b69d | |||
afa4df1991 | |||
e30cb4b52f | |||
ba1f6bf926 | |||
4a9cca6c2d | |||
b0275700b3 | |||
8319aca5f9 | |||
51a604f907 | |||
7515d73628 | |||
2c453aa531 | |||
2cca6e4c76 | |||
ef171e890a | |||
caafbf2f0d | |||
2db5eaf907 | |||
f234bf6256 | |||
cfa78b4052 | |||
ba1dd4b02b | |||
bcf58cac59 | |||
e866d90ab2 | |||
e8797787cf | |||
0082ecb22b | |||
656839fcd1 | |||
99407c899f | |||
48119d9010 | |||
7c9128b253 | |||
4f9bb00275 | |||
78895b3e80 | |||
3030a34b88 | |||
58fa9c2fac | |||
a8b6635050 | |||
6829610a71 | |||
5551cf8ac4 | |||
37b969d339 | |||
c953e61294 | |||
93dd3c848e | |||
02bde7bb75 | |||
3391c19926 | |||
0f60b1ced4 | |||
44c40d7d1a | |||
0b9a212363 | |||
c3aa985c93 | |||
7cb0da1f66 | |||
3534366146 | |||
f2b5f8753f | |||
f13f5984c0 | |||
94e1e64296 | |||
2411bf53c0 | |||
9378e47a06 | |||
4471ea8ad1 | |||
2c835fd550 | |||
61b737bb9f | |||
a8cd3dfc99 | |||
0cce582f2f | |||
4347d1c7f7 | |||
bd4fd9693d | |||
9b40c28144 | |||
16a5d718bf | |||
76cbc745e1 | |||
0a614943f6 | |||
e426096d32 | |||
c561cd751f | |||
af9298f0ef | |||
5b74117836 | |||
38474c9797 | |||
b880a31039 | |||
dd31bc4586 | |||
316573df2d | |||
8b34f5298c | |||
893bcd16fc | |||
f6028a4c61 | |||
264aee3ffa | |||
4deb60f365 | |||
f2d5fb176f | |||
94005b5501 | |||
02dc1a8780 | |||
ef958568ac | |||
48e323d887 | |||
735857479d | |||
2f372d9b18 | |||
554d175792 | |||
ae99428883 | |||
863ce00712 | |||
86982f3059 | |||
ec8ed530a7 | |||
982076d7d7 | |||
2e4672f931 | |||
908e915a71 | |||
a72056e0df | |||
d8d7ddf43a | |||
cc45007dc4 | |||
73bec56c59 |
@ -32,7 +32,6 @@ model. These are the:
|
||||
Responsible for loading a model from disk
|
||||
into RAM and VRAM and getting it ready for inference.
|
||||
|
||||
|
||||
## Location of the Code
|
||||
|
||||
The four main services can be found in
|
||||
@ -63,23 +62,21 @@ provides the following fields:
|
||||
|----------------|-----------------|------------------|
|
||||
| `key` | str | Unique identifier for the model |
|
||||
| `name` | str | Name of the model (not unique) |
|
||||
| `model_type` | ModelType | The type of the model |
|
||||
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
|
||||
| `base_model` | BaseModelType | The base model that the model is compatible with |
|
||||
| `model_type` | ModelType | The type of the model |
|
||||
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
|
||||
| `base_model` | BaseModelType | The base model that the model is compatible with |
|
||||
| `path` | str | Location of model on disk |
|
||||
| `original_hash` | str | Hash of the model when it was first installed |
|
||||
| `current_hash` | str | Most recent hash of the model's contents |
|
||||
| `hash` | str | Hash of the model |
|
||||
| `description` | str | Human-readable description of the model (optional) |
|
||||
| `source` | str | Model's source URL or repo id (optional) |
|
||||
|
||||
The `key` is a unique 32-character random ID which was generated at
|
||||
install time. The `original_hash` field stores a hash of the model's
|
||||
install time. The `hash` field stores a hash of the model's
|
||||
contents at install time obtained by sampling several parts of the
|
||||
model's files using the `imohash` library. Over the course of the
|
||||
model's lifetime it may be transformed in various ways, such as
|
||||
changing its precision or converting it from a .safetensors to a
|
||||
diffusers model. When this happens, `original_hash` is unchanged, but
|
||||
`current_hash` is updated to indicate the current contents.
|
||||
diffusers model.
|
||||
|
||||
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
|
||||
are defined in `invokeai.backend.model_manager.config`. They are also
|
||||
@ -94,7 +91,6 @@ The `path` field can be absolute or relative. If relative, it is taken
|
||||
to be relative to the `models_dir` setting in the user's
|
||||
`invokeai.yaml` file.
|
||||
|
||||
|
||||
### CheckpointConfig
|
||||
|
||||
This adds support for checkpoint configurations, and adds the
|
||||
@ -174,7 +170,7 @@ store = context.services.model_manager.store
|
||||
or from elsewhere in the code by accessing
|
||||
`ApiDependencies.invoker.services.model_manager.store`.
|
||||
|
||||
### Creating a `ModelRecordService`
|
||||
### Creating a `ModelRecordService`
|
||||
|
||||
To create a new `ModelRecordService` database or open an existing one,
|
||||
you can directly create either a `ModelRecordServiceSQL` or a
|
||||
@ -217,27 +213,27 @@ for use in the InvokeAI web server. Its signature is:
|
||||
```
|
||||
def open(
|
||||
cls,
|
||||
config: InvokeAIAppConfig,
|
||||
conn: Optional[sqlite3.Connection] = None,
|
||||
lock: Optional[threading.Lock] = None
|
||||
config: InvokeAIAppConfig,
|
||||
conn: Optional[sqlite3.Connection] = None,
|
||||
lock: Optional[threading.Lock] = None
|
||||
) -> Union[ModelRecordServiceSQL, ModelRecordServiceFile]:
|
||||
```
|
||||
|
||||
The way it works is as follows:
|
||||
|
||||
1. Retrieve the value of the `model_config_db` option from the user's
|
||||
`invokeai.yaml` config file.
|
||||
`invokeai.yaml` config file.
|
||||
2. If `model_config_db` is `auto` (the default), then:
|
||||
- Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
|
||||
opened on the passed connection and lock.
|
||||
- Open up a new connection to `databases/invokeai.db` if `conn`
|
||||
* Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
|
||||
opened on the passed connection and lock.
|
||||
* Open up a new connection to `databases/invokeai.db` if `conn`
|
||||
and/or `lock` are missing (see note below).
|
||||
3. If `model_config_db` is a Path, then use `from_db_file`
|
||||
to return the appropriate type of ModelRecordService.
|
||||
4. If `model_config_db` is None, then retrieve the legacy
|
||||
`conf_path` option from `invokeai.yaml` and use the Path
|
||||
indicated there. This will default to `configs/models.yaml`.
|
||||
|
||||
|
||||
So a typical startup pattern would be:
|
||||
|
||||
```
|
||||
@ -255,7 +251,7 @@ store = ModelRecordServiceBase.open(config, db_conn, lock)
|
||||
|
||||
Configurations can be retrieved in several ways.
|
||||
|
||||
#### get_model(key) -> AnyModelConfig:
|
||||
#### get_model(key) -> AnyModelConfig
|
||||
|
||||
The basic functionality is to call the record store object's
|
||||
`get_model()` method with the desired model's unique key. It returns
|
||||
@ -272,28 +268,28 @@ print(model_conf.path)
|
||||
If the key is unrecognized, this call raises an
|
||||
`UnknownModelException`.
|
||||
|
||||
#### exists(key) -> AnyModelConfig:
|
||||
#### exists(key) -> AnyModelConfig
|
||||
|
||||
Returns True if a model with the given key exists in the databsae.
|
||||
|
||||
#### search_by_path(path) -> AnyModelConfig:
|
||||
#### search_by_path(path) -> AnyModelConfig
|
||||
|
||||
Returns the configuration of the model whose path is `path`. The path
|
||||
is matched using a simple string comparison and won't correctly match
|
||||
models referred to by different paths (e.g. using symbolic links).
|
||||
|
||||
#### search_by_name(name, base, type) -> List[AnyModelConfig]:
|
||||
#### search_by_name(name, base, type) -> List[AnyModelConfig]
|
||||
|
||||
This method searches for models that match some combination of `name`,
|
||||
`BaseType` and `ModelType`. Calling without any arguments will return
|
||||
all the models in the database.
|
||||
|
||||
#### all_models() -> List[AnyModelConfig]:
|
||||
#### all_models() -> List[AnyModelConfig]
|
||||
|
||||
Return all the model configs in the database. Exactly equivalent to
|
||||
calling `search_by_name()` with no arguments.
|
||||
|
||||
#### search_by_tag(tags) -> List[AnyModelConfig]:
|
||||
#### search_by_tag(tags) -> List[AnyModelConfig]
|
||||
|
||||
`tags` is a list of strings. This method returns a list of model
|
||||
configs that contain all of the given tags. Examples:
|
||||
@ -312,11 +308,11 @@ commercializable_models = [x for x in store.all_models() \
|
||||
if x.license.contains('allowCommercialUse=Sell')]
|
||||
```
|
||||
|
||||
#### version() -> str:
|
||||
#### version() -> str
|
||||
|
||||
Returns the version of the database, currently at `3.2`
|
||||
|
||||
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase:
|
||||
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase
|
||||
|
||||
This method exists to ease the transition from the previous version of
|
||||
the model manager, in which `get_model()` took the three arguments
|
||||
@ -337,7 +333,7 @@ model and pass its key to `get_model()`.
|
||||
Several methods allow you to create and update stored model config
|
||||
records.
|
||||
|
||||
#### add_model(key, config) -> AnyModelConfig:
|
||||
#### add_model(key, config) -> AnyModelConfig
|
||||
|
||||
Given a key and a configuration, this will add the model's
|
||||
configuration record to the database. `config` can either be a subclass of
|
||||
@ -352,7 +348,7 @@ model with the same key is already in the database, or an
|
||||
`InvalidModelConfigException` if a dict was passed and Pydantic
|
||||
experienced a parse or validation error.
|
||||
|
||||
### update_model(key, config) -> AnyModelConfig:
|
||||
### update_model(key, config) -> AnyModelConfig
|
||||
|
||||
Given a key and a configuration, this will update the model
|
||||
configuration record in the database. `config` can be either a
|
||||
@ -370,31 +366,31 @@ The `ModelInstallService` class implements the
|
||||
shop for all your model install needs. It provides the following
|
||||
functionality:
|
||||
|
||||
- Registering a model config record for a model already located on the
|
||||
* Registering a model config record for a model already located on the
|
||||
local filesystem, without moving it or changing its path.
|
||||
|
||||
- Installing a model alreadiy located on the local filesystem, by
|
||||
* Installing a model alreadiy located on the local filesystem, by
|
||||
moving it into the InvokeAI root directory under the
|
||||
`models` folder (or wherever config parameter `models_dir`
|
||||
specifies).
|
||||
|
||||
- Probing of models to determine their type, base type and other key
|
||||
|
||||
* Probing of models to determine their type, base type and other key
|
||||
information.
|
||||
|
||||
- Interface with the InvokeAI event bus to provide status updates on
|
||||
* Interface with the InvokeAI event bus to provide status updates on
|
||||
the download, installation and registration process.
|
||||
|
||||
- Downloading a model from an arbitrary URL and installing it in
|
||||
* Downloading a model from an arbitrary URL and installing it in
|
||||
`models_dir`.
|
||||
|
||||
- Special handling for Civitai model URLs which allow the user to
|
||||
* Special handling for Civitai model URLs which allow the user to
|
||||
paste in a model page's URL or download link
|
||||
|
||||
- Special handling for HuggingFace repo_ids to recursively download
|
||||
* Special handling for HuggingFace repo_ids to recursively download
|
||||
the contents of the repository, paying attention to alternative
|
||||
variants such as fp16.
|
||||
|
||||
- Saving tags and other metadata about the model into the invokeai database
|
||||
* Saving tags and other metadata about the model into the invokeai database
|
||||
when fetching from a repo that provides that type of information,
|
||||
(currently only Civitai and HuggingFace).
|
||||
|
||||
@ -427,8 +423,8 @@ queue.start()
|
||||
|
||||
installer = ModelInstallService(app_config=config,
|
||||
record_store=record_store,
|
||||
download_queue=queue
|
||||
)
|
||||
download_queue=queue
|
||||
)
|
||||
installer.start()
|
||||
```
|
||||
|
||||
@ -443,7 +439,6 @@ required parameters:
|
||||
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|
||||
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
|
||||
|
||||
|
||||
Once initialized, the installer will provide the following methods:
|
||||
|
||||
#### install_job = installer.heuristic_import(source, [config], [access_token])
|
||||
@ -457,15 +452,15 @@ The `source` is a string that can be any of these forms
|
||||
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
|
||||
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
|
||||
3. A HuggingFace repo_id with any of the following formats:
|
||||
- `model/name` -- entire model
|
||||
- `model/name:fp32` -- entire model, using the fp32 variant
|
||||
- `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
||||
- `model/name::vae` -- vae submodel, using default precision
|
||||
- `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
||||
- `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
||||
* `model/name` -- entire model
|
||||
* `model/name:fp32` -- entire model, using the fp32 variant
|
||||
* `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
||||
* `model/name::vae` -- vae submodel, using default precision
|
||||
* `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
||||
* `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
||||
|
||||
Note that by specifying a relative path to the top of the HuggingFace
|
||||
repo, you can download and install arbitrary models files.
|
||||
repo, you can download and install arbitrary models files.
|
||||
|
||||
The variant, if not provided, will be automatically filled in with
|
||||
`fp32` if the user has requested full precision, and `fp16`
|
||||
@ -491,9 +486,9 @@ following illustrates basic usage:
|
||||
|
||||
```
|
||||
from invokeai.app.services.model_install import (
|
||||
LocalModelSource,
|
||||
HFModelSource,
|
||||
URLModelSource,
|
||||
LocalModelSource,
|
||||
HFModelSource,
|
||||
URLModelSource,
|
||||
)
|
||||
|
||||
source1 = LocalModelSource(path='/opt/models/sushi.safetensors') # a local safetensors file
|
||||
@ -513,13 +508,13 @@ for source in [source1, source2, source3, source4, source5, source6, source7]:
|
||||
source2job = installer.wait_for_installs(timeout=120)
|
||||
for source in sources:
|
||||
job = source2job[source]
|
||||
if job.complete:
|
||||
model_config = job.config_out
|
||||
model_key = model_config.key
|
||||
print(f"{source} installed as {model_key}")
|
||||
elif job.errored:
|
||||
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
||||
|
||||
if job.complete:
|
||||
model_config = job.config_out
|
||||
model_key = model_config.key
|
||||
print(f"{source} installed as {model_key}")
|
||||
elif job.errored:
|
||||
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
||||
|
||||
```
|
||||
|
||||
As shown here, the `import_model()` method accepts a variety of
|
||||
@ -528,7 +523,7 @@ HuggingFace repo_ids with and without a subfolder designation,
|
||||
Civitai model URLs and arbitrary URLs that point to checkpoint files
|
||||
(but not to folders).
|
||||
|
||||
Each call to `import_model()` return a `ModelInstallJob` job,
|
||||
Each call to `import_model()` return a `ModelInstallJob` job,
|
||||
an object which tracks the progress of the install.
|
||||
|
||||
If a remote model is requested, the model's files are downloaded in
|
||||
@ -555,7 +550,7 @@ The full list of arguments to `import_model()` is as follows:
|
||||
| `config` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
|
||||
|
||||
The next few sections describe the various types of ModelSource that
|
||||
can be passed to `import_model()`.
|
||||
can be passed to `import_model()`.
|
||||
|
||||
`config` can be used to override all or a portion of the configuration
|
||||
attributes returned by the model prober. See the section below for
|
||||
@ -566,7 +561,6 @@ details.
|
||||
This is used for a model that is located on a locally-accessible Posix
|
||||
filesystem, such as a local disk or networked fileshare.
|
||||
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `path` | str | Path | None | Path to the model file or directory |
|
||||
@ -625,7 +619,6 @@ HuggingFace has the most complicated `ModelSource` structure:
|
||||
| `subfolder` | Path | None | Look for the model in a subfolder of the repo. |
|
||||
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
|
||||
|
||||
|
||||
The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`.
|
||||
|
||||
The `variant` is one of the various diffusers formats that HuggingFace
|
||||
@ -661,7 +654,6 @@ in. To download these files, you must provide an
|
||||
`HfFolder.get_token()` will be called to fill it in with the cached
|
||||
one.
|
||||
|
||||
|
||||
#### Monitoring the install job process
|
||||
|
||||
When you create an install job with `import_model()`, it launches the
|
||||
@ -675,14 +667,13 @@ The `ModelInstallJob` class has the following structure:
|
||||
| `id` | `int` | Integer ID for this job |
|
||||
| `status` | `InstallStatus` | An enum of [`waiting`, `downloading`, `running`, `completed`, `error` and `cancelled`]|
|
||||
| `config_in` | `dict` | Overriding configuration values provided by the caller |
|
||||
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
|
||||
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
|
||||
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
|
||||
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
|
||||
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
|
||||
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
|
||||
| `local_path` | `Path` | If a remote model, holds the path of the model after it is downloaded; if a local model, same as `source` |
|
||||
| `error_type` | `str` | Name of the exception that led to an error status |
|
||||
| `error` | `str` | Traceback of the error |
|
||||
|
||||
|
||||
If the `event_bus` argument was provided, events will also be
|
||||
broadcast to the InvokeAI event bus. The events will appear on the bus
|
||||
as an event of type `EventServiceBase.model_event`, a timestamp and
|
||||
@ -702,14 +693,13 @@ following keys:
|
||||
| `total_bytes` | int | Total size of all the files that make up the model |
|
||||
| `parts` | List[Dict]| Information on the progress of the individual files that make up the model |
|
||||
|
||||
|
||||
The parts is a list of dictionaries that give information on each of
|
||||
the components pieces of the download. The dictionary's keys are
|
||||
`source`, `local_path`, `bytes` and `total_bytes`, and correspond to
|
||||
the like-named keys in the main event.
|
||||
|
||||
Note that downloading events will not be issued for local models, and
|
||||
that downloading events occur *before* the running event.
|
||||
that downloading events occur _before_ the running event.
|
||||
|
||||
##### `model_install_running`
|
||||
|
||||
@ -752,14 +742,13 @@ properties: `waiting`, `downloading`, `running`, `complete`, `errored`
|
||||
and `cancelled`, as well as `in_terminal_state`. The last will return
|
||||
True if the job is in the complete, errored or cancelled states.
|
||||
|
||||
|
||||
#### Model configuration and probing
|
||||
|
||||
The install service uses the `invokeai.backend.model_manager.probe`
|
||||
module during import to determine the model's type, base type, and
|
||||
other configuration parameters. Among other things, it assigns a
|
||||
default name and description for the model based on probed
|
||||
fields.
|
||||
fields.
|
||||
|
||||
When downloading remote models is implemented, additional
|
||||
configuration information, such as list of trigger terms, will be
|
||||
@ -774,11 +763,11 @@ attributes. Here is an example of setting the
|
||||
```
|
||||
install_job = installer.import_model(
|
||||
source=HFModelSource(repo_id='stabilityai/stable-diffusion-2-1',variant='fp32'),
|
||||
config=dict(
|
||||
prediction_type=SchedulerPredictionType('v_prediction')
|
||||
name='stable diffusion 2 base model',
|
||||
)
|
||||
)
|
||||
config=dict(
|
||||
prediction_type=SchedulerPredictionType('v_prediction')
|
||||
name='stable diffusion 2 base model',
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
### Other installer methods
|
||||
@ -862,7 +851,6 @@ This method is similar to `unregister()`, but also unconditionally
|
||||
deletes the corresponding model weights file(s), regardless of whether
|
||||
they are inside or outside the InvokeAI models hierarchy.
|
||||
|
||||
|
||||
#### path = installer.download_and_cache(remote_source, [access_token], [timeout])
|
||||
|
||||
This utility routine will download the model file located at source,
|
||||
@ -953,7 +941,7 @@ following fields:
|
||||
|
||||
When you create a job, you can assign it a `priority`. If multiple
|
||||
jobs are queued, the job with the lowest priority runs first. (Don't
|
||||
blame me! The Unix developers came up with this convention.)
|
||||
blame me! The Unix developers came up with this convention.)
|
||||
|
||||
Every job has a `source` and a `destination`. `source` is a string in
|
||||
the base class, but subclassses redefine it more specifically.
|
||||
@ -974,7 +962,7 @@ is in its lifecycle. Values are defined in the string enum
|
||||
`DownloadJobStatus`, a symbol available from
|
||||
`invokeai.app.services.download_manager`. Possible values are:
|
||||
|
||||
| **Value** | **String Value** | ** Description ** |
|
||||
| **Value** | **String Value** | **Description** |
|
||||
|--------------|---------------------|-------------------|
|
||||
| `IDLE` | idle | Job created, but not submitted to the queue |
|
||||
| `ENQUEUED` | enqueued | Job is patiently waiting on the queue |
|
||||
@ -991,7 +979,7 @@ debugging and performance testing.
|
||||
|
||||
In case of an error, the Exception that caused the error will be
|
||||
placed in the `error` field, and the job's status will be set to
|
||||
`DownloadJobStatus.ERROR`.
|
||||
`DownloadJobStatus.ERROR`.
|
||||
|
||||
After an error occurs, any partially downloaded files will be deleted
|
||||
from disk, unless `preserve_partial_downloads` was set to True at job
|
||||
@ -1040,11 +1028,11 @@ While a job is being downloaded, the queue will emit events at
|
||||
periodic intervals. A typical series of events during a successful
|
||||
download session will look like this:
|
||||
|
||||
- enqueued
|
||||
- running
|
||||
- running
|
||||
- running
|
||||
- completed
|
||||
* enqueued
|
||||
* running
|
||||
* running
|
||||
* running
|
||||
* completed
|
||||
|
||||
There will be a single enqueued event, followed by one or more running
|
||||
events, and finally one `completed`, `error` or `cancelled`
|
||||
@ -1053,12 +1041,12 @@ events.
|
||||
It is possible for a caller to pause download temporarily, in which
|
||||
case the events may look something like this:
|
||||
|
||||
- enqueued
|
||||
- running
|
||||
- running
|
||||
- paused
|
||||
- running
|
||||
- completed
|
||||
* enqueued
|
||||
* running
|
||||
* running
|
||||
* paused
|
||||
* running
|
||||
* completed
|
||||
|
||||
The download queue logs when downloads start and end (unless `quiet`
|
||||
is set to True at initialization time) but doesn't log any progress
|
||||
@ -1120,11 +1108,11 @@ A typical initialization sequence will look like:
|
||||
from invokeai.app.services.download_manager import DownloadQueueService
|
||||
|
||||
def log_download_event(job: DownloadJobBase):
|
||||
logger.info(f'job={job.id}: status={job.status}')
|
||||
logger.info(f'job={job.id}: status={job.status}')
|
||||
|
||||
queue = DownloadQueueService(
|
||||
event_handlers=[log_download_event]
|
||||
)
|
||||
event_handlers=[log_download_event]
|
||||
)
|
||||
```
|
||||
|
||||
Event handlers can be provided to the queue at initialization time as
|
||||
@ -1155,9 +1143,9 @@ To use the former method, follow this example:
|
||||
```
|
||||
job = DownloadJobRemoteSource(
|
||||
source='http://www.civitai.com/models/13456',
|
||||
destination='/tmp/models/',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
)
|
||||
destination='/tmp/models/',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
)
|
||||
queue.submit_download_job(job, start=True)
|
||||
```
|
||||
|
||||
@ -1172,13 +1160,13 @@ To have the queue create the job for you, follow this example instead:
|
||||
```
|
||||
job = queue.create_download_job(
|
||||
source='http://www.civitai.com/models/13456',
|
||||
destdir='/tmp/models/',
|
||||
filename='my_model.safetensors',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
start=True,
|
||||
)
|
||||
destdir='/tmp/models/',
|
||||
filename='my_model.safetensors',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
start=True,
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
The `filename` argument forces the downloader to use the specified
|
||||
name for the file rather than the name provided by the remote source,
|
||||
and is equivalent to manually specifying a destination of
|
||||
@ -1187,7 +1175,6 @@ and is equivalent to manually specifying a destination of
|
||||
Here is the full list of arguments that can be provided to
|
||||
`create_download_job()`:
|
||||
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `source` | Union[str, Path, AnyHttpUrl] | | Download remote or local source |
|
||||
@ -1200,7 +1187,7 @@ Here is the full list of arguments that can be provided to
|
||||
|
||||
Internally, `create_download_job()` has a little bit of internal logic
|
||||
that looks at the type of the source and selects the right subclass of
|
||||
`DownloadJobBase` to create and enqueue.
|
||||
`DownloadJobBase` to create and enqueue.
|
||||
|
||||
**TODO**: move this logic into its own method for overriding in
|
||||
subclasses.
|
||||
@ -1275,7 +1262,7 @@ for getting the model to run. For example "author" is metadata, while
|
||||
"type", "base" and "format" are not. The latter fields are part of the
|
||||
model's config, as defined in `invokeai.backend.model_manager.config`.
|
||||
|
||||
### Example Usage:
|
||||
### Example Usage
|
||||
|
||||
```
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
@ -1328,7 +1315,6 @@ This is the common base class for metadata:
|
||||
| `author` | str | Model's author |
|
||||
| `tags` | Set[str] | Model tags |
|
||||
|
||||
|
||||
Note that the model config record also has a `name` field. It is
|
||||
intended that the config record version be locally customizable, while
|
||||
the metadata version is read-only. However, enforcing this is expected
|
||||
@ -1348,7 +1334,6 @@ This descends from `ModelMetadataBase` and adds the following fields:
|
||||
| `last_modified`| datetime | Date of last commit of this model to the repo |
|
||||
| `files` | List[Path] | List of the files in the model repo |
|
||||
|
||||
|
||||
#### `CivitaiMetadata`
|
||||
|
||||
This descends from `ModelMetadataBase` and adds the following fields:
|
||||
@ -1415,7 +1400,6 @@ testing suite to avoid hitting the internet.
|
||||
The HuggingFace and Civitai fetcher subclasses add additional
|
||||
repo-specific fetching methods:
|
||||
|
||||
|
||||
#### HuggingFaceMetadataFetch
|
||||
|
||||
This overrides its base class `from_json()` method to return a
|
||||
@ -1434,13 +1418,12 @@ retrieves its metadata. Functionally equivalent to `from_id()`, the
|
||||
only difference is that it returna a `CivitaiMetadata` object rather
|
||||
than an `AnyModelRepoMetadata`.
|
||||
|
||||
|
||||
### Metadata Storage
|
||||
|
||||
The `ModelMetadataStore` provides a simple facility to store model
|
||||
metadata in the `invokeai.db` database. The data is stored as a JSON
|
||||
blob, with a few common fields (`name`, `author`, `tags`) broken out
|
||||
to be searchable.
|
||||
to be searchable.
|
||||
|
||||
When a metadata object is saved to the database, it is identified
|
||||
using the model key, _and this key must correspond to an existing
|
||||
@ -1535,16 +1518,16 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
)
|
||||
convert_cache = ModelConvertCache(
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
)
|
||||
loader = ModelLoadService(
|
||||
app_config=config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
registry=ModelLoaderRegistry
|
||||
app_config=config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
registry=ModelLoaderRegistry
|
||||
)
|
||||
```
|
||||
|
||||
@ -1567,7 +1550,6 @@ The returned `LoadedModel` object contains a copy of the configuration
|
||||
record returned by the model record `get_model()` method, as well as
|
||||
the in-memory loaded model:
|
||||
|
||||
|
||||
| **Attribute Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
|
||||
@ -1581,7 +1563,6 @@ return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
|
||||
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
|
||||
models. The others are obvious.
|
||||
|
||||
|
||||
`LoadedModel` acts as a context manager. The context loads the model
|
||||
into the execution device (e.g. VRAM on CUDA systems), locks the model
|
||||
in the execution device for the duration of the context, and returns
|
||||
@ -1590,14 +1571,14 @@ the model. Use it like this:
|
||||
```
|
||||
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||
with model_info as vae:
|
||||
image = vae.decode(latents)[0]
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
`get_model_by_key()` may raise any of the following exceptions:
|
||||
|
||||
- `UnknownModelException` -- key not in database
|
||||
- `ModelNotFoundException` -- key in database but model not found at path
|
||||
- `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
* `UnknownModelException` -- key not in database
|
||||
* `ModelNotFoundException` -- key in database but model not found at path
|
||||
* `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
|
||||
### Emitting model loading events
|
||||
|
||||
@ -1609,15 +1590,15 @@ following payload:
|
||||
|
||||
```
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel_type=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
precision=str(model_info.precision),
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel_type=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
precision=str(model_info.precision),
|
||||
)
|
||||
```
|
||||
|
||||
@ -1724,6 +1705,7 @@ object, or in `context.services.model_manager` from within an
|
||||
invocation.
|
||||
|
||||
In the examples below, we have retrieved the manager using:
|
||||
|
||||
```
|
||||
mm = ApiDependencies.invoker.services.model_manager
|
||||
```
|
||||
|
@ -19,6 +19,8 @@ their descriptions.
|
||||
| Conditioning Primitive | A conditioning tensor primitive value |
|
||||
| Content Shuffle Processor | Applies content shuffle processing to image |
|
||||
| ControlNet | Collects ControlNet info to pass to other nodes |
|
||||
| Create Denoise Mask | Converts a greyscale or transparency image into a mask for denoising. |
|
||||
| Create Gradient Mask | Creates a mask for Gradient ("soft", "differential") inpainting that gradually expands during denoising. Improves edge coherence. |
|
||||
| Denoise Latents | Denoises noisy latents to decodable images |
|
||||
| Divide Integers | Divides two numbers |
|
||||
| Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator |
|
||||
|
@ -26,7 +26,6 @@ from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_metadata import ModelMetadataStoreSQL
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
@ -93,10 +92,9 @@ class ApiDependencies:
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
model_metadata_service = ModelMetadataStoreSQL(db=db)
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
|
||||
model_record_service=ModelRecordServiceSQL(db=db),
|
||||
download_queue=download_queue_service,
|
||||
events=events,
|
||||
)
|
||||
|
@ -3,9 +3,7 @@
|
||||
|
||||
import pathlib
|
||||
import shutil
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
@ -15,13 +13,10 @@ from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelRecordOrderBy,
|
||||
ModelSummary,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@ -30,8 +25,6 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
@ -47,15 +40,6 @@ class ModelsList(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class ModelTagSet(BaseModel):
|
||||
"""Return tags for a set of models."""
|
||||
|
||||
key: str
|
||||
name: str
|
||||
author: str
|
||||
tags: Set[str]
|
||||
|
||||
|
||||
##############################################################################
|
||||
# These are example inputs and outputs that are used in places where Swagger
|
||||
# is unable to generate a correct example.
|
||||
@ -66,19 +50,16 @@ example_model_config = {
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "checkpoint",
|
||||
"config": "string",
|
||||
"config_path": "string",
|
||||
"key": "string",
|
||||
"original_hash": "string",
|
||||
"current_hash": "string",
|
||||
"hash": "string",
|
||||
"description": "string",
|
||||
"source": "string",
|
||||
"last_modified": 0,
|
||||
"vae": "string",
|
||||
"converted_at": 0,
|
||||
"variant": "normal",
|
||||
"prediction_type": "epsilon",
|
||||
"repo_variant": "fp16",
|
||||
"upcast_attention": False,
|
||||
"ztsnr_training": False,
|
||||
}
|
||||
|
||||
example_model_input = {
|
||||
@ -87,50 +68,12 @@ example_model_input = {
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "checkpoint",
|
||||
"config": "configs/stable-diffusion/v1-inference.yaml",
|
||||
"config_path": "configs/stable-diffusion/v1-inference.yaml",
|
||||
"description": "Model description",
|
||||
"vae": None,
|
||||
"variant": "normal",
|
||||
}
|
||||
|
||||
example_model_metadata = {
|
||||
"name": "ip_adapter_sd_image_encoder",
|
||||
"author": "InvokeAI",
|
||||
"tags": [
|
||||
"transformers",
|
||||
"safetensors",
|
||||
"clip_vision_model",
|
||||
"endpoints_compatible",
|
||||
"region:us",
|
||||
"has_space",
|
||||
"license:apache-2.0",
|
||||
],
|
||||
"files": [
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
|
||||
"path": "ip_adapter_sd_image_encoder/README.md",
|
||||
"size": 628,
|
||||
"sha256": None,
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
|
||||
"path": "ip_adapter_sd_image_encoder/config.json",
|
||||
"size": 560,
|
||||
"sha256": None,
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
|
||||
"path": "ip_adapter_sd_image_encoder/model.safetensors",
|
||||
"size": 2528373448,
|
||||
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
|
||||
},
|
||||
],
|
||||
"type": "huggingface",
|
||||
"id": "InvokeAI/ip_adapter_sd_image_encoder",
|
||||
"tag_dict": {"license": "apache-2.0"},
|
||||
"last_modified": "2023-09-23T17:33:25Z",
|
||||
}
|
||||
|
||||
##############################################################################
|
||||
# ROUTES
|
||||
##############################################################################
|
||||
@ -210,48 +153,16 @@ async def get_model_record(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
async def list_model_summary(
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of models per page"),
|
||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Gets a page of model summary data."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
return results
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/i/{key}/metadata",
|
||||
operation_id="get_model_metadata",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model metadata was retrieved successfully",
|
||||
"content": {"application/json": {"example": example_model_metadata}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def get_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Get a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags",
|
||||
operation_id="list_tags",
|
||||
)
|
||||
async def list_tags() -> Set[str]:
|
||||
"""Get a unique set of all the model tags."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Set[str] = record_store.list_tags()
|
||||
return result
|
||||
# @model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
# async def list_model_summary(
|
||||
# page: int = Query(default=0, description="The page to get"),
|
||||
# per_page: int = Query(default=10, description="The number of models per page"),
|
||||
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
# ) -> PaginatedResults[ModelSummary]:
|
||||
# """Gets a page of model summary data."""
|
||||
# record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
# return results
|
||||
|
||||
|
||||
class FoundModel(BaseModel):
|
||||
@ -323,19 +234,6 @@ async def scan_for_models(
|
||||
return scan_results
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags/search",
|
||||
operation_id="search_by_metadata_tags",
|
||||
)
|
||||
async def search_by_metadata_tags(
|
||||
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results = record_store.search_by_metadata_tag(tags)
|
||||
return ModelsList(models=results)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}",
|
||||
operation_id="update_model_record",
|
||||
@ -352,15 +250,13 @@ async def search_by_metadata_tags(
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
info: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
|
||||
) -> AnyModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
"""Update a model's config."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
model_response: AnyModelConfig = record_store.update_model(key, config=info)
|
||||
model_response: AnyModelConfig = record_store.update_model(key, changes=changes)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -372,14 +268,14 @@ async def update_model_record(
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="del_model_record",
|
||||
operation_id="delete_model",
|
||||
responses={
|
||||
204: {"description": "Model deleted successfully"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=204,
|
||||
)
|
||||
async def del_model_record(
|
||||
async def delete_model(
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
) -> Response:
|
||||
"""
|
||||
@ -400,42 +296,39 @@ async def del_model_record(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/i/",
|
||||
operation_id="add_model_record",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The model added successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def add_model_record(
|
||||
config: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
) -> AnyModelConfig:
|
||||
"""Add a model using the configuration information appropriate for its type."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
if config.key == "<NOKEY>":
|
||||
config.key = sha1(randbytes(100)).hexdigest()
|
||||
logger.info(f"Created model {config.key} for {config.name}")
|
||||
try:
|
||||
record_store.add_model(config.key, config)
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
# @model_manager_router.post(
|
||||
# "/i/",
|
||||
# operation_id="add_model_record",
|
||||
# responses={
|
||||
# 201: {
|
||||
# "description": "The model added successfully",
|
||||
# "content": {"application/json": {"example": example_model_config}},
|
||||
# },
|
||||
# 409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
# 415: {"description": "Unrecognized file/folder format"},
|
||||
# },
|
||||
# status_code=201,
|
||||
# )
|
||||
# async def add_model_record(
|
||||
# config: Annotated[
|
||||
# AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
# ],
|
||||
# ) -> AnyModelConfig:
|
||||
# """Add a model using the configuration information appropriate for its type."""
|
||||
# logger = ApiDependencies.invoker.services.logger
|
||||
# record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
# try:
|
||||
# record_store.add_model(config)
|
||||
# except DuplicateModelException as e:
|
||||
# logger.error(str(e))
|
||||
# raise HTTPException(status_code=409, detail=str(e))
|
||||
# except InvalidModelException as e:
|
||||
# logger.error(str(e))
|
||||
# raise HTTPException(status_code=415)
|
||||
|
||||
# now fetch it out
|
||||
result: AnyModelConfig = record_store.get_model(config.key)
|
||||
return result
|
||||
# # now fetch it out
|
||||
# result: AnyModelConfig = record_store.get_model(config.key)
|
||||
# return result
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
@ -451,6 +344,7 @@ async def add_model_record(
|
||||
)
|
||||
async def install_model(
|
||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
|
||||
# TODO(MM2): Can we type this?
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
@ -493,6 +387,7 @@ async def install_model(
|
||||
source=source,
|
||||
config=config,
|
||||
access_token=access_token,
|
||||
inplace=bool(inplace),
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
@ -508,10 +403,10 @@ async def install_model(
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/import",
|
||||
operation_id="list_model_install_jobs",
|
||||
"/install",
|
||||
operation_id="list_model_installs",
|
||||
)
|
||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
async def list_model_installs() -> List[ModelInstallJob]:
|
||||
"""Return the list of model install jobs.
|
||||
|
||||
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
|
||||
@ -525,9 +420,8 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
* "cancelled" -- Job was cancelled before completion.
|
||||
|
||||
Once completed, information about the model such as its size, base
|
||||
model, type, and metadata can be retrieved from the `config_out`
|
||||
field. For multi-file models such as diffusers, information on individual files
|
||||
can be retrieved from `download_parts`.
|
||||
model and type can be retrieved from the `config_out` field. For multi-file models such as diffusers,
|
||||
information on individual files can be retrieved from `download_parts`.
|
||||
|
||||
See the example and schema below for more information.
|
||||
"""
|
||||
@ -536,7 +430,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/import/{id}",
|
||||
"/install/{id}",
|
||||
operation_id="get_model_install_job",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
@ -556,7 +450,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/import/{id}",
|
||||
"/install/{id}",
|
||||
operation_id="cancel_model_install_job",
|
||||
responses={
|
||||
201: {"description": "The job was cancelled successfully"},
|
||||
@ -574,8 +468,8 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
||||
installer.cancel_job(job)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/import",
|
||||
@model_manager_router.delete(
|
||||
"/install",
|
||||
operation_id="prune_model_install_jobs",
|
||||
responses={
|
||||
204: {"description": "All completed and errored jobs have been pruned"},
|
||||
@ -654,7 +548,8 @@ async def convert_model(
|
||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||
original_name = model_config.name
|
||||
model_config.name = f"{original_name}.DELETE"
|
||||
store.update_model(key, config=model_config)
|
||||
changes = ModelRecordChanges(name=model_config.name)
|
||||
store.update_model(key, changes=changes)
|
||||
|
||||
# install the diffusers
|
||||
try:
|
||||
@ -663,7 +558,7 @@ async def convert_model(
|
||||
config={
|
||||
"name": original_name,
|
||||
"description": model_config.description,
|
||||
"original_hash": model_config.original_hash,
|
||||
"hash": model_config.hash,
|
||||
"source": model_config.source,
|
||||
},
|
||||
)
|
||||
@ -671,10 +566,6 @@ async def convert_model(
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
# get the original metadata
|
||||
if orig_metadata := store.get_metadata(key):
|
||||
store.metadata_store.add_metadata(new_key, orig_metadata)
|
||||
|
||||
# delete the original safetensors file
|
||||
installer.delete(key)
|
||||
|
||||
@ -686,66 +577,66 @@ async def convert_model(
|
||||
return new_config
|
||||
|
||||
|
||||
@model_manager_router.put(
|
||||
"/merge",
|
||||
operation_id="merge",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Model converted successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "Model not found"},
|
||||
409: {"description": "There is already a model registered at this location"},
|
||||
},
|
||||
)
|
||||
async def merge(
|
||||
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
force: bool = Body(
|
||||
description="Force merging of models created with different versions of diffusers",
|
||||
default=False,
|
||||
),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||
merge_dest_directory: Optional[str] = Body(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
),
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
||||
```
|
||||
Argument Description [default]
|
||||
-------- ----------------------
|
||||
keys List of 2-3 model keys to merge together. All models must use the same base type.
|
||||
merged_model_name Name for the merged model [Concat model names]
|
||||
alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||
force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||
interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||
merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||
```
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
merger = ModelMerger(installer)
|
||||
model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
response = merger.merge_diffusion_models_and_save(
|
||||
model_keys=keys,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
except UnknownModelException:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"One or more of the models '{keys}' not found",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
# @model_manager_router.put(
|
||||
# "/merge",
|
||||
# operation_id="merge",
|
||||
# responses={
|
||||
# 200: {
|
||||
# "description": "Model converted successfully",
|
||||
# "content": {"application/json": {"example": example_model_config}},
|
||||
# },
|
||||
# 400: {"description": "Bad request"},
|
||||
# 404: {"description": "Model not found"},
|
||||
# 409: {"description": "There is already a model registered at this location"},
|
||||
# },
|
||||
# )
|
||||
# async def merge(
|
||||
# keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
# merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
# alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
# force: bool = Body(
|
||||
# description="Force merging of models created with different versions of diffusers",
|
||||
# default=False,
|
||||
# ),
|
||||
# interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||
# merge_dest_directory: Optional[str] = Body(
|
||||
# description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
# default=None,
|
||||
# ),
|
||||
# ) -> AnyModelConfig:
|
||||
# """
|
||||
# Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
||||
# ```
|
||||
# Argument Description [default]
|
||||
# -------- ----------------------
|
||||
# keys List of 2-3 model keys to merge together. All models must use the same base type.
|
||||
# merged_model_name Name for the merged model [Concat model names]
|
||||
# alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||
# force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||
# interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||
# merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||
# ```
|
||||
# """
|
||||
# logger = ApiDependencies.invoker.services.logger
|
||||
# try:
|
||||
# logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
# dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
# installer = ApiDependencies.invoker.services.model_manager.install
|
||||
# merger = ModelMerger(installer)
|
||||
# model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
# response = merger.merge_diffusion_models_and_save(
|
||||
# model_keys=keys,
|
||||
# merged_model_name=merged_model_name or "+".join(model_names),
|
||||
# alpha=alpha,
|
||||
# interp=interp,
|
||||
# force=force,
|
||||
# merge_dest_directory=dest,
|
||||
# )
|
||||
# except UnknownModelException:
|
||||
# raise HTTPException(
|
||||
# status_code=404,
|
||||
# detail=f"One or more of the models '{keys}' not found",
|
||||
# )
|
||||
# except ValueError as e:
|
||||
# raise HTTPException(status_code=400, detail=str(e))
|
||||
# return response
|
||||
|
@ -173,6 +173,16 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("gradient_mask_output")
|
||||
class GradientMaskOutput(BaseInvocationOutput):
|
||||
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
|
||||
|
||||
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||
expanded_mask_area: ImageField = OutputField(
|
||||
description="Image representing the total gradient area of the mask. For paste-back purposes."
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"create_gradient_mask",
|
||||
title="Create Gradient Mask",
|
||||
@ -193,38 +203,42 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
if self.coherence_mode == "Box Blur":
|
||||
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||
else: # Gaussian Blur OR Staged
|
||||
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||
if self.edge_radius > 0:
|
||||
if self.coherence_mode == "Box Blur":
|
||||
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||
else: # Gaussian Blur OR Staged
|
||||
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||
|
||||
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||
|
||||
# redistribute blur so that the edges are 0 and blur out to 1
|
||||
blur_tensor = (blur_tensor - 0.5) * 2
|
||||
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
||||
blur_tensor = (blur_tensor - 0.5) * 2
|
||||
|
||||
threshold = 1 - self.minimum_denoise
|
||||
threshold = 1 - self.minimum_denoise
|
||||
|
||||
if self.coherence_mode == "Staged":
|
||||
# wherever the blur_tensor is less than fully masked, convert it to threshold
|
||||
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
|
||||
else:
|
||||
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||
|
||||
if self.coherence_mode == "Staged":
|
||||
# wherever the blur_tensor is masked to any degree, convert it to threshold
|
||||
blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
|
||||
else:
|
||||
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||
|
||||
# multiply original mask to force actually masked regions to 0
|
||||
blur_tensor = mask_tensor * blur_tensor
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
|
||||
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
|
||||
|
||||
return DenoiseMaskOutput.build(
|
||||
mask_name=mask_name,
|
||||
masked_latents_name=None,
|
||||
gradient=True,
|
||||
# compute a [0, 1] mask from the blur_tensor
|
||||
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
|
||||
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
||||
expanded_image_dto = context.images.save(expanded_mask_image)
|
||||
|
||||
return GradientMaskOutput(
|
||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
|
||||
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
||||
)
|
||||
|
||||
|
||||
@ -775,10 +789,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
(
|
||||
result_latents,
|
||||
result_attention_map_saver,
|
||||
) = pipeline.latents_from_embeddings(
|
||||
result_latents = pipeline.latents_from_embeddings(
|
||||
latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
|
@ -133,7 +133,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
submodel_type=SubModelType.VAE,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -85,7 +85,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
submodel_type=SubModelType.VAE,
|
||||
),
|
||||
),
|
||||
)
|
||||
@ -142,7 +142,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
submodel_type=SubModelType.VAE,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -256,6 +256,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development)
|
||||
profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
|
||||
profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development)
|
||||
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce startup time.", json_schema_extra=Categories.Development)
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
|
||||
|
||||
|
@ -18,10 +18,9 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
@ -151,6 +150,13 @@ ModelSource = Annotated[
|
||||
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
|
||||
]
|
||||
|
||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||
URLModelSource: ModelSourceType.Url,
|
||||
HFModelSource: ModelSourceType.HFRepoID,
|
||||
CivitaiModelSource: ModelSourceType.CivitAI,
|
||||
LocalModelSource: ModelSourceType.Path,
|
||||
}
|
||||
|
||||
|
||||
class ModelInstallJob(BaseModel):
|
||||
"""Object that tracks the current status of an install request."""
|
||||
@ -260,7 +266,6 @@ class ModelInstallServiceBase(ABC):
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
metadata_store: ModelMetadataStoreBase,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
):
|
||||
"""
|
||||
@ -347,6 +352,7 @@ class ModelInstallServiceBase(ABC):
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
r"""Install the indicated model using heuristics to interpret user intentions.
|
||||
|
||||
@ -392,7 +398,7 @@ class ModelInstallServiceBase(ABC):
|
||||
will override corresponding autoassigned probe fields in the
|
||||
model's config record. Use it to override
|
||||
`name`, `description`, `base_type`, `model_type`, `format`,
|
||||
`prediction_type`, `image_size`, and/or `ztsnr_training`.
|
||||
`prediction_type`, and/or `image_size`.
|
||||
|
||||
This will download the model located at `source`,
|
||||
probe it, and install it into the models directory.
|
||||
|
@ -7,7 +7,6 @@ import time
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from random import randbytes
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
@ -21,11 +20,15 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
@ -35,12 +38,14 @@ from invokeai.backend.model_manager.metadata import (
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||
from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
||||
|
||||
from .model_install_base import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
CivitaiModelSource,
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
@ -90,7 +95,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._running = False
|
||||
self._session = session
|
||||
self._next_job_id = 0
|
||||
self._metadata_store = record_store.metadata_store # for convenience
|
||||
|
||||
@property
|
||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||
@ -139,6 +143,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
config["source_type"] = ModelSourceType.Path
|
||||
return self._register(model_path, config)
|
||||
|
||||
def install_path(
|
||||
@ -148,11 +153,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
config["key"] = config.get("key", self._create_key())
|
||||
|
||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||
if self._app_config.skip_model_hash:
|
||||
config["hash"] = uuid_string()
|
||||
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
|
||||
|
||||
if preferred_name := config.get("name"):
|
||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||
@ -178,13 +183,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
|
||||
if Path(source).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source))
|
||||
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
repo_id=match.group(1),
|
||||
@ -373,15 +379,18 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.bytes = job.total_bytes
|
||||
self._signal_job_running(job)
|
||||
job.config_in["source"] = str(job.source)
|
||||
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||
# enter the metadata, if there is any
|
||||
if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)):
|
||||
job.config_in["source_api_response"] = job.source_metadata.api_response
|
||||
if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_phrases:
|
||||
job.config_in["trigger_phrases"] = job.source_metadata.trigger_phrases
|
||||
|
||||
if job.inplace:
|
||||
key = self.register_path(job.local_path, job.config_in)
|
||||
else:
|
||||
key = self.install_path(job.local_path, job.config_in)
|
||||
job.config_out = self.record_store.get_model(key)
|
||||
|
||||
# enter the metadata, if there is any
|
||||
if job.source_metadata:
|
||||
self._metadata_store.add_metadata(key, job.source_metadata)
|
||||
self._signal_job_completed(job)
|
||||
|
||||
except InvalidModelConfigException as excp:
|
||||
@ -467,7 +476,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||
new_path = self._move_model(old_path, new_path)
|
||||
model.path = new_path.relative_to(models_dir).as_posix()
|
||||
self.record_store.update_model(key, model)
|
||||
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
|
||||
return model
|
||||
|
||||
def _scan_register(self, model: Path) -> bool:
|
||||
@ -519,22 +528,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
move(old_path, new_path)
|
||||
return new_path
|
||||
|
||||
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
|
||||
if config: # used to override probe fields
|
||||
for key, value in config.items():
|
||||
setattr(info, key, value)
|
||||
return info
|
||||
|
||||
def _create_key(self) -> str:
|
||||
return sha256(randbytes(100)).hexdigest()[0:32]
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
# Note that we may be passed a pre-populated AnyModelConfig object,
|
||||
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
|
||||
config["key"] = config.get("key", self._create_key())
|
||||
config = config or {}
|
||||
|
||||
if self._app_config.skip_model_hash:
|
||||
config["hash"] = uuid_string()
|
||||
|
||||
info = info or ModelProbe.probe(model_path, config)
|
||||
|
||||
model_path = model_path.absolute()
|
||||
@ -544,11 +545,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
info.path = model_path.as_posix()
|
||||
|
||||
# add 'main' specific fields
|
||||
if hasattr(info, "config"):
|
||||
if isinstance(info, CheckpointConfigBase):
|
||||
# make config relative to our root
|
||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
|
||||
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
self.record_store.add_model(info.key, info)
|
||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve()
|
||||
info.config_path = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
self.record_store.add_model(info)
|
||||
return info.key
|
||||
|
||||
def _next_id(self) -> int:
|
||||
@ -569,13 +570,15 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
local_path=Path(source.path),
|
||||
inplace=source.inplace,
|
||||
inplace=source.inplace or False,
|
||||
)
|
||||
|
||||
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
if not source.access_token:
|
||||
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
|
||||
metadata = CivitaiMetadataFetch(self._session).from_id(str(source.version_id))
|
||||
metadata = CivitaiMetadataFetch(self._session, self.app_config.get_config().civitai_api_key).from_id(
|
||||
str(source.version_id)
|
||||
)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
|
||||
@ -603,15 +606,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
# URLs from Civitai or HuggingFace will be handled specially
|
||||
url_patterns = {
|
||||
r"^https?://civitai.com/": CivitaiMetadataFetch,
|
||||
r"^https?://huggingface.co/[^/]+/[^/]+$": HuggingFaceMetadataFetch,
|
||||
}
|
||||
metadata = None
|
||||
for pattern, fetcher in url_patterns.items():
|
||||
if re.match(pattern, str(source.url), re.IGNORECASE):
|
||||
metadata = fetcher(self._session).from_url(source.url)
|
||||
break
|
||||
fetcher = None
|
||||
try:
|
||||
fetcher = self.get_fetcher_from_url(str(source.url))
|
||||
except ValueError:
|
||||
pass
|
||||
kwargs: dict[str, Any] = {"session": self._session}
|
||||
if fetcher is CivitaiMetadataFetch:
|
||||
kwargs["api_key"] = self._app_config.get_config().civitai_api_key
|
||||
if fetcher is not None:
|
||||
metadata = fetcher(**kwargs).from_url(source.url)
|
||||
self._logger.debug(f"metadata={metadata}")
|
||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
@ -626,7 +631,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _import_remote_model(
|
||||
self,
|
||||
source: ModelSource,
|
||||
source: HFModelSource | CivitaiModelSource | URLModelSource,
|
||||
remote_files: List[RemoteModelFile],
|
||||
metadata: Optional[AnyModelRepoMetadata],
|
||||
config: Optional[Dict[str, Any]],
|
||||
@ -654,7 +659,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# In the event that there is a subfolder specified in the source,
|
||||
# we need to remove it from the destination path in order to avoid
|
||||
# creating unwanted subfolders
|
||||
if hasattr(source, "subfolder") and source.subfolder:
|
||||
if isinstance(source, HFModelSource) and source.subfolder:
|
||||
root = Path(remote_files[0].path.parts[0])
|
||||
subfolder = root / source.subfolder
|
||||
else:
|
||||
@ -841,3 +846,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"{job.source}: model installation was cancelled")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_cancelled(str(job.source))
|
||||
|
||||
@staticmethod
|
||||
def get_fetcher_from_url(url: str):
|
||||
if re.match(r"^https?://civitai.com/", url.lower()):
|
||||
return CivitaiMetadataFetch
|
||||
elif re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||
return HuggingFaceMetadataFetch
|
||||
raise ValueError(f"Unsupported model source: '{url}'")
|
||||
|
@ -1,9 +0,0 @@
|
||||
"""Init file for ModelMetadataStoreService module."""
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
from .metadata_store_sql import ModelMetadataStoreSQL
|
||||
|
||||
__all__ = [
|
||||
"ModelMetadataStoreBase",
|
||||
"ModelMetadataStoreSQL",
|
||||
]
|
@ -1,65 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Storage for Model Metadata
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
|
||||
class ModelMetadataStoreBase(ABC):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
@abstractmethod
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
|
||||
@abstractmethod
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
|
||||
@abstractmethod
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
@ -1,222 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise UnknownMetadataException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id,metadata FROM model_metadata;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select tag_text from tags;
|
||||
"""
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.model_id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches if matches else set()
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
@ -6,20 +6,19 @@ Abstract base class for storing and retrieving model configuration records.
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
@ -60,11 +59,33 @@ class ModelSummary(BaseModel):
|
||||
tags: Set[str] = Field(description="tags associated with model")
|
||||
|
||||
|
||||
class ModelRecordChanges(BaseModelExcludeNull):
|
||||
"""A set of changes to apply to a model."""
|
||||
|
||||
# Changes applicable to all models
|
||||
name: Optional[str] = Field(description="Name of the model.", default=None)
|
||||
path: Optional[str] = Field(description="Path to the model.", default=None)
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
# Checkpoint-specific changes
|
||||
# TODO(MM2): Should we expose these? Feels footgun-y...
|
||||
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
|
||||
prediction_type: Optional[SchedulerPredictionType] = Field(
|
||||
description="The prediction type of the model.", default=None
|
||||
)
|
||||
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
|
||||
|
||||
|
||||
class ModelRecordServiceBase(ABC):
|
||||
"""Abstract base class for storage and retrieval of model configs."""
|
||||
|
||||
@abstractmethod
|
||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@ -88,13 +109,12 @@ class ModelRecordServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
:param key: Unique key for the model to be updated.
|
||||
:param changes: A set of changes to apply to this model. Changes are validated before being written.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -109,40 +129,6 @@ class ModelRecordServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Retrieve metadata (if any) from when model was downloaded from a repo.
|
||||
|
||||
:param key: Model key
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Search model metadata for ones with all listed tags and return their corresponding configs.
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
@ -217,21 +203,3 @@ class ModelRecordServiceBase(ABC):
|
||||
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
|
||||
)
|
||||
return model_configs[0]
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
key: str,
|
||||
new_name: str,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Rename the indicated model. Just a special case of update_model().
|
||||
|
||||
In some implementations, renaming the model may involve changing where
|
||||
it is stored on the filesystem. So this is broken out.
|
||||
|
||||
:param key: Model key
|
||||
:param new_name: New name for model
|
||||
"""
|
||||
config = self.get_model(key)
|
||||
config.name = new_name
|
||||
return self.update_model(key, config)
|
||||
|
@ -43,7 +43,7 @@ import json
|
||||
import sqlite3
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@ -53,12 +53,11 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_records_base import (
|
||||
DuplicateModelException,
|
||||
ModelRecordChanges,
|
||||
ModelRecordOrderBy,
|
||||
ModelRecordServiceBase,
|
||||
ModelSummary,
|
||||
@ -69,7 +68,7 @@ from .model_records_base import (
|
||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
@ -78,14 +77,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = db.conn.cursor()
|
||||
self._metadata_store = metadata_store
|
||||
|
||||
@property
|
||||
def db(self) -> SqliteDatabase:
|
||||
"""Return the underlying database."""
|
||||
return self._db
|
||||
|
||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@ -95,23 +93,19 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_config (
|
||||
INSERT INTO models (
|
||||
id,
|
||||
original_hash,
|
||||
config
|
||||
)
|
||||
VALUES (?,?,?);
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.original_hash,
|
||||
json_serialized,
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
),
|
||||
)
|
||||
self._db.conn.commit()
|
||||
@ -119,12 +113,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._db.conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "model_config.path" in str(e):
|
||||
msg = f"A model with path '{record.path}' is already installed"
|
||||
elif "model_config.name" in str(e):
|
||||
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
else:
|
||||
msg = f"A model with key '{key}' is already installed"
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
else:
|
||||
raise e
|
||||
@ -132,7 +126,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(key)
|
||||
return self.get_model(config.key)
|
||||
|
||||
def del_model(self, key: str) -> None:
|
||||
"""
|
||||
@ -146,7 +140,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_config
|
||||
DELETE FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@ -158,21 +152,20 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
record = self.get_model(key)
|
||||
|
||||
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
||||
for field_name in changes.model_fields_set:
|
||||
setattr(record, field_name, getattr(changes, field_name))
|
||||
|
||||
json_serialized = record.model_dump_json()
|
||||
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_config
|
||||
UPDATE models
|
||||
SET
|
||||
config=?
|
||||
WHERE id=?;
|
||||
@ -199,7 +192,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@ -220,7 +213,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM model_config
|
||||
select count(*) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@ -246,9 +239,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
results = []
|
||||
where_clause = []
|
||||
bindings = []
|
||||
where_clause: list[str] = []
|
||||
bindings: list[str] = []
|
||||
if model_name:
|
||||
where_clause.append("name=?")
|
||||
bindings.append(model_name)
|
||||
@ -265,14 +257,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
select config, strftime('%s',updated_at) FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
{where};
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
result = self._cursor.fetchall()
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
@ -281,7 +272,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
@ -292,13 +283,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated original_hash."""
|
||||
"""Return models with the indicated hash."""
|
||||
results = []
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE original_hash=?;
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
@ -307,83 +298,35 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
]
|
||||
return results
|
||||
|
||||
@property
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
return self._metadata_store
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Retrieve metadata (if any) from when model was downloaded from a repo.
|
||||
|
||||
:param key: Model key
|
||||
"""
|
||||
store = self.metadata_store
|
||||
try:
|
||||
metadata = store.get_metadata(key)
|
||||
return metadata
|
||||
except UnknownMetadataException:
|
||||
return None
|
||||
|
||||
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Search model metadata for ones with all listed tags and return their corresponding configs.
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
keys = store.search_by_tag(tags)
|
||||
return [self.get_model(x) for x in keys]
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_tags()
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_all_metadata()
|
||||
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Return a paginated summary listing of each model in the database."""
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name",
|
||||
ModelRecordOrderBy.Type: "a.type",
|
||||
ModelRecordOrderBy.Base: "a.base",
|
||||
ModelRecordOrderBy.Name: "a.name",
|
||||
ModelRecordOrderBy.Format: "a.format",
|
||||
ModelRecordOrderBy.Default: "type, base, format, name",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]:
|
||||
"""Fix up results so that there are no null values."""
|
||||
result: Dict[str, Union[str, int, Set[str]]] = {}
|
||||
for key, item in summary.items():
|
||||
result[key] = item or ""
|
||||
result["tags"] = set(json.loads(summary["tags"] or "[]"))
|
||||
return result
|
||||
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
with self._db.lock:
|
||||
# query1: get the total number of model configs
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from model_config;
|
||||
select count(*) from models;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(self._cursor.fetchone()[0])
|
||||
|
||||
# query2: fetch key fields from the join of model_config and model_metadata
|
||||
# query2: fetch key fields
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT a.id as key, a.type, a.base, a.format, a.name,
|
||||
json_extract(a.config, '$.description') as description,
|
||||
json_extract(b.metadata, '$.tags') as tags
|
||||
FROM model_config AS a
|
||||
LEFT JOIN model_metadata AS b on a.id=b.id
|
||||
SELECT config
|
||||
FROM models
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
@ -394,7 +337,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
items = [ModelSummary.model_validate(_fixup(dict(x))) for x in rows]
|
||||
items = [ModelSummary.model_validate(dict(x)) for x in rows]
|
||||
return PaginatedResults(
|
||||
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
|
||||
)
|
||||
|
@ -1,6 +1,35 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from threading import Event
|
||||
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
|
||||
|
||||
class SessionRunnerBase(ABC):
|
||||
"""
|
||||
Base class for session runner.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def start(self, services: InvocationServices, cancel_event: Event) -> None:
|
||||
"""Starts the session runner"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run(self, queue_item: SessionQueueItem) -> None:
|
||||
"""Runs the session"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def complete(self, queue_item: SessionQueueItem) -> None:
|
||||
"""Completes the session"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None:
|
||||
"""Runs an already prepared node on the session"""
|
||||
pass
|
||||
|
||||
|
||||
class SessionProcessorBase(ABC):
|
||||
|
@ -2,13 +2,14 @@ import traceback
|
||||
from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
@ -16,15 +17,164 @@ from invokeai.app.services.shared.invocation_context import InvocationContextDat
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
|
||||
from ..invoker import Invoker
|
||||
from .session_processor_base import SessionProcessorBase
|
||||
from .session_processor_base import SessionProcessorBase, SessionRunnerBase
|
||||
from .session_processor_common import SessionProcessorStatus
|
||||
|
||||
|
||||
class DefaultSessionRunner(SessionRunnerBase):
|
||||
"""Processes a single session's invocations"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||
):
|
||||
self.on_before_run_node = on_before_run_node
|
||||
self.on_after_run_node = on_after_run_node
|
||||
|
||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
|
||||
"""Start the session runner"""
|
||||
self.services = services
|
||||
self.cancel_event = cancel_event
|
||||
|
||||
def run(self, queue_item: SessionQueueItem):
|
||||
"""Run the graph"""
|
||||
if not queue_item.session:
|
||||
raise ValueError("Queue item has no session")
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
|
||||
# Prepare the next node
|
||||
invocation = queue_item.session.next()
|
||||
if invocation is None:
|
||||
# If there are no more invocations, complete the graph
|
||||
break
|
||||
# Build invocation context (the node-facing API
|
||||
self.run_node(invocation.id, queue_item)
|
||||
self.complete(queue_item)
|
||||
|
||||
def complete(self, queue_item: SessionQueueItem):
|
||||
"""Complete the graph"""
|
||||
self.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
)
|
||||
|
||||
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run before a node is executed"""
|
||||
# Send starting event
|
||||
self.services.events.emit_invocation_started(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
)
|
||||
if self.on_before_run_node is not None:
|
||||
self.on_before_run_node(invocation, queue_item)
|
||||
|
||||
def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run after a node is executed"""
|
||||
if self.on_after_run_node is not None:
|
||||
self.on_after_run_node(invocation, queue_item)
|
||||
|
||||
def run_node(self, node_id: str, queue_item: SessionQueueItem):
|
||||
"""Run a single node in the graph"""
|
||||
# If this error raises a NodeNotFoundError that's handled by the processor
|
||||
invocation = queue_item.session.execution_graph.get_node(node_id)
|
||||
try:
|
||||
self._on_before_run_node(invocation, queue_item)
|
||||
data = InvocationContextData(
|
||||
invocation=invocation,
|
||||
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
queue_item=queue_item,
|
||||
)
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self.services,
|
||||
cancel_event=self.cancel_event,
|
||||
)
|
||||
|
||||
# Invoke the node
|
||||
outputs = invocation.invoke_internal(context=context, services=self.services)
|
||||
|
||||
# Save outputs and history
|
||||
queue_item.session.complete(invocation.id, outputs)
|
||||
|
||||
self._on_after_run_node(invocation, queue_item)
|
||||
# Send complete event on successful runs
|
||||
self.services.events.emit_invocation_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=data.source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
queue_item.session.set_node_error(invocation.id, error)
|
||||
self.services.logger.error(
|
||||
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self.services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.session_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
|
||||
|
||||
class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
|
||||
"""Processes sessions from the session queue"""
|
||||
|
||||
def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None:
|
||||
super().__init__()
|
||||
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
|
||||
|
||||
def start(
|
||||
self,
|
||||
invoker: Invoker,
|
||||
thread_limit: int = 1,
|
||||
polling_interval: int = 1,
|
||||
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
||||
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
||||
) -> None:
|
||||
self._invoker: Invoker = invoker
|
||||
self._queue_item: Optional[SessionQueueItem] = None
|
||||
self._invocation: Optional[BaseInvocation] = None
|
||||
self.on_before_run_session = on_before_run_session
|
||||
self.on_after_run_session = on_after_run_session
|
||||
|
||||
self._resume_event = ThreadEvent()
|
||||
self._stop_event = ThreadEvent()
|
||||
@ -59,6 +209,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
"cancel_event": self._cancel_event,
|
||||
},
|
||||
)
|
||||
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
@ -117,130 +268,34 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
cancel_event.clear()
|
||||
|
||||
# If we have a on_before_run_session callback, call it
|
||||
if self.on_before_run_session is not None:
|
||||
self.on_before_run_session(self._queue_item)
|
||||
|
||||
# If profiling is enabled, start the profiler
|
||||
if self._profiler is not None:
|
||||
self._profiler.start(profile_id=self._queue_item.session_id)
|
||||
|
||||
# Prepare invocations and take the first
|
||||
self._invocation = self._queue_item.session.next()
|
||||
# Run the graph
|
||||
self.session_runner.run(queue_item=self._queue_item)
|
||||
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while self._invocation is not None and not cancel_event.is_set():
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
||||
|
||||
# Send starting event
|
||||
self._invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session_id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
try:
|
||||
with self._invoker.services.performance_statistics.collect_stats(
|
||||
self._invocation, self._queue_item.session.id
|
||||
):
|
||||
# Build invocation context (the node-facing API)
|
||||
data = InvocationContextData(
|
||||
invocation=self._invocation,
|
||||
source_invocation_id=source_invocation_id,
|
||||
queue_item=self._queue_item,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._invoker.services,
|
||||
cancel_event=self._cancel_event,
|
||||
)
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
|
||||
# Invoke the node
|
||||
outputs = self._invocation.invoke_internal(
|
||||
context=context, services=self._invoker.services
|
||||
)
|
||||
|
||||
# Save outputs and history
|
||||
self._queue_item.session.complete(self._invocation.id, outputs)
|
||||
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
self._queue_item.session.set_node_error(self._invocation.id, error)
|
||||
self._invoker.services.logger.error(
|
||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||
)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=self._queue_item.session_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
pass
|
||||
|
||||
# The session is complete if the all invocations are complete or there was an error
|
||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
)
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
|
||||
# Set the invocation to None to prepare for the next session
|
||||
self._invocation = None
|
||||
else:
|
||||
# Prepare the next invocation
|
||||
self._invocation = self._queue_item.session.next()
|
||||
# If we have a on_after_run_session callback, call it
|
||||
if self.on_after_run_session is not None:
|
||||
self.on_after_run_session(self._queue_item)
|
||||
|
||||
# The session is complete, immediately poll for next session
|
||||
self._queue_item = None
|
||||
@ -274,3 +329,4 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
poll_now_event.clear()
|
||||
self._queue_item = None
|
||||
self._thread_semaphore.release()
|
||||
self._invoker.services.logger.debug("Session processor stopped")
|
||||
|
@ -9,6 +9,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@ -35,6 +36,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_4())
|
||||
migrator.register_migration(build_migration_5())
|
||||
migrator.register_migration(build_migration_6())
|
||||
migrator.register_migration(build_migration_7())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
@ -0,0 +1,88 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration7Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._create_models_table(cursor)
|
||||
self._drop_old_models_tables(cursor)
|
||||
|
||||
def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the old model_records, model_metadata, model_tags and tags tables."""
|
||||
|
||||
tables = ["model_records", "model_metadata", "model_tags", "tags"]
|
||||
|
||||
for table in tables:
|
||||
cursor.execute(f"DROP TABLE IF EXISTS {table};")
|
||||
|
||||
def _create_models_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the v4.0.0 models table."""
|
||||
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS models (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
hash TEXT GENERATED ALWAYS as (json_extract(config, '$.hash')) VIRTUAL NOT NULL,
|
||||
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||
description TEXT GENERATED ALWAYS as (json_extract(config, '$.description')) VIRTUAL,
|
||||
source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL,
|
||||
source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL,
|
||||
source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL,
|
||||
trigger_phrases TEXT GENERATED ALWAYS as (json_extract(config, '$.trigger_phrases')) VIRTUAL,
|
||||
-- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS models_updated_at
|
||||
AFTER UPDATE
|
||||
ON models FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE models SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
# Add indexes for searchable fields
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS base_index ON models(base);",
|
||||
"CREATE INDEX IF NOT EXISTS type_index ON models(type);",
|
||||
"CREATE INDEX IF NOT EXISTS name_index ON models(name);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON models(path);",
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def build_migration_7() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 6 to 7.
|
||||
|
||||
This migration does the following:
|
||||
- Adds the new models table
|
||||
- Drops the old model_records, model_metadata, model_tags and tags tables.
|
||||
- TODO(MM2): Migrates model names and descriptions from `models.yaml` to the new table (?).
|
||||
"""
|
||||
migration_7 = Migration(
|
||||
from_version=6,
|
||||
to_version=7,
|
||||
callback=Migration7Callback(),
|
||||
)
|
||||
|
||||
return migration_7
|
@ -3,7 +3,6 @@
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from hashlib import sha1
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@ -22,7 +21,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelConfigFactory,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.model_manager.hash import ModelHash
|
||||
|
||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
@ -73,19 +72,27 @@ class MigrateModelYamlToDb1:
|
||||
|
||||
base_type, model_type, model_name = str(model_key).split("/")
|
||||
try:
|
||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||
hash = ModelHash().hash(self.config.models_path / stanza.path)
|
||||
except OSError:
|
||||
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
||||
continue
|
||||
|
||||
assert isinstance(model_key, str)
|
||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||
|
||||
stanza["base"] = BaseModelType(base_type)
|
||||
stanza["type"] = ModelType(model_type)
|
||||
stanza["name"] = model_name
|
||||
stanza["original_hash"] = hash
|
||||
stanza["current_hash"] = hash
|
||||
new_key = hash # deterministic key assignment
|
||||
|
||||
# special case for ip adapters, which need the new `image_encoder_model_id` field
|
||||
if stanza["type"] == ModelType.IPAdapter:
|
||||
try:
|
||||
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
|
||||
self.config.models_path / stanza.path
|
||||
)
|
||||
except OSError:
|
||||
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
|
||||
continue
|
||||
|
||||
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||
|
||||
@ -95,7 +102,7 @@ class MigrateModelYamlToDb1:
|
||||
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||
self._update_model(key, new_config)
|
||||
else:
|
||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
||||
self.logger.info(f"Adding model {model_name} with key {new_key}")
|
||||
self._add_model(new_key, new_config)
|
||||
except DuplicateModelException:
|
||||
self.logger.warning(f"Model {model_name} is already in the database")
|
||||
@ -143,9 +150,14 @@ class MigrateModelYamlToDb1:
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.original_hash,
|
||||
record.hash,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
|
||||
|
||||
def _get_image_encoder_model_id(self, model_path: Path) -> str:
|
||||
with open(model_path / "image_encoder.txt") as f:
|
||||
encoder = f.read()
|
||||
return encoder.strip()
|
||||
|
@ -17,7 +17,8 @@ class MigrateCallback(Protocol):
|
||||
See :class:`Migration` for an example.
|
||||
"""
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
...
|
||||
|
||||
|
||||
class MigrationError(RuntimeError):
|
||||
|
@ -1,55 +0,0 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from invokeai.app.services.shared.graph import Edge
|
||||
|
||||
|
||||
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
||||
"""
|
||||
Parses raw session string, returning a dict of the graph.
|
||||
|
||||
Only the general graph shape is validated; none of the fields are validated.
|
||||
|
||||
Any `metadata_accumulator` nodes and edges are removed.
|
||||
|
||||
Any validation failure will return None.
|
||||
"""
|
||||
|
||||
graph = json.loads(session_raw).get("graph", None)
|
||||
|
||||
# sanity check make sure the graph is at least reasonably shaped
|
||||
if (
|
||||
not isinstance(graph, dict)
|
||||
or "nodes" not in graph
|
||||
or not isinstance(graph["nodes"], dict)
|
||||
or "edges" not in graph
|
||||
or not isinstance(graph["edges"], list)
|
||||
):
|
||||
# something has gone terribly awry, return an empty dict
|
||||
return None
|
||||
|
||||
try:
|
||||
# delete the `metadata_accumulator` node
|
||||
del graph["nodes"]["metadata_accumulator"]
|
||||
except KeyError:
|
||||
# no accumulator node, all good
|
||||
pass
|
||||
|
||||
# delete any edges to or from it
|
||||
for i, edge in enumerate(graph["edges"]):
|
||||
try:
|
||||
# try to parse the edge
|
||||
Edge(**edge)
|
||||
except ValidationError:
|
||||
# something has gone terribly awry, return an empty dict
|
||||
return None
|
||||
|
||||
if (
|
||||
edge["source"]["node_id"] == "metadata_accumulator"
|
||||
or edge["destination"]["node_id"] == "metadata_accumulator"
|
||||
):
|
||||
del graph["edges"][i]
|
||||
|
||||
return graph
|
@ -25,10 +25,13 @@ from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from diffusers import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
from ..raw_model import RawModel
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
@ -56,8 +59,8 @@ class ModelType(str, Enum):
|
||||
|
||||
ONNX = "onnx"
|
||||
Main = "main"
|
||||
Vae = "vae"
|
||||
Lora = "lora"
|
||||
VAE = "vae"
|
||||
LoRA = "lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
@ -73,9 +76,9 @@ class SubModelType(str, Enum):
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Vae = "vae"
|
||||
VaeDecoder = "vae_decoder"
|
||||
VaeEncoder = "vae_encoder"
|
||||
VAE = "vae"
|
||||
VAEDecoder = "vae_decoder"
|
||||
VAEEncoder = "vae_encoder"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
|
||||
@ -93,8 +96,8 @@ class ModelFormat(str, Enum):
|
||||
|
||||
Diffusers = "diffusers"
|
||||
Checkpoint = "checkpoint"
|
||||
Lycoris = "lycoris"
|
||||
Onnx = "onnx"
|
||||
LyCORIS = "lycoris"
|
||||
ONNX = "onnx"
|
||||
Olive = "olive"
|
||||
EmbeddingFile = "embedding_file"
|
||||
EmbeddingFolder = "embedding_folder"
|
||||
@ -112,128 +115,187 @@ class SchedulerPredictionType(str, Enum):
|
||||
class ModelRepoVariant(str, Enum):
|
||||
"""Various hugging face variants on the diffusers format."""
|
||||
|
||||
DEFAULT = "" # model files without "fp16" or other qualifier - empty str
|
||||
Default = "" # model files without "fp16" or other qualifier - empty str
|
||||
FP16 = "fp16"
|
||||
FP32 = "fp32"
|
||||
ONNX = "onnx"
|
||||
OPENVINO = "openvino"
|
||||
FLAX = "flax"
|
||||
OpenVINO = "openvino"
|
||||
Flax = "flax"
|
||||
|
||||
|
||||
class ModelSourceType(str, Enum):
|
||||
"""Model source type."""
|
||||
|
||||
Path = "path"
|
||||
Url = "url"
|
||||
HFRepoID = "hf_repo_id"
|
||||
CivitAI = "civitai"
|
||||
|
||||
|
||||
class ModelDefaultSettings(BaseModel):
|
||||
vae: str | None
|
||||
vae_precision: str | None
|
||||
scheduler: SCHEDULER_NAME_VALUES | None
|
||||
steps: int | None
|
||||
cfg_scale: float | None
|
||||
cfg_rescale_multiplier: float | None
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
"""Base class for model configuration information."""
|
||||
|
||||
path: str = Field(description="filesystem path to the model file or directory")
|
||||
name: str = Field(description="model name")
|
||||
base: BaseModelType = Field(description="base model")
|
||||
type: ModelType = Field(description="type of the model")
|
||||
format: ModelFormat = Field(description="model format")
|
||||
key: str = Field(description="unique key for model", default="<NOKEY>")
|
||||
original_hash: Optional[str] = Field(
|
||||
description="original fasthash of model contents", default=None
|
||||
) # this is assigned at install time and will not change
|
||||
current_hash: Optional[str] = Field(
|
||||
description="current fasthash of model contents", default=None
|
||||
) # if model is converted or otherwise modified, this will hold updated hash
|
||||
description: Optional[str] = Field(description="human readable description of the model", default=None)
|
||||
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
|
||||
last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time)
|
||||
key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
|
||||
hash: str = Field(description="The hash of the model file(s).")
|
||||
path: str = Field(
|
||||
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
|
||||
)
|
||||
name: str = Field(description="Name of the model.")
|
||||
base: BaseModelType = Field(description="The base model.")
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
source: str = Field(description="The original source of the model (path, URL or repo_id).")
|
||||
source_type: ModelSourceType = Field(description="The type of source")
|
||||
source_api_response: Optional[str] = Field(
|
||||
description="The original API response from the source, as stringified JSON.", default=None
|
||||
)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
schema["required"].extend(
|
||||
["key", "base", "type", "format", "original_hash", "current_hash", "source", "last_modified"]
|
||||
)
|
||||
schema["required"].extend(["key", "type", "format"])
|
||||
|
||||
model_config = ConfigDict(
|
||||
use_enum_values=False,
|
||||
validate_assignment=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
)
|
||||
|
||||
def update(self, attributes: Dict[str, Any]) -> None:
|
||||
"""Update the object with fields in dict."""
|
||||
for key, value in attributes.items():
|
||||
setattr(self, key, value) # may raise a validation error
|
||||
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
||||
|
||||
|
||||
class _CheckpointConfig(ModelConfigBase):
|
||||
class CheckpointConfigBase(ModelConfigBase):
|
||||
"""Model config for checkpoint-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
config: str = Field(description="path to the checkpoint model config file")
|
||||
config_path: str = Field(description="path to the checkpoint model config file")
|
||||
converted_at: Optional[float] = Field(
|
||||
description="When this model was last converted to diffusers", default_factory=time.time
|
||||
)
|
||||
|
||||
|
||||
class _DiffusersConfig(ModelConfigBase):
|
||||
class DiffusersConfigBase(ModelConfigBase):
|
||||
"""Model config for diffusers-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
|
||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
|
||||
|
||||
|
||||
class LoRAConfig(ModelConfigBase):
|
||||
class LoRALyCORISConfig(ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
type: Literal[ModelType.Lora] = ModelType.Lora
|
||||
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
|
||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
|
||||
class VaeCheckpointConfig(ModelConfigBase):
|
||||
class LoRADiffusersConfig(ModelConfigBase):
|
||||
"""Model config for LoRA/Diffusers models."""
|
||||
|
||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class VAECheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
class VaeDiffusersConfig(ModelConfigBase):
|
||||
|
||||
class VAEDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for standalone VAE models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
class ControlNetDiffusersConfig(_DiffusersConfig):
|
||||
|
||||
class ControlNetDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
class ControlNetCheckpointConfig(_CheckpointConfig):
|
||||
|
||||
class ControlNetCheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
class TextualInversionConfig(ModelConfigBase):
|
||||
|
||||
class TextualInversionFileConfig(ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
|
||||
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
|
||||
|
||||
|
||||
class _MainConfig(ModelConfigBase):
|
||||
"""Model config for main models."""
|
||||
class TextualInversionFolderConfig(ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
vae: Optional[str] = Field(default=None)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
ztsnr_training: bool = False
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
|
||||
|
||||
|
||||
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
||||
class MainCheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||
class MainDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class IPAdapterConfig(ModelConfigBase):
|
||||
"""Model config for IP Adaptor format models."""
|
||||
@ -242,6 +304,10 @@ class IPAdapterConfig(ModelConfigBase):
|
||||
image_encoder_model_id: str
|
||||
format: Literal[ModelFormat.InvokeAI]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for ClipVision."""
|
||||
@ -249,58 +315,65 @@ class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
class T2IConfig(ModelConfigBase):
|
||||
|
||||
class T2IAdapterConfig(ModelConfigBase):
|
||||
"""Model config for T2I."""
|
||||
|
||||
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
_ControlNetConfig = Annotated[
|
||||
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
|
||||
Field(discriminator="format"),
|
||||
]
|
||||
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
|
||||
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
|
||||
|
||||
AnyModelConfig = Union[
|
||||
_MainModelConfig,
|
||||
_VaeConfig,
|
||||
_ControlNetConfig,
|
||||
# ModelConfigBase,
|
||||
LoRAConfig,
|
||||
TextualInversionConfig,
|
||||
IPAdapterConfig,
|
||||
CLIPVisionDiffusersConfig,
|
||||
T2IConfig,
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
"""
|
||||
Computes the discriminator value for a model config.
|
||||
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
|
||||
"""
|
||||
format_ = None
|
||||
type_ = None
|
||||
if isinstance(v, dict):
|
||||
format_ = v.get("format")
|
||||
if isinstance(format_, Enum):
|
||||
format_ = format_.value
|
||||
type_ = v.get("type")
|
||||
if isinstance(type_, Enum):
|
||||
type_ = type_.value
|
||||
else:
|
||||
format_ = v.format.value
|
||||
type_ = v.type.value
|
||||
v = f"{type_}.{format_}"
|
||||
return v
|
||||
|
||||
|
||||
AnyModelConfig = Annotated[
|
||||
Union[
|
||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
||||
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
|
||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
|
||||
# IMPLEMENTATION NOTE:
|
||||
# The preferred alternative to the above is a discriminated Union as shown
|
||||
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
|
||||
# This is a known issue. Please see:
|
||||
# https://github.com/tiangolo/fastapi/discussions/9761 and
|
||||
# https://github.com/tiangolo/fastapi/discussions/9287
|
||||
# AnyModelConfig = Annotated[
|
||||
# Union[
|
||||
# _MainModelConfig,
|
||||
# _ONNXConfig,
|
||||
# _VaeConfig,
|
||||
# _ControlNetConfig,
|
||||
# LoRAConfig,
|
||||
# TextualInversionConfig,
|
||||
# IPAdapterConfig,
|
||||
# CLIPVisionDiffusersConfig,
|
||||
# T2IConfig,
|
||||
# ],
|
||||
# Field(discriminator="type"),
|
||||
# ]
|
||||
|
||||
|
||||
class ModelConfigFactory(object):
|
||||
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
||||
|
||||
@ -332,6 +405,6 @@ class ModelConfigFactory(object):
|
||||
assert model is not None
|
||||
if key:
|
||||
model.key = key
|
||||
if timestamp:
|
||||
model.last_modified = timestamp
|
||||
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
|
||||
model.converted_at = timestamp
|
||||
return model # type: ignore
|
||||
|
@ -11,56 +11,175 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
|
||||
from imohash import hashfile
|
||||
from blake3 import blake3
|
||||
|
||||
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
||||
|
||||
ALGORITHM = Literal[
|
||||
"md5",
|
||||
"sha1",
|
||||
"sha224",
|
||||
"sha256",
|
||||
"sha384",
|
||||
"sha512",
|
||||
"blake2b",
|
||||
"blake2s",
|
||||
"sha3_224",
|
||||
"sha3_256",
|
||||
"sha3_384",
|
||||
"sha3_512",
|
||||
"shake_128",
|
||||
"shake_256",
|
||||
"blake3",
|
||||
]
|
||||
|
||||
|
||||
class FastModelHash(object):
|
||||
"""FastModelHash obect provides one public class method, hash()."""
|
||||
class ModelHash:
|
||||
"""
|
||||
Creates a hash of a model using a specified algorithm.
|
||||
|
||||
@classmethod
|
||||
def hash(cls, model_location: Union[str, Path]) -> str:
|
||||
"""
|
||||
Return hexdigest string for model located at model_location.
|
||||
Args:
|
||||
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
|
||||
file_filter: A function that takes a file name and returns True if the file should be included in the hash.
|
||||
|
||||
:param model_location: Path to the model
|
||||
"""
|
||||
model_location = Path(model_location)
|
||||
if model_location.is_file():
|
||||
return cls._hash_file(model_location)
|
||||
elif model_location.is_dir():
|
||||
return cls._hash_dir(model_location)
|
||||
If the model is a single file, it is hashed directly using the provided algorithm.
|
||||
|
||||
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
|
||||
|
||||
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
|
||||
|
||||
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
|
||||
that directory hashes are never weaker than the file hashes.
|
||||
|
||||
Usage:
|
||||
```py
|
||||
# BLAKE3 hash
|
||||
ModelHash().hash("path/to/some/model.safetensors")
|
||||
# MD5
|
||||
ModelHash("md5").hash("path/to/model/dir/")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
|
||||
if algorithm == "blake3":
|
||||
self._hash_file = self._blake3
|
||||
elif algorithm in hashlib.algorithms_available:
|
||||
self._hash_file = self._get_hashlib(algorithm)
|
||||
else:
|
||||
raise OSError(f"Not a valid file or directory: {model_location}")
|
||||
raise ValueError(f"Algorithm {algorithm} not available")
|
||||
|
||||
@classmethod
|
||||
def _hash_file(cls, model_location: Union[str, Path]) -> str:
|
||||
self._file_filter = file_filter or self._default_file_filter
|
||||
|
||||
def hash(self, model_path: Union[str, Path]) -> str:
|
||||
"""
|
||||
Fasthash a single file and return its hexdigest.
|
||||
Return hexdigest of hash of model located at model_path using the algorithm provided at class instantiation.
|
||||
|
||||
:param model_location: Path to the model file
|
||||
If model_path is a directory, the hash is computed by hashing the hashes of all model files in the
|
||||
directory. The final composite hash is always computed using BLAKE3.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
|
||||
Returns:
|
||||
str: Hexdigest of the hash of the model
|
||||
"""
|
||||
# we return md5 hash of the filehash to make it shorter
|
||||
# cryptographic security not needed here
|
||||
return hashlib.md5(hashfile(model_location)).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
||||
components: Dict[str, str] = {}
|
||||
model_path = Path(model_path)
|
||||
if model_path.is_file():
|
||||
return self._hash_file(model_path)
|
||||
elif model_path.is_dir():
|
||||
return self._hash_dir(model_path)
|
||||
else:
|
||||
raise OSError(f"Not a valid file or directory: {model_path}")
|
||||
|
||||
for root, _dirs, files in os.walk(model_location):
|
||||
for file in files:
|
||||
# only tally tensor files because diffusers config files change slightly
|
||||
# depending on how the model was downloaded/converted.
|
||||
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
||||
continue
|
||||
path = (Path(root) / file).as_posix()
|
||||
fast_hash = cls._hash_file(path)
|
||||
components.update({path: fast_hash})
|
||||
def _hash_dir(self, dir: Path) -> str:
|
||||
"""Compute the hash for all files in a directory and return a hexdigest.
|
||||
|
||||
# hash all the model hashes together, using alphabetic file order
|
||||
md5 = hashlib.md5()
|
||||
for _path, fast_hash in sorted(components.items()):
|
||||
md5.update(fast_hash.encode("utf-8"))
|
||||
return md5.hexdigest()
|
||||
Args:
|
||||
dir: Path to the directory
|
||||
|
||||
Returns:
|
||||
str: Hexdigest of the hash of the directory
|
||||
"""
|
||||
model_component_paths = self._get_file_paths(dir, self._file_filter)
|
||||
|
||||
component_hashes: list[str] = []
|
||||
for component in sorted(model_component_paths):
|
||||
component_hashes.append(self._hash_file(component))
|
||||
|
||||
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
|
||||
# for the composite hash
|
||||
composite_hasher = blake3()
|
||||
for h in component_hashes:
|
||||
composite_hasher.update(h.encode("utf-8"))
|
||||
return composite_hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _get_file_paths(model_path: Path, file_filter: Callable[[str], bool]) -> list[Path]:
|
||||
"""Return a list of all model files in the directory.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
file_filter: Function that takes a file name and returns True if the file should be included in the list.
|
||||
|
||||
Returns:
|
||||
List of all model files in the directory
|
||||
"""
|
||||
|
||||
files: list[Path] = []
|
||||
for root, _dirs, _files in os.walk(model_path):
|
||||
for file in _files:
|
||||
if file_filter(file):
|
||||
files.append(Path(root, file))
|
||||
return files
|
||||
|
||||
@staticmethod
|
||||
def _blake3(file_path: Path) -> str:
|
||||
"""Hashes a file using BLAKE3
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to hash
|
||||
|
||||
Returns:
|
||||
Hexdigest of the hash of the file
|
||||
"""
|
||||
file_hasher = blake3(max_threads=blake3.AUTO)
|
||||
file_hasher.update_mmap(file_path)
|
||||
return file_hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
|
||||
"""Factory function that returns a function to hash a file with the given algorithm.
|
||||
|
||||
Args:
|
||||
algorithm: Hashing algorithm to use
|
||||
|
||||
Returns:
|
||||
A function that hashes a file using the given algorithm
|
||||
"""
|
||||
|
||||
def hashlib_hasher(file_path: Path) -> str:
|
||||
"""Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory."""
|
||||
hasher = hashlib.new(algorithm)
|
||||
buffer = bytearray(128 * 1024)
|
||||
mv = memoryview(buffer)
|
||||
with open(file_path, "rb", buffering=0) as f:
|
||||
while n := f.readinto(mv):
|
||||
hasher.update(mv[:n])
|
||||
return hasher.hexdigest()
|
||||
|
||||
return hashlib_hasher
|
||||
|
||||
@staticmethod
|
||||
def _default_file_filter(file_path: str) -> bool:
|
||||
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
True if the file matches the given extensions, otherwise False
|
||||
"""
|
||||
return file_path.endswith(MODEL_FILE_EXTENSIONS)
|
||||
|
@ -13,6 +13,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelRepoVariant,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
@ -50,7 +51,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
if model_config.type == "main" and not submodel_type:
|
||||
if model_config.type is ModelType.Main and not submodel_type:
|
||||
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
||||
|
||||
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
|
||||
@ -80,7 +81,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||
return self._convert_model(config, model_path, cache_path)
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
def _load_if_needed(
|
||||
@ -119,7 +120,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
return calc_model_size_by_fs(
|
||||
model_path=model_path,
|
||||
subfolder=submodel_type.value if submodel_type else None,
|
||||
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
|
||||
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
|
||||
)
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
|
@ -15,10 +15,8 @@ Use like this:
|
||||
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional, Tuple, Type
|
||||
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
|
||||
|
||||
from ..config import (
|
||||
AnyModelConfig,
|
||||
@ -27,8 +25,6 @@ from ..config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
VaeCheckpointConfig,
|
||||
VaeDiffusersConfig,
|
||||
)
|
||||
from . import ModelLoaderBase
|
||||
|
||||
@ -61,6 +57,9 @@ class ModelLoaderRegistryBase(ABC):
|
||||
"""
|
||||
|
||||
|
||||
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
|
||||
|
||||
|
||||
class ModelLoaderRegistry:
|
||||
"""
|
||||
This class allows model loaders to register their type, base and format.
|
||||
@ -71,10 +70,10 @@ class ModelLoaderRegistry:
|
||||
@classmethod
|
||||
def register(
|
||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
||||
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
|
||||
"""Define a decorator which registers the subclass of loader."""
|
||||
|
||||
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
|
||||
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
|
||||
key = cls._to_registry_key(base, type, format)
|
||||
if key in cls._registry:
|
||||
raise Exception(
|
||||
@ -90,33 +89,15 @@ class ModelLoaderRegistry:
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
|
||||
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
|
||||
|
||||
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
|
||||
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
|
||||
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
|
||||
key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
|
||||
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
||||
if not implementation:
|
||||
raise NotImplementedError(
|
||||
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
||||
)
|
||||
return implementation, conf2, submodel_type
|
||||
|
||||
@classmethod
|
||||
def _handle_subtype_overrides(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
|
||||
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
|
||||
model_path = Path(config.vae)
|
||||
config_class = (
|
||||
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
|
||||
)
|
||||
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
|
||||
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
|
||||
submodel_type = None
|
||||
else:
|
||||
new_conf = config
|
||||
return new_conf, submodel_type
|
||||
return implementation, config, submodel_type
|
||||
|
||||
@staticmethod
|
||||
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
|
||||
|
@ -3,8 +3,8 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
@ -12,6 +12,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
@ -20,15 +21,15 @@ from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||
class ControlnetLoader(GenericDiffusersLoader):
|
||||
class ControlNetLoader(GenericDiffusersLoader):
|
||||
"""Class to load ControlNet models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
@ -37,13 +38,13 @@ class ControlnetLoader(GenericDiffusersLoader):
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
raise Exception(f"ControlNet conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
assert hasattr(config, "config")
|
||||
config_file = config.config
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = config.config_path
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
|
@ -3,9 +3,10 @@
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
@ -41,6 +42,7 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
# TO DO: Add exception handling
|
||||
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
||||
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
|
||||
result = None
|
||||
if submodel_type:
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
||||
@ -64,6 +66,7 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||
assert result is not None
|
||||
return result
|
||||
|
||||
# TO DO: Add exception handling
|
||||
@ -75,7 +78,7 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
result: ModelMixin = getattr(res_type, class_name)
|
||||
return result
|
||||
|
||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
|
||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> dict[str, Any]:
|
||||
return ConfigLoader.load_config(model_path, config_name=config_name)
|
||||
|
||||
|
||||
@ -83,8 +86,8 @@ class ConfigLoader(ConfigMixin):
|
||||
"""Subclass of ConfigMixin for loading diffusers configuration files."""
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
def load_config(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # pyright: ignore [reportIncompatibleMethodOverride]
|
||||
"""Load a diffusrs ConfigMixin configuration."""
|
||||
cls.config_name = kwargs.pop("config_name")
|
||||
# Diffusers doesn't provide typing info
|
||||
# TODO(psyche): the types on this diffusers method are not correct
|
||||
return super().load_config(*args, **kwargs) # type: ignore
|
||||
|
@ -31,7 +31,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||
model = build_ip_adapter(
|
||||
ip_adapter_ckpt_path=model_path / "ip_adapter.bin",
|
||||
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
|
||||
device=torch.device("cpu"),
|
||||
dtype=self._torch_dtype,
|
||||
)
|
||||
|
@ -22,8 +22,8 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
||||
class LoraLoader(ModelLoader):
|
||||
"""Class to load LoRA models."""
|
||||
|
||||
|
@ -18,7 +18,7 @@ from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.ONNX)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
|
||||
class OnnyxDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load onnx models."""
|
||||
|
@ -4,7 +4,8 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
@ -16,7 +17,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelVariantType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import MainCheckpointConfig
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
@ -54,11 +55,11 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
return result
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
@ -73,7 +74,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
|
||||
)
|
||||
|
||||
config_file = config.config
|
||||
config_file = config.config_path
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
convert_ckpt_to_diffusers(
|
||||
|
@ -3,9 +3,9 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
@ -13,24 +13,25 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
class VaeLoader(GenericDiffusersLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
@ -38,16 +39,15 @@ class VaeLoader(GenericDiffusersLoader):
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
# TO DO: check whether sdxl VAE models convert.
|
||||
# TODO(MM2): check whether sdxl VAE models convert.
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
raise Exception(f"VAE conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
config_file = (
|
||||
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
|
||||
)
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = config.config_path
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
@ -55,7 +55,7 @@ class VaeLoader(GenericDiffusersLoader):
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
|
||||
ckpt_config = OmegaConf.load(self._app_config.root_path / config_file)
|
||||
assert isinstance(ckpt_config, DictConfig)
|
||||
|
||||
vae_model = convert_ldm_vae_to_diffusers(
|
||||
|
@ -16,6 +16,7 @@ from diffusers import AutoPipelineForText2Image
|
||||
from diffusers.utils import logging as dlogging
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
|
||||
from . import (
|
||||
@ -117,7 +118,6 @@ class ModelMerger(object):
|
||||
config = self._installer.app_config
|
||||
store = self._installer.record_store
|
||||
base_models: Set[BaseModelType] = set()
|
||||
vae = None
|
||||
variant = None if self._installer.app_config.full_precision else "fp16"
|
||||
|
||||
assert (
|
||||
@ -134,10 +134,6 @@ class ModelMerger(object):
|
||||
"normal"
|
||||
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
|
||||
|
||||
# pick up the first model's vae
|
||||
if key == model_keys[0]:
|
||||
vae = info.vae
|
||||
|
||||
# tally base models used
|
||||
base_models.add(info.base)
|
||||
model_paths.extend([config.models_path / info.path])
|
||||
@ -163,12 +159,10 @@ class ModelMerger(object):
|
||||
|
||||
# update model's config
|
||||
model_config = self._installer.record_store.get_model(key)
|
||||
model_config.update(
|
||||
{
|
||||
"name": merged_model_name,
|
||||
"description": f"Merge of models {', '.join(model_names)}",
|
||||
"vae": vae,
|
||||
}
|
||||
model_config.name = merged_model_name
|
||||
model_config.description = f"Merge of models {', '.join(model_names)}"
|
||||
|
||||
self._installer.record_store.update_model(
|
||||
key, ModelRecordChanges(name=model_config.name, description=model_config.description)
|
||||
)
|
||||
self._installer.record_store.update_model(key, model_config)
|
||||
return model_config
|
||||
|
@ -25,9 +25,7 @@ from .metadata_base import (
|
||||
AnyModelRepoMetadataValidator,
|
||||
BaseMetadata,
|
||||
CivitaiMetadata,
|
||||
CommercialUsage,
|
||||
HuggingFaceMetadata,
|
||||
LicenseRestrictions,
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
@ -38,10 +36,8 @@ __all__ = [
|
||||
"AnyModelRepoMetadataValidator",
|
||||
"CivitaiMetadata",
|
||||
"CivitaiMetadataFetch",
|
||||
"CommercialUsage",
|
||||
"HuggingFaceMetadata",
|
||||
"HuggingFaceMetadataFetch",
|
||||
"LicenseRestrictions",
|
||||
"ModelMetadataFetchBase",
|
||||
"BaseMetadata",
|
||||
"ModelMetadataWithFiles",
|
||||
|
@ -23,22 +23,21 @@ metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
||||
print(metadata.trained_words)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
|
||||
from ..metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadata,
|
||||
CommercialUsage,
|
||||
LicenseRestrictions,
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
@ -52,10 +51,13 @@ CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
||||
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||
|
||||
|
||||
StringSetAdapter = TypeAdapter(set[str])
|
||||
|
||||
|
||||
class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
"""Fetch model metadata from Civitai."""
|
||||
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
@ -63,6 +65,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
this module without an internet connection.
|
||||
"""
|
||||
self._requests = session or requests.Session()
|
||||
self._api_key = api_key
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
@ -102,22 +105,21 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
May raise an `UnknownMetadataException`.
|
||||
"""
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(model_url).json()
|
||||
return self._from_model_json(model_json)
|
||||
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
||||
return self._from_api_response(model_json)
|
||||
|
||||
def _from_model_json(self, model_json: Dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
try:
|
||||
version_id = version_id or model_json["modelVersions"][0]["id"]
|
||||
version_id = version_id or api_response["modelVersions"][0]["id"]
|
||||
except TypeError as excp:
|
||||
raise UnknownMetadataException from excp
|
||||
|
||||
# loop till we find the section containing the version requested
|
||||
version_sections = [x for x in model_json["modelVersions"] if x["id"] == version_id]
|
||||
version_sections = [x for x in api_response["modelVersions"] if x["id"] == version_id]
|
||||
if not version_sections:
|
||||
raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
|
||||
|
||||
version_json = version_sections[0]
|
||||
safe_thumbnails = [x["url"] for x in version_json["images"] if x["nsfw"] == "None"]
|
||||
|
||||
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
|
||||
primary = [x for x in version_json["files"] if x.get("primary")]
|
||||
@ -134,36 +136,23 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
url = url + f"?type={primary_file['type']}{metadata_string}"
|
||||
model_files = [
|
||||
RemoteModelFile(
|
||||
url=url,
|
||||
url=self._get_url_with_api_key(url),
|
||||
path=Path(primary_file["name"]),
|
||||
size=int(primary_file["sizeKB"] * 1024),
|
||||
sha256=primary_file["hashes"]["SHA256"],
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
trigger_phrases = StringSetAdapter.validate_python(version_json.get("trainedWords"))
|
||||
except ValidationError:
|
||||
trigger_phrases: set[str] = set()
|
||||
|
||||
return CivitaiMetadata(
|
||||
id=model_json["id"],
|
||||
name=version_json["name"],
|
||||
version_id=version_json["id"],
|
||||
version_name=version_json["name"],
|
||||
created=datetime.fromisoformat(_fix_timezone(version_json["createdAt"])),
|
||||
updated=datetime.fromisoformat(_fix_timezone(version_json["updatedAt"])),
|
||||
published=datetime.fromisoformat(_fix_timezone(version_json["publishedAt"])),
|
||||
base_model_trained_on=version_json["baseModel"], # note - need a dictionary to turn into a BaseModelType
|
||||
files=model_files,
|
||||
download_url=version_json["downloadUrl"],
|
||||
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
|
||||
author=model_json["creator"]["username"],
|
||||
description=model_json["description"],
|
||||
version_description=version_json["description"] or "",
|
||||
tags=model_json["tags"],
|
||||
trained_words=version_json["trainedWords"],
|
||||
nsfw=model_json["nsfw"],
|
||||
restrictions=LicenseRestrictions(
|
||||
AllowNoCredit=model_json["allowNoCredit"],
|
||||
AllowCommercialUse={CommercialUsage(x) for x in model_json["allowCommercialUse"]},
|
||||
AllowDerivatives=model_json["allowDerivatives"],
|
||||
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
||||
),
|
||||
trigger_phrases=trigger_phrases,
|
||||
api_response=json.dumps(version_json),
|
||||
)
|
||||
|
||||
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
@ -174,14 +163,14 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
"""
|
||||
if model_id is None:
|
||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||
version = self._requests.get(version_url).json()
|
||||
version = self._requests.get(self._get_url_with_api_key(version_url)).json()
|
||||
if error := version.get("error"):
|
||||
raise UnknownMetadataException(error)
|
||||
model_id = version["modelId"]
|
||||
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(model_url).json()
|
||||
return self._from_model_json(model_json, version_id)
|
||||
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
||||
return self._from_api_response(model_json, version_id)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: str) -> CivitaiMetadata:
|
||||
@ -189,6 +178,11 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
metadata = CivitaiMetadata.model_validate_json(json)
|
||||
return metadata
|
||||
|
||||
def _get_url_with_api_key(self, url: str) -> str:
|
||||
if not self._api_key:
|
||||
return url
|
||||
|
||||
def _fix_timezone(date: str) -> str:
|
||||
return re.sub(r"Z$", "+00:00", date)
|
||||
if "?" in url:
|
||||
return f"{url}&token={self._api_key}"
|
||||
|
||||
return f"{url}?token={self._api_key}"
|
||||
|
@ -13,6 +13,7 @@ metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
|
||||
print(metadata.tags)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@ -23,7 +24,7 @@ from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFo
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
|
||||
from ..metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
@ -60,6 +61,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
# Little loop which tries fetching a revision corresponding to the selected variant.
|
||||
# If not available, then set variant to None and get the default.
|
||||
# If this too fails, raise exception.
|
||||
|
||||
model_info = None
|
||||
while not model_info:
|
||||
try:
|
||||
@ -72,23 +74,24 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
else:
|
||||
variant = None
|
||||
|
||||
files: list[RemoteModelFile] = []
|
||||
|
||||
_, name = id.split("/")
|
||||
return HuggingFaceMetadata(
|
||||
id=model_info.id,
|
||||
author=model_info.author,
|
||||
name=name,
|
||||
last_modified=model_info.last_modified,
|
||||
tag_dict=model_info.card_data.to_dict() if model_info.card_data else {},
|
||||
tags=model_info.tags,
|
||||
files=[
|
||||
|
||||
for s in model_info.siblings or []:
|
||||
assert s.rfilename is not None
|
||||
assert s.size is not None
|
||||
files.append(
|
||||
RemoteModelFile(
|
||||
url=hf_hub_url(id, x.rfilename, revision=variant),
|
||||
path=Path(name, x.rfilename),
|
||||
size=x.size,
|
||||
sha256=x.lfs.get("sha256") if x.lfs else None,
|
||||
url=hf_hub_url(id, s.rfilename, revision=variant),
|
||||
path=Path(name, s.rfilename),
|
||||
size=s.size,
|
||||
sha256=s.lfs.get("sha256") if s.lfs else None,
|
||||
)
|
||||
for x in model_info.siblings
|
||||
],
|
||||
)
|
||||
|
||||
return HuggingFaceMetadata(
|
||||
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str)
|
||||
)
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
|
@ -14,10 +14,8 @@ versions of these fields are intended to be kept in sync with the
|
||||
remote repo.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from huggingface_hub import configure_http_backend, hf_hub_url
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
@ -34,31 +32,6 @@ class UnknownMetadataException(Exception):
|
||||
"""Raised when no metadata is available for a model."""
|
||||
|
||||
|
||||
class CommercialUsage(str, Enum):
|
||||
"""Type of commercial usage allowed."""
|
||||
|
||||
No = "None"
|
||||
Image = "Image"
|
||||
Rent = "Rent"
|
||||
RentCivit = "RentCivit"
|
||||
Sell = "Sell"
|
||||
|
||||
|
||||
class LicenseRestrictions(BaseModel):
|
||||
"""Broad categories of licensing restrictions."""
|
||||
|
||||
AllowNoCredit: bool = Field(
|
||||
description="if true, model can be redistributed without crediting author", default=False
|
||||
)
|
||||
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
|
||||
AllowDifferentLicense: bool = Field(
|
||||
description="if true, derivatives of this model be redistributed under a different license", default=False
|
||||
)
|
||||
AllowCommercialUse: Optional[Set[CommercialUsage] | CommercialUsage] = Field(
|
||||
description="Type of commercial use allowed if no commercial use is allowed.", default=None
|
||||
)
|
||||
|
||||
|
||||
class RemoteModelFile(BaseModel):
|
||||
"""Information about a downloadable file that forms part of a model."""
|
||||
|
||||
@ -72,8 +45,6 @@ class ModelMetadataBase(BaseModel):
|
||||
"""Base class for model metadata information."""
|
||||
|
||||
name: str = Field(description="model's name")
|
||||
author: str = Field(description="model's author")
|
||||
tags: Set[str] = Field(description="tags provided by model source")
|
||||
|
||||
|
||||
class BaseMetadata(ModelMetadataBase):
|
||||
@ -111,60 +82,16 @@ class CivitaiMetadata(ModelMetadataWithFiles):
|
||||
"""Extended metadata fields provided by Civitai."""
|
||||
|
||||
type: Literal["civitai"] = "civitai"
|
||||
id: int = Field(description="Civitai version identifier")
|
||||
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
|
||||
version_id: int = Field(description="Civitai model version identifier")
|
||||
created: datetime = Field(description="date the model was created")
|
||||
updated: datetime = Field(description="date the model was last modified")
|
||||
published: datetime = Field(description="date the model was published to Civitai")
|
||||
description: str = Field(description="text description of model; may contain HTML")
|
||||
version_description: str = Field(
|
||||
description="text description of the model's reversion; usually change history; may contain HTML"
|
||||
)
|
||||
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
|
||||
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
|
||||
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
|
||||
download_url: AnyHttpUrl = Field(description="download URL for this model")
|
||||
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
|
||||
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
|
||||
weight_minmax: Tuple[float, float] = Field(
|
||||
description="minimum and maximum slider values for a LoRA or other secondary model", default=(-1.0, +2.0)
|
||||
) # note: For future use
|
||||
|
||||
@property
|
||||
def credit_required(self) -> bool:
|
||||
"""Return True if you must give credit for derivatives of this model and images generated from it."""
|
||||
return not self.restrictions.AllowNoCredit
|
||||
|
||||
@property
|
||||
def allow_commercial_use(self) -> bool:
|
||||
"""Return True if commercial use is allowed."""
|
||||
if self.restrictions.AllowCommercialUse is None:
|
||||
return False
|
||||
else:
|
||||
# accommodate schema change
|
||||
acu = self.restrictions.AllowCommercialUse
|
||||
commercial_usage = acu if isinstance(acu, set) else {acu}
|
||||
return CommercialUsage.No not in commercial_usage
|
||||
|
||||
@property
|
||||
def allow_derivatives(self) -> bool:
|
||||
"""Return True if derivatives of this model can be redistributed."""
|
||||
return self.restrictions.AllowDerivatives
|
||||
|
||||
@property
|
||||
def allow_different_license(self) -> bool:
|
||||
"""Return true if derivatives of this model can use a different license."""
|
||||
return self.restrictions.AllowDifferentLicense
|
||||
trigger_phrases: set[str] = Field(description="Trigger phrases extracted from the API response")
|
||||
api_response: Optional[str] = Field(description="Response from the Civitai API as stringified JSON", default=None)
|
||||
|
||||
|
||||
class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||
"""Extended metadata fields provided by HuggingFace."""
|
||||
|
||||
type: Literal["huggingface"] = "huggingface"
|
||||
id: str = Field(description="huggingface model id")
|
||||
tag_dict: Dict[str, Any]
|
||||
last_modified: datetime = Field(description="date of last commit to repo")
|
||||
id: str = Field(description="The HF model id")
|
||||
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
|
||||
|
||||
def download_urls(
|
||||
self,
|
||||
@ -193,7 +120,7 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||
# the next step reads model_index.json to determine which subdirectories belong
|
||||
# to the model
|
||||
if Path(f"{prefix}model_index.json") in paths:
|
||||
url = hf_hub_url(self.id, filename="model_index.json", subfolder=subfolder)
|
||||
url = hf_hub_url(self.id, filename="model_index.json", subfolder=str(subfolder) if subfolder else None)
|
||||
resp = session.get(url)
|
||||
resp.raise_for_status()
|
||||
submodels = resp.json()
|
||||
|
@ -1,221 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
from .fetch import ModelMetadataFetchBase
|
||||
from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException
|
||||
|
||||
|
||||
class ModelMetadataStore:
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise UnknownMetadataException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id,metadata FROM model_metadata;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select tag_text from tags;
|
||||
"""
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.model_id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches if matches else set()
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
@ -8,6 +8,7 @@ import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.util.util import SilenceWarnings
|
||||
|
||||
from .config import (
|
||||
@ -17,11 +18,12 @@ from .config import (
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from .hash import FastModelHash
|
||||
from .hash import ModelHash
|
||||
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
|
||||
CkptType = Dict[str, Any]
|
||||
@ -95,8 +97,8 @@ class ModelProbe(object):
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"AutoencoderTiny": ModelType.Vae,
|
||||
"AutoencoderKL": ModelType.VAE,
|
||||
"AutoencoderTiny": ModelType.VAE,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||
"T2IAdapter": ModelType.T2IAdapter,
|
||||
@ -108,14 +110,6 @@ class ModelProbe(object):
|
||||
) -> None:
|
||||
cls.PROBES[format][model_type] = probe_class
|
||||
|
||||
@classmethod
|
||||
def heuristic_probe(
|
||||
cls,
|
||||
model_path: Path,
|
||||
fields: Optional[Dict[str, Any]] = None,
|
||||
) -> AnyModelConfig:
|
||||
return cls.probe(model_path, fields)
|
||||
|
||||
@classmethod
|
||||
def probe(
|
||||
cls,
|
||||
@ -137,19 +131,21 @@ class ModelProbe(object):
|
||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||
model_info = None
|
||||
model_type = None
|
||||
if format_type == "diffusers":
|
||||
if format_type is ModelFormat.Diffusers:
|
||||
model_type = cls.get_model_type_from_folder(model_path)
|
||||
else:
|
||||
model_type = cls.get_model_type_from_checkpoint(model_path)
|
||||
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
|
||||
format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
|
||||
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
if not probe_class:
|
||||
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||
|
||||
hash = FastModelHash.hash(model_path)
|
||||
probe = probe_class(model_path)
|
||||
|
||||
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
|
||||
fields["source"] = fields.get("source") or model_path.as_posix()
|
||||
fields["key"] = fields.get("key", uuid_string())
|
||||
fields["path"] = model_path.as_posix()
|
||||
fields["type"] = fields.get("type") or model_type
|
||||
fields["base"] = fields.get("base") or probe.get_base_type()
|
||||
@ -161,15 +157,17 @@ class ModelProbe(object):
|
||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["original_hash"] = fields.get("original_hash") or hash
|
||||
fields["current_hash"] = fields.get("current_hash") or hash
|
||||
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
||||
|
||||
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
|
||||
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
||||
# additional fields needed for main and controlnet models
|
||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
|
||||
fields["config"] = cls._get_checkpoint_config_path(
|
||||
if (
|
||||
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
|
||||
and fields["format"] is ModelFormat.Checkpoint
|
||||
):
|
||||
fields["config_path"] = cls._get_checkpoint_config_path(
|
||||
model_path,
|
||||
model_type=fields["type"],
|
||||
base_type=fields["base"],
|
||||
@ -179,7 +177,7 @@ class ModelProbe(object):
|
||||
|
||||
# additional fields needed for main non-checkpoint models
|
||||
elif fields["type"] == ModelType.Main and fields["format"] in [
|
||||
ModelFormat.Onnx,
|
||||
ModelFormat.ONNX,
|
||||
ModelFormat.Olive,
|
||||
ModelFormat.Diffusers,
|
||||
]:
|
||||
@ -213,11 +211,11 @@ class ModelProbe(object):
|
||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||
return ModelType.Main
|
||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||
return ModelType.Vae
|
||||
return ModelType.VAE
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
return ModelType.Lora
|
||||
return ModelType.LoRA
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
return ModelType.Lora
|
||||
return ModelType.LoRA
|
||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||
return ModelType.ControlNet
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
@ -239,7 +237,7 @@ class ModelProbe(object):
|
||||
if (folder_path / f"learned_embeds.{suffix}").exists():
|
||||
return ModelType.TextualInversion
|
||||
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
|
||||
return ModelType.Lora
|
||||
return ModelType.LoRA
|
||||
if (folder_path / "unet/model.onnx").exists():
|
||||
return ModelType.ONNX
|
||||
if (folder_path / "image_encoder.txt").exists():
|
||||
@ -285,13 +283,21 @@ class ModelProbe(object):
|
||||
if possible_conf.exists():
|
||||
return possible_conf.absolute()
|
||||
|
||||
if model_type == ModelType.Main:
|
||||
if model_type is ModelType.Main:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
elif model_type == ModelType.ControlNet:
|
||||
elif model_type is ModelType.ControlNet:
|
||||
config_file = (
|
||||
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
|
||||
"../controlnet/cldm_v15.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
else "../controlnet/cldm_v21.yaml"
|
||||
)
|
||||
elif model_type is ModelType.VAE:
|
||||
config_file = (
|
||||
"../stable-diffusion/v1-inference.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
else "../stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
else:
|
||||
raise InvalidModelConfigException(
|
||||
@ -497,12 +503,12 @@ class FolderProbeBase(ProbeBase):
|
||||
if ".fp16" in x.suffixes:
|
||||
return ModelRepoVariant.FP16
|
||||
if "openvino_model" in x.name:
|
||||
return ModelRepoVariant.OPENVINO
|
||||
return ModelRepoVariant.OpenVINO
|
||||
if "flax_model" in x.name:
|
||||
return ModelRepoVariant.FLAX
|
||||
return ModelRepoVariant.Flax
|
||||
if x.suffix == ".onnx":
|
||||
return ModelRepoVariant.ONNX
|
||||
return ModelRepoVariant.DEFAULT
|
||||
return ModelRepoVariant.Default
|
||||
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
@ -708,8 +714,8 @@ class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
|
||||
############## register probe classes ######
|
||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
@ -717,8 +723,8 @@ ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderPro
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
|
@ -13,6 +13,7 @@ files_to_download = select_hf_model_files(metadata.files, variant='onnx')
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
@ -34,7 +35,7 @@ def filter_files(
|
||||
The file list can be obtained from the `files` field of HuggingFaceMetadata,
|
||||
as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
|
||||
"""
|
||||
variant = variant or ModelRepoVariant.DEFAULT
|
||||
variant = variant or ModelRepoVariant.Default
|
||||
paths: List[Path] = []
|
||||
root = files[0].parts[0]
|
||||
|
||||
@ -73,64 +74,81 @@ def filter_files(
|
||||
return sorted(_filter_by_variant(paths, variant))
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubfolderCandidate:
|
||||
path: Path
|
||||
score: int
|
||||
|
||||
|
||||
def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]:
|
||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||
result = set()
|
||||
basenames: Dict[Path, Path] = {}
|
||||
result: set[Path] = set()
|
||||
subfolder_weights: dict[Path, list[SubfolderCandidate]] = {}
|
||||
for path in files:
|
||||
if path.suffix in [".onnx", ".pb", ".onnx_data"]:
|
||||
if variant == ModelRepoVariant.ONNX:
|
||||
result.add(path)
|
||||
|
||||
elif "openvino_model" in path.name:
|
||||
if variant == ModelRepoVariant.OPENVINO:
|
||||
if variant == ModelRepoVariant.OpenVINO:
|
||||
result.add(path)
|
||||
|
||||
elif "flax_model" in path.name:
|
||||
if variant == ModelRepoVariant.FLAX:
|
||||
if variant == ModelRepoVariant.Flax:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".json", ".txt"]:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [
|
||||
elif variant in [
|
||||
ModelRepoVariant.FP16,
|
||||
ModelRepoVariant.FP32,
|
||||
ModelRepoVariant.DEFAULT,
|
||||
]:
|
||||
parent = path.parent
|
||||
suffixes = path.suffixes
|
||||
if len(suffixes) == 2:
|
||||
variant_label, suffix = suffixes
|
||||
basename = parent / Path(path.stem).stem
|
||||
else:
|
||||
variant_label = ""
|
||||
suffix = suffixes[0]
|
||||
basename = parent / path.stem
|
||||
ModelRepoVariant.Default,
|
||||
] and path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"]:
|
||||
# For weights files, we want to select the best one for each subfolder. For example, we may have multiple
|
||||
# text encoders:
|
||||
#
|
||||
# - text_encoder/model.fp16.safetensors
|
||||
# - text_encoder/model.safetensors
|
||||
# - text_encoder/pytorch_model.bin
|
||||
# - text_encoder/pytorch_model.fp16.bin
|
||||
#
|
||||
# We prefer safetensors over other file formats and an exact variant match. We'll score each file based on
|
||||
# variant and format and select the best one.
|
||||
|
||||
if previous := basenames.get(basename):
|
||||
if (
|
||||
previous.suffix != ".safetensors" and suffix == ".safetensors"
|
||||
): # replace non-safetensors with safetensors when available
|
||||
basenames[basename] = path
|
||||
if variant_label == f".{variant}":
|
||||
basenames[basename] = path
|
||||
elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]:
|
||||
basenames[basename] = path
|
||||
else:
|
||||
basenames[basename] = path
|
||||
parent = path.parent
|
||||
score = 0
|
||||
|
||||
if path.suffix == ".safetensors":
|
||||
score += 1
|
||||
|
||||
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
|
||||
|
||||
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
|
||||
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
|
||||
if candidate_variant_label == f".{variant}" or (
|
||||
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]
|
||||
):
|
||||
score += 1
|
||||
|
||||
if parent not in subfolder_weights:
|
||||
subfolder_weights[parent] = []
|
||||
|
||||
subfolder_weights[parent].append(SubfolderCandidate(path=path, score=score))
|
||||
|
||||
else:
|
||||
continue
|
||||
|
||||
for v in basenames.values():
|
||||
result.add(v)
|
||||
for candidate_list in subfolder_weights.values():
|
||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||
if highest_score_candidate:
|
||||
result.add(highest_score_candidate.path)
|
||||
|
||||
# If one of the architecture-related variants was specified and no files matched other than
|
||||
# config and text files then we return an empty list
|
||||
if (
|
||||
variant
|
||||
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OPENVINO, ModelRepoVariant.FLAX]
|
||||
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OpenVINO, ModelRepoVariant.Flax]
|
||||
and not any(variant.value in x.name for x in result)
|
||||
):
|
||||
return set()
|
||||
|
@ -4,13 +4,11 @@ Initialization file for the invokeai.backend.stable_diffusion package
|
||||
|
||||
from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401
|
||||
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||
from .seamless import set_seamless # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"PipelineIntermediateState",
|
||||
"StableDiffusionGeneratorPipeline",
|
||||
"InvokeAIDiffuserComponent",
|
||||
"AttentionMapSaver",
|
||||
"set_seamless",
|
||||
]
|
||||
|
@ -12,7 +12,6 @@ import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
@ -26,9 +25,9 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
from ..util import auto_detect_slice_size, normalize_device
|
||||
from .diffusion import AttentionMapSaver, InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -39,7 +38,6 @@ class PipelineIntermediateState:
|
||||
timestep: int
|
||||
latents: torch.Tensor
|
||||
predicted_original: Optional[torch.Tensor] = None
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -190,19 +188,6 @@ class T2IAdapterData:
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
||||
r"""
|
||||
Output class for InvokeAI's Stable Diffusion pipeline.
|
||||
|
||||
Args:
|
||||
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
||||
after generation completes. Optional.
|
||||
"""
|
||||
|
||||
attention_map_saver: Optional[AttentionMapSaver]
|
||||
|
||||
|
||||
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
@ -343,9 +328,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
masked_latents: Optional[torch.Tensor] = None,
|
||||
gradient_mask: Optional[bool] = False,
|
||||
seed: Optional[int] = None,
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
) -> torch.Tensor:
|
||||
if init_timestep.shape[0] == 0:
|
||||
return latents, None
|
||||
return latents
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
@ -385,7 +370,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
|
||||
|
||||
try:
|
||||
latents, attention_map_saver = self.generate_latents_from_embeddings(
|
||||
latents = self.generate_latents_from_embeddings(
|
||||
latents,
|
||||
timesteps,
|
||||
conditioning_data,
|
||||
@ -402,7 +387,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if mask is not None and not gradient_mask:
|
||||
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
|
||||
|
||||
return latents, attention_map_saver
|
||||
return latents
|
||||
|
||||
def generate_latents_from_embeddings(
|
||||
self,
|
||||
@ -415,16 +400,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
if timesteps.shape[0] == 0:
|
||||
return latents, attention_map_saver
|
||||
return latents
|
||||
|
||||
ip_adapter_unet_patcher = None
|
||||
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning
|
||||
@ -432,7 +416,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
)
|
||||
self.use_ip_adapter = False
|
||||
elif ip_adapter_data is not None:
|
||||
@ -483,13 +466,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||
|
||||
# TODO resuscitate attention map saving
|
||||
# if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
||||
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||
# attention_map_token_ids = range(1, eos_token_index)
|
||||
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
||||
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||
|
||||
if callback is not None:
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
@ -499,11 +475,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
timestep=int(t),
|
||||
latents=latents,
|
||||
predicted_original=predicted_original,
|
||||
attention_map_saver=attention_map_saver,
|
||||
)
|
||||
)
|
||||
|
||||
return latents, attention_map_saver
|
||||
return latents
|
||||
|
||||
@torch.inference_mode()
|
||||
def step(
|
||||
|
@ -2,6 +2,4 @@
|
||||
Initialization file for invokeai.models.diffusion
|
||||
"""
|
||||
|
||||
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
|
||||
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||
|
@ -3,19 +3,13 @@
|
||||
|
||||
|
||||
import enum
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
import diffusers
|
||||
import psutil
|
||||
import torch
|
||||
from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, SlicedAttnProcessor
|
||||
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from torch import nn
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ...util import torch_dtype
|
||||
|
||||
@ -25,72 +19,14 @@ class CrossAttentionType(enum.Enum):
|
||||
TOKENS = 2
|
||||
|
||||
|
||||
class Context:
|
||||
cross_attention_mask: Optional[torch.Tensor]
|
||||
cross_attention_index_map: Optional[torch.Tensor]
|
||||
|
||||
class Action(enum.Enum):
|
||||
NONE = 0
|
||||
SAVE = (1,)
|
||||
APPLY = 2
|
||||
|
||||
def __init__(self, arguments: Arguments, step_count: int):
|
||||
class CrossAttnControlContext:
|
||||
def __init__(self, arguments: Arguments):
|
||||
"""
|
||||
:param arguments: Arguments for the cross-attention control process
|
||||
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
|
||||
"""
|
||||
self.cross_attention_mask = None
|
||||
self.cross_attention_index_map = None
|
||||
self.self_cross_attention_action = Context.Action.NONE
|
||||
self.tokens_cross_attention_action = Context.Action.NONE
|
||||
self.cross_attention_mask: Optional[torch.Tensor] = None
|
||||
self.cross_attention_index_map: Optional[torch.Tensor] = None
|
||||
self.arguments = arguments
|
||||
self.step_count = step_count
|
||||
|
||||
self.self_cross_attention_module_identifiers = []
|
||||
self.tokens_cross_attention_module_identifiers = []
|
||||
|
||||
self.saved_cross_attention_maps = {}
|
||||
|
||||
self.clear_requests(cleanup=True)
|
||||
|
||||
def register_cross_attention_modules(self, model):
|
||||
for name, _module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
||||
if name in self.self_cross_attention_module_identifiers:
|
||||
raise AssertionError(f"name {name} cannot appear more than once")
|
||||
self.self_cross_attention_module_identifiers.append(name)
|
||||
for name, _module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||
if name in self.tokens_cross_attention_module_identifiers:
|
||||
raise AssertionError(f"name {name} cannot appear more than once")
|
||||
self.tokens_cross_attention_module_identifiers.append(name)
|
||||
|
||||
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
if cross_attention_type == CrossAttentionType.SELF:
|
||||
self.self_cross_attention_action = Context.Action.SAVE
|
||||
else:
|
||||
self.tokens_cross_attention_action = Context.Action.SAVE
|
||||
|
||||
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
if cross_attention_type == CrossAttentionType.SELF:
|
||||
self.self_cross_attention_action = Context.Action.APPLY
|
||||
else:
|
||||
self.tokens_cross_attention_action = Context.Action.APPLY
|
||||
|
||||
def is_tokens_cross_attention(self, module_identifier) -> bool:
|
||||
return module_identifier in self.tokens_cross_attention_module_identifiers
|
||||
|
||||
def get_should_save_maps(self, module_identifier: str) -> bool:
|
||||
if module_identifier in self.self_cross_attention_module_identifiers:
|
||||
return self.self_cross_attention_action == Context.Action.SAVE
|
||||
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
||||
return self.tokens_cross_attention_action == Context.Action.SAVE
|
||||
return False
|
||||
|
||||
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
|
||||
if module_identifier in self.self_cross_attention_module_identifiers:
|
||||
return self.self_cross_attention_action == Context.Action.APPLY
|
||||
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
||||
return self.tokens_cross_attention_action == Context.Action.APPLY
|
||||
return False
|
||||
|
||||
def get_active_cross_attention_control_types_for_step(
|
||||
self, percent_through: float = None
|
||||
@ -111,219 +47,8 @@ class Context:
|
||||
to_control.append(CrossAttentionType.TOKENS)
|
||||
return to_control
|
||||
|
||||
def save_slice(
|
||||
self,
|
||||
identifier: str,
|
||||
slice: torch.Tensor,
|
||||
dim: Optional[int],
|
||||
offset: int,
|
||||
slice_size: Optional[int],
|
||||
):
|
||||
if identifier not in self.saved_cross_attention_maps:
|
||||
self.saved_cross_attention_maps[identifier] = {
|
||||
"dim": dim,
|
||||
"slice_size": slice_size,
|
||||
"slices": {offset or 0: slice},
|
||||
}
|
||||
else:
|
||||
self.saved_cross_attention_maps[identifier]["slices"][offset or 0] = slice
|
||||
|
||||
def get_slice(
|
||||
self,
|
||||
identifier: str,
|
||||
requested_dim: Optional[int],
|
||||
requested_offset: int,
|
||||
slice_size: int,
|
||||
):
|
||||
saved_attention_dict = self.saved_cross_attention_maps[identifier]
|
||||
if requested_dim is None:
|
||||
if saved_attention_dict["dim"] is not None:
|
||||
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
|
||||
return saved_attention_dict["slices"][0]
|
||||
|
||||
if saved_attention_dict["dim"] == requested_dim:
|
||||
if slice_size != saved_attention_dict["slice_size"]:
|
||||
raise RuntimeError(
|
||||
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}"
|
||||
)
|
||||
return saved_attention_dict["slices"][requested_offset]
|
||||
|
||||
if saved_attention_dict["dim"] is None:
|
||||
whole_saved_attention = saved_attention_dict["slices"][0]
|
||||
if requested_dim == 0:
|
||||
return whole_saved_attention[requested_offset : requested_offset + slice_size]
|
||||
elif requested_dim == 1:
|
||||
return whole_saved_attention[:, requested_offset : requested_offset + slice_size]
|
||||
|
||||
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
||||
|
||||
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
|
||||
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
||||
if saved_attention is None:
|
||||
return None, None
|
||||
return saved_attention["dim"], saved_attention["slice_size"]
|
||||
|
||||
def clear_requests(self, cleanup=True):
|
||||
self.tokens_cross_attention_action = Context.Action.NONE
|
||||
self.self_cross_attention_action = Context.Action.NONE
|
||||
if cleanup:
|
||||
self.saved_cross_attention_maps = {}
|
||||
|
||||
def offload_saved_attention_slices_to_cpu(self):
|
||||
for _key, map_dict in self.saved_cross_attention_maps.items():
|
||||
for offset, slice in map_dict["slices"].items():
|
||||
map_dict[offset] = slice.to("cpu")
|
||||
|
||||
|
||||
class InvokeAICrossAttentionMixin:
|
||||
"""
|
||||
Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
|
||||
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
||||
and dymamic slicing strategy selection.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
self.attention_slice_wrangler = None
|
||||
self.slicing_strategy_getter = None
|
||||
self.attention_slice_calculated_callback = None
|
||||
|
||||
def set_attention_slice_wrangler(
|
||||
self,
|
||||
wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]],
|
||||
):
|
||||
"""
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||
`module` is the current Attention module for which the callback is being invoked.
|
||||
`suggested_attention_slice` is the default-calculated attention slice
|
||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||
|
||||
Pass None to use the default attention calculation.
|
||||
:return:
|
||||
"""
|
||||
self.attention_slice_wrangler = wrangler
|
||||
|
||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]):
|
||||
self.slicing_strategy_getter = getter
|
||||
|
||||
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
|
||||
self.attention_slice_calculated_callback = callback
|
||||
|
||||
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
||||
# calculate attention scores
|
||||
# attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(
|
||||
query.shape[0],
|
||||
query.shape[1],
|
||||
key.shape[1],
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
),
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
# calculate attention slice by taking the best scores for each latent pixel
|
||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
attention_slice_wrangler = self.attention_slice_wrangler
|
||||
if attention_slice_wrangler is not None:
|
||||
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
||||
else:
|
||||
attention_slice = default_attention_slice
|
||||
|
||||
if self.attention_slice_calculated_callback is not None:
|
||||
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
|
||||
|
||||
hidden_states = torch.bmm(attention_slice, value)
|
||||
return hidden_states
|
||||
|
||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
||||
slicing_strategy_getter = self.slicing_strategy_getter
|
||||
if slicing_strategy_getter is not None:
|
||||
(dim, slice_size) = slicing_strategy_getter(self)
|
||||
if dim is not None:
|
||||
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
||||
if dim == 0:
|
||||
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
||||
elif dim == 1:
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
# fallback for when there is no saved strategy, or saved strategy does not slice
|
||||
mem_free_total = get_mem_free_total(q.device)
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
def get_invokeai_attention_mem_efficient(self, q, k, v):
|
||||
if q.device.type == "cuda":
|
||||
# print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
if q.device.type == "mps" or q.device.type == "cpu":
|
||||
if self.mem_total_gb >= 32:
|
||||
return self.einsum_op_mps_v1(q, k, v)
|
||||
return self.einsum_op_mps_v2(q, k, v)
|
||||
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
|
||||
def restore_default_cross_attention(
|
||||
model,
|
||||
is_running_diffusers: bool,
|
||||
restore_attention_processor: Optional[AttentionProcessor] = None,
|
||||
):
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
unet.set_attn_processor(restore_attention_processor or AttnProcessor())
|
||||
else:
|
||||
remove_attention_function(model)
|
||||
|
||||
|
||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: CrossAttnControlContext):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
@ -362,170 +87,6 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
|
||||
|
||||
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||
cross_attention_class: type = InvokeAIDiffusersCrossAttention
|
||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||
attention_module_tuples = [
|
||||
(name, module)
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, cross_attention_class) and which_attn in name
|
||||
]
|
||||
cross_attention_modules_in_model_count = len(attention_module_tuples)
|
||||
expected_count = 16
|
||||
if cross_attention_modules_in_model_count != expected_count:
|
||||
# non-fatal error but .swap() won't work.
|
||||
logger.error(
|
||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
||||
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching "
|
||||
"failed or some assumption has changed about the structure of the model itself. Please fix the "
|
||||
f"monkey-patching, and/or update the {expected_count} above to an appropriate number, and/or find and "
|
||||
"inform someone who knows what it means. This error is non-fatal, but it is likely that .swap() and "
|
||||
"attention map display will not work properly until it is fixed."
|
||||
)
|
||||
return attention_module_tuples
|
||||
|
||||
|
||||
def inject_attention_function(unet, context: Context):
|
||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||
|
||||
def attention_slice_wrangler(module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size):
|
||||
# memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
|
||||
|
||||
attention_slice = suggested_attention_slice
|
||||
|
||||
if context.get_should_save_maps(module.identifier):
|
||||
# print(module.identifier, "saving suggested_attention_slice of shape",
|
||||
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
|
||||
slice_to_save = attention_slice.to("cpu") if dim is not None else attention_slice
|
||||
context.save_slice(
|
||||
module.identifier,
|
||||
slice_to_save,
|
||||
dim=dim,
|
||||
offset=offset,
|
||||
slice_size=slice_size,
|
||||
)
|
||||
elif context.get_should_apply_saved_maps(module.identifier):
|
||||
# print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
|
||||
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
|
||||
|
||||
# slice may have been offloaded to CPU
|
||||
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
|
||||
|
||||
if context.is_tokens_cross_attention(module.identifier):
|
||||
index_map = context.cross_attention_index_map
|
||||
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
|
||||
this_attention_slice = suggested_attention_slice
|
||||
|
||||
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
|
||||
saved_mask = mask
|
||||
this_mask = 1 - mask
|
||||
attention_slice = remapped_saved_attention_slice * saved_mask + this_attention_slice * this_mask
|
||||
else:
|
||||
# just use everything
|
||||
attention_slice = saved_attention_slice
|
||||
|
||||
return attention_slice
|
||||
|
||||
cross_attention_modules = get_cross_attention_modules(
|
||||
unet, CrossAttentionType.TOKENS
|
||||
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for identifier, module in cross_attention_modules:
|
||||
module.identifier = identifier
|
||||
try:
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def remove_attention_function(unet):
|
||||
cross_attention_modules = get_cross_attention_modules(
|
||||
unet, CrossAttentionType.TOKENS
|
||||
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for _identifier, module in cross_attention_modules:
|
||||
try:
|
||||
# clear wrangler callback
|
||||
module.set_attention_slice_wrangler(None)
|
||||
module.set_slicing_strategy_getter(None)
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def is_attribute_error_about(error: AttributeError, attribute: str):
|
||||
if hasattr(error, "name"): # Python 3.10
|
||||
return error.name == attribute
|
||||
else: # Python 3.9
|
||||
return attribute in str(error)
|
||||
|
||||
|
||||
def get_mem_free_total(device):
|
||||
# only on cuda
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
stats = torch.cuda.memory_stats(device)
|
||||
mem_active = stats["active_bytes.all.current"]
|
||||
mem_reserved = stats["reserved_bytes.all.current"]
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
return mem_free_total
|
||||
|
||||
|
||||
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
InvokeAICrossAttentionMixin.__init__(self)
|
||||
|
||||
def _attention(self, query, key, value, attention_mask=None):
|
||||
# default_result = super()._attention(query, key, value)
|
||||
if attention_mask is not None:
|
||||
print(f"{type(self).__name__} ignoring passed-in attention_mask")
|
||||
attention_result = self.get_invokeai_attention_mem_efficient(query, key, value)
|
||||
|
||||
hidden_states = self.reshape_batch_dim_to_heads(attention_result)
|
||||
return hidden_states
|
||||
|
||||
|
||||
## 🧨diffusers implementation follows
|
||||
|
||||
|
||||
"""
|
||||
# base implementation
|
||||
|
||||
class AttnProcessor:
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SwapCrossAttnContext:
|
||||
modified_text_embeddings: torch.Tensor
|
||||
|
@ -1,100 +0,0 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
|
||||
class AttentionMapSaver:
|
||||
def __init__(self, token_ids: range, latents_shape: torch.Size):
|
||||
self.token_ids = token_ids
|
||||
self.latents_shape = latents_shape
|
||||
# self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
||||
self.collated_maps: dict[str, torch.Tensor] = {}
|
||||
|
||||
def clear_maps(self):
|
||||
self.collated_maps = {}
|
||||
|
||||
def add_attention_maps(self, maps: torch.Tensor, key: str):
|
||||
"""
|
||||
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
|
||||
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
|
||||
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
|
||||
:return: None
|
||||
"""
|
||||
key_and_size = f"{key}_{maps.shape[1]}"
|
||||
|
||||
# extract desired tokens
|
||||
maps = maps[:, :, self.token_ids]
|
||||
|
||||
# merge attention heads to a single map per token
|
||||
maps = torch.sum(maps, 0)
|
||||
|
||||
# store
|
||||
if key_and_size not in self.collated_maps:
|
||||
self.collated_maps[key_and_size] = torch.zeros_like(maps, device="cpu")
|
||||
self.collated_maps[key_and_size] += maps.cpu()
|
||||
|
||||
def write_maps_to_disk(self, path: str):
|
||||
pil_image = self.get_stacked_maps_image()
|
||||
if pil_image is not None:
|
||||
pil_image.save(path, "PNG")
|
||||
|
||||
def get_stacked_maps_image(self) -> Optional[Image.Image]:
|
||||
"""
|
||||
Scale all collected attention maps to the same size, blend them together and return as an image.
|
||||
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
|
||||
"""
|
||||
num_tokens = len(self.token_ids)
|
||||
if num_tokens == 0:
|
||||
return None
|
||||
|
||||
latents_height = self.latents_shape[0]
|
||||
latents_width = self.latents_shape[1]
|
||||
|
||||
merged = None
|
||||
|
||||
for _key, maps in self.collated_maps.items():
|
||||
# maps has shape [(H*W), N] for N tokens
|
||||
# but we want [N, H, W]
|
||||
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
||||
this_maps_height = int(float(latents_height) * this_scale_factor)
|
||||
this_maps_width = int(float(latents_width) * this_scale_factor)
|
||||
# and we need to do some dimension juggling
|
||||
maps = torch.reshape(
|
||||
torch.swapdims(maps, 0, 1),
|
||||
[num_tokens, this_maps_height, this_maps_width],
|
||||
)
|
||||
|
||||
# scale to output size if necessary
|
||||
if this_scale_factor != 1:
|
||||
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
|
||||
|
||||
# normalize
|
||||
maps_min = torch.min(maps)
|
||||
maps_range = torch.max(maps) - maps_min
|
||||
# print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
|
||||
maps_normalized = (maps - maps_min) / maps_range
|
||||
# expand to (-0.1, 1.1) and clamp
|
||||
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
|
||||
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
|
||||
|
||||
# merge together, producing a vertical stack
|
||||
maps_stacked = torch.reshape(
|
||||
maps_normalized_expanded_clamped,
|
||||
[num_tokens * latents_height, latents_width],
|
||||
)
|
||||
|
||||
if merged is None:
|
||||
merged = maps_stacked
|
||||
else:
|
||||
# screen blend
|
||||
merged = 1 - (1 - maps_stacked) * (1 - merged)
|
||||
|
||||
if merged is None:
|
||||
return None
|
||||
|
||||
merged_bytes = merged.mul(0xFF).byte()
|
||||
return Image.fromarray(merged_bytes.numpy(), mode="L")
|
@ -17,13 +17,11 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
)
|
||||
|
||||
from .cross_attention_control import (
|
||||
Context,
|
||||
CrossAttentionType,
|
||||
CrossAttnControlContext,
|
||||
SwapCrossAttnContext,
|
||||
get_cross_attention_modules,
|
||||
setup_cross_attention_control_attention_processors,
|
||||
)
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
|
||||
ModelForwardCallback: TypeAlias = Union[
|
||||
# x, t, conditioning, Optional[cross-attention kwargs]
|
||||
@ -69,14 +67,12 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int,
|
||||
):
|
||||
old_attn_processors = unet.attn_processors
|
||||
|
||||
try:
|
||||
self.cross_attention_control_context = Context(
|
||||
self.cross_attention_control_context = CrossAttnControlContext(
|
||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||
step_count=step_count,
|
||||
)
|
||||
setup_cross_attention_control_attention_processors(
|
||||
unet,
|
||||
@ -87,27 +83,6 @@ class InvokeAIDiffuserComponent:
|
||||
finally:
|
||||
self.cross_attention_control_context = None
|
||||
unet.set_attn_processor(old_attn_processors)
|
||||
# TODO resuscitate attention map saving
|
||||
# self.remove_attention_map_saving()
|
||||
|
||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||
def callback(slice, dim, offset, slice_size, key):
|
||||
if dim is not None:
|
||||
# sliced tokens attention map saving is not implemented
|
||||
return
|
||||
saver.add_attention_maps(slice, key)
|
||||
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
for identifier, module in tokens_cross_attention_modules:
|
||||
key = "down" if identifier.startswith("down") else "up" if identifier.startswith("up") else "mid"
|
||||
module.set_attention_slice_calculated_callback(
|
||||
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key)
|
||||
)
|
||||
|
||||
def remove_attention_map_saving(self):
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
for _, module in tokens_cross_attention_modules:
|
||||
module.set_attention_slice_calculated_callback(None)
|
||||
|
||||
def do_controlnet_step(
|
||||
self,
|
||||
@ -592,54 +567,3 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
self.last_percent_through = percent_through
|
||||
return latents.to(device=dev)
|
||||
|
||||
# todo: make this work
|
||||
@classmethod
|
||||
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2) # aka sigmas
|
||||
|
||||
deltas = None
|
||||
uncond_latents = None
|
||||
weighted_cond_list = (
|
||||
c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)]
|
||||
)
|
||||
|
||||
# below is fugly omg
|
||||
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
||||
weights = [1] + [weight for c, weight in weighted_cond_list]
|
||||
chunk_count = math.ceil(len(conditionings) / 2)
|
||||
deltas = None
|
||||
for chunk_index in range(chunk_count):
|
||||
offset = chunk_index * 2
|
||||
chunk_size = min(2, len(conditionings) - offset)
|
||||
|
||||
if chunk_size == 1:
|
||||
c_in = conditionings[offset]
|
||||
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
|
||||
latents_b = None
|
||||
else:
|
||||
c_in = torch.cat(conditionings[offset : offset + 2])
|
||||
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
|
||||
|
||||
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
|
||||
if chunk_index == 0:
|
||||
uncond_latents = latents_a
|
||||
deltas = latents_b - uncond_latents
|
||||
else:
|
||||
deltas = torch.cat((deltas, latents_a - uncond_latents))
|
||||
if latents_b is not None:
|
||||
deltas = torch.cat((deltas, latents_b - uncond_latents))
|
||||
|
||||
# merge the weighted deltas together into a single merged delta
|
||||
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
|
||||
normalize = False
|
||||
if normalize:
|
||||
per_delta_weights /= torch.sum(per_delta_weights)
|
||||
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
|
||||
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
|
||||
|
||||
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
|
||||
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
|
||||
|
||||
return uncond_latents + deltas_merged * global_guidance_scale
|
||||
|
@ -858,9 +858,9 @@ def do_textual_inversion_training(
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
@ -144,7 +144,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
|
||||
self.nextrely = top_of_table
|
||||
self.lora_models = self.add_model_widgets(
|
||||
model_type=ModelType.Lora,
|
||||
model_type=ModelType.LoRA,
|
||||
window_width=window_width,
|
||||
)
|
||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||
|
@ -30,7 +30,7 @@
|
||||
"lint:prettier": "prettier --check .",
|
||||
"lint:tsc": "tsc --noEmit",
|
||||
"lint": "concurrently -g -c red,green,yellow,blue,magenta pnpm:lint:*",
|
||||
"fix": "knip --fix && eslint --fix . && prettier --log-level warn --write .",
|
||||
"fix": "eslint --fix . && prettier --log-level warn --write .",
|
||||
"preinstall": "npx only-allow pnpm",
|
||||
"storybook": "storybook dev -p 6006",
|
||||
"build-storybook": "storybook build",
|
||||
|
@ -134,8 +134,6 @@
|
||||
"loadMore": "Mehr laden",
|
||||
"noImagesInGallery": "Keine Bilder in der Galerie",
|
||||
"loading": "Lade",
|
||||
"preparingDownload": "bereite Download vor",
|
||||
"preparingDownloadFailed": "Problem beim Download vorbereiten",
|
||||
"deleteImage": "Lösche Bild",
|
||||
"copy": "Kopieren",
|
||||
"download": "Runterladen",
|
||||
@ -967,7 +965,7 @@
|
||||
"resumeFailed": "Problem beim Fortsetzen des Prozesses",
|
||||
"pruneFailed": "Problem beim leeren der Warteschlange",
|
||||
"pauseTooltip": "Prozess anhalten",
|
||||
"back": "Hinten",
|
||||
"back": "Ende",
|
||||
"resumeSucceeded": "Prozess wird fortgesetzt",
|
||||
"resumeTooltip": "Prozess wieder aufnehmen",
|
||||
"time": "Zeit",
|
||||
|
@ -78,6 +78,7 @@
|
||||
"aboutDesc": "Using Invoke for work? Check out:",
|
||||
"aboutHeading": "Own Your Creative Power",
|
||||
"accept": "Accept",
|
||||
"add": "Add",
|
||||
"advanced": "Advanced",
|
||||
"advancedOptions": "Advanced Options",
|
||||
"ai": "ai",
|
||||
@ -303,6 +304,12 @@
|
||||
"method": "High Resolution Fix Method"
|
||||
}
|
||||
},
|
||||
"prompt": {
|
||||
"addPromptTrigger": "Add Prompt Trigger",
|
||||
"compatibleEmbeddings": "Compatible Embeddings",
|
||||
"noPromptTriggers": "No triggers available",
|
||||
"noMatchingTriggers": "No matching triggers"
|
||||
},
|
||||
"embedding": {
|
||||
"addEmbedding": "Add Embedding",
|
||||
"incompatibleModel": "Incompatible base model:",
|
||||
@ -734,6 +741,8 @@
|
||||
"customConfig": "Custom Config",
|
||||
"customConfigFileLocation": "Custom Config File Location",
|
||||
"customSaveLocation": "Custom Save Location",
|
||||
"defaultSettings": "Default Settings",
|
||||
"defaultSettingsSaved": "Default Settings Saved",
|
||||
"delete": "Delete",
|
||||
"deleteConfig": "Delete Config",
|
||||
"deleteModel": "Delete Model",
|
||||
@ -768,6 +777,7 @@
|
||||
"mergedModelName": "Merged Model Name",
|
||||
"mergedModelSaveLocation": "Save Location",
|
||||
"mergeModels": "Merge Models",
|
||||
"metadata": "Metadata",
|
||||
"model": "Model",
|
||||
"modelAdded": "Model Added",
|
||||
"modelConversionFailed": "Model Conversion Failed",
|
||||
@ -839,9 +849,12 @@
|
||||
"statusConverting": "Converting",
|
||||
"syncModels": "Sync Models",
|
||||
"syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.",
|
||||
"triggerPhrases": "Trigger Phrases",
|
||||
"typePhraseHere": "Type phrase here",
|
||||
"upcastAttention": "Upcast Attention",
|
||||
"updateModel": "Update Model",
|
||||
"useCustomConfig": "Use Custom Config",
|
||||
"useDefaultSettings": "Use Default Settings",
|
||||
"v1": "v1",
|
||||
"v2_768": "v2 (768px)",
|
||||
"v2_base": "v2 (512px)",
|
||||
@ -860,6 +873,7 @@
|
||||
"models": {
|
||||
"addLora": "Add LoRA",
|
||||
"allLoRAsAdded": "All LoRAs added",
|
||||
"concepts": "Concepts",
|
||||
"loraAlreadyAdded": "LoRA already added",
|
||||
"esrganModel": "ESRGAN Model",
|
||||
"loading": "loading",
|
||||
|
@ -505,8 +505,6 @@
|
||||
"seamLowThreshold": "Bajo",
|
||||
"coherencePassHeader": "Parámetros de la coherencia",
|
||||
"compositingSettingsHeader": "Ajustes de la composición",
|
||||
"coherenceSteps": "Pasos",
|
||||
"coherenceStrength": "Fuerza",
|
||||
"patchmatchDownScaleSize": "Reducir a escala",
|
||||
"coherenceMode": "Modo"
|
||||
},
|
||||
|
@ -114,7 +114,8 @@
|
||||
"checkpoint": "Checkpoint",
|
||||
"safetensors": "Safetensors",
|
||||
"ai": "ia",
|
||||
"file": "File"
|
||||
"file": "File",
|
||||
"toResolve": "Da risolvere"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generazioni",
|
||||
@ -142,8 +143,6 @@
|
||||
"copy": "Copia",
|
||||
"download": "Scarica",
|
||||
"setCurrentImage": "Imposta come immagine corrente",
|
||||
"preparingDownload": "Preparazione del download",
|
||||
"preparingDownloadFailed": "Problema durante la preparazione del download",
|
||||
"downloadSelection": "Scarica gli elementi selezionati",
|
||||
"noImageSelected": "Nessuna immagine selezionata",
|
||||
"deleteSelection": "Elimina la selezione",
|
||||
@ -609,8 +608,6 @@
|
||||
"seamLowThreshold": "Basso",
|
||||
"seamHighThreshold": "Alto",
|
||||
"coherencePassHeader": "Passaggio di coerenza",
|
||||
"coherenceSteps": "Passi",
|
||||
"coherenceStrength": "Forza",
|
||||
"compositingSettingsHeader": "Impostazioni di composizione",
|
||||
"patchmatchDownScaleSize": "Ridimensiona",
|
||||
"coherenceMode": "Modalità",
|
||||
@ -1400,19 +1397,6 @@
|
||||
"Regola la maschera."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceSteps": {
|
||||
"heading": "Passi",
|
||||
"paragraphs": [
|
||||
"Numero di passi utilizzati nel Passaggio di Coerenza.",
|
||||
"Simile ai passi di generazione."
|
||||
]
|
||||
},
|
||||
"compositingBlur": {
|
||||
"heading": "Sfocatura",
|
||||
"paragraphs": [
|
||||
"Il raggio di sfocatura della maschera."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceMode": {
|
||||
"heading": "Modalità",
|
||||
"paragraphs": [
|
||||
@ -1431,13 +1415,6 @@
|
||||
"Un secondo ciclo di riduzione del rumore aiuta a comporre l'immagine Inpaint/Outpaint."
|
||||
]
|
||||
},
|
||||
"compositingStrength": {
|
||||
"heading": "Forza",
|
||||
"paragraphs": [
|
||||
"Quantità di rumore aggiunta per il Passaggio di Coerenza.",
|
||||
"Simile alla forza di riduzione del rumore."
|
||||
]
|
||||
},
|
||||
"paramNegativeConditioning": {
|
||||
"paragraphs": [
|
||||
"Il processo di generazione evita i concetti nel prompt negativo. Utilizzatelo per escludere qualità o oggetti dall'output.",
|
||||
|
@ -123,8 +123,6 @@
|
||||
"autoSwitchNewImages": "새로운 이미지로 자동 전환",
|
||||
"loading": "불러오는 중",
|
||||
"unableToLoad": "갤러리를 로드할 수 없음",
|
||||
"preparingDownload": "다운로드 준비",
|
||||
"preparingDownloadFailed": "다운로드 준비 중 발생한 문제",
|
||||
"singleColumnLayout": "단일 열 레이아웃",
|
||||
"image": "이미지",
|
||||
"loadMore": "더 불러오기",
|
||||
|
@ -97,8 +97,6 @@
|
||||
"featuresWillReset": "Als je deze afbeelding verwijdert, dan worden deze functies onmiddellijk teruggezet.",
|
||||
"loading": "Bezig met laden",
|
||||
"unableToLoad": "Kan galerij niet laden",
|
||||
"preparingDownload": "Bezig met voorbereiden van download",
|
||||
"preparingDownloadFailed": "Fout bij voorbereiden van download",
|
||||
"downloadSelection": "Download selectie",
|
||||
"currentlyInUse": "Deze afbeelding is momenteel in gebruik door de volgende functies:",
|
||||
"copy": "Kopieer",
|
||||
@ -535,8 +533,6 @@
|
||||
"coherencePassHeader": "Coherentiestap",
|
||||
"maskBlur": "Vervaag",
|
||||
"maskBlurMethod": "Vervagingsmethode",
|
||||
"coherenceSteps": "Stappen",
|
||||
"coherenceStrength": "Sterkte",
|
||||
"seamHighThreshold": "Hoog",
|
||||
"seamLowThreshold": "Laag",
|
||||
"invoke": {
|
||||
@ -1139,13 +1135,6 @@
|
||||
"Een afbeeldingsgrootte (in aantal pixels) equivalent aan 512x512 wordt aanbevolen voor SD1.5-modellen. Een grootte-equivalent van 1024x1024 wordt aanbevolen voor SDXL-modellen."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceSteps": {
|
||||
"heading": "Stappen",
|
||||
"paragraphs": [
|
||||
"Het aantal te gebruiken ontruisingsstappen in de coherentiefase.",
|
||||
"Gelijk aan de hoofdparameter Stappen."
|
||||
]
|
||||
},
|
||||
"dynamicPrompts": {
|
||||
"paragraphs": [
|
||||
"Dynamische prompts vormt een enkele prompt om in vele.",
|
||||
@ -1160,12 +1149,6 @@
|
||||
],
|
||||
"heading": "VAE"
|
||||
},
|
||||
"compositingBlur": {
|
||||
"heading": "Vervaging",
|
||||
"paragraphs": [
|
||||
"De vervagingsstraal van het masker."
|
||||
]
|
||||
},
|
||||
"paramIterations": {
|
||||
"paragraphs": [
|
||||
"Het aantal te genereren afbeeldingen.",
|
||||
@ -1240,13 +1223,6 @@
|
||||
],
|
||||
"heading": "Ontruisingssterkte"
|
||||
},
|
||||
"compositingStrength": {
|
||||
"heading": "Sterkte",
|
||||
"paragraphs": [
|
||||
"Ontruisingssterkte voor de coherentiefase.",
|
||||
"Gelijk aan de parameter Ontruisingssterkte Afbeelding naar afbeelding."
|
||||
]
|
||||
},
|
||||
"paramNegativeConditioning": {
|
||||
"paragraphs": [
|
||||
"Het genereerproces voorkomt de gegeven begrippen in de negatieve prompt. Gebruik dit om bepaalde zaken of voorwerpen uit te sluiten van de uitvoerafbeelding.",
|
||||
|
@ -143,8 +143,6 @@
|
||||
"problemDeletingImagesDesc": "Не удалось удалить одно или несколько изображений",
|
||||
"loading": "Загрузка",
|
||||
"unableToLoad": "Невозможно загрузить галерею",
|
||||
"preparingDownload": "Подготовка к скачиванию",
|
||||
"preparingDownloadFailed": "Проблема с подготовкой к скачиванию",
|
||||
"image": "изображение",
|
||||
"drop": "перебросить",
|
||||
"problemDeletingImages": "Проблема с удалением изображений",
|
||||
@ -612,9 +610,7 @@
|
||||
"maskBlurMethod": "Метод размытия",
|
||||
"seamLowThreshold": "Низкий",
|
||||
"seamHighThreshold": "Высокий",
|
||||
"coherenceSteps": "Шагов",
|
||||
"coherencePassHeader": "Порог Coherence",
|
||||
"coherenceStrength": "Сила",
|
||||
"compositingSettingsHeader": "Настройки компоновки",
|
||||
"invoke": {
|
||||
"noNodesInGraph": "Нет узлов в графе",
|
||||
@ -1321,13 +1317,6 @@
|
||||
"Размер изображения (в пикселях), эквивалентный 512x512, рекомендуется для моделей SD1.5, а размер, эквивалентный 1024x1024, рекомендуется для моделей SDXL."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceSteps": {
|
||||
"heading": "Шаги",
|
||||
"paragraphs": [
|
||||
"Количество шагов снижения шума, используемых при прохождении когерентности.",
|
||||
"То же, что и основной параметр «Шаги»."
|
||||
]
|
||||
},
|
||||
"dynamicPrompts": {
|
||||
"paragraphs": [
|
||||
"Динамические запросы превращают одно приглашение на множество.",
|
||||
@ -1342,12 +1331,6 @@
|
||||
],
|
||||
"heading": "VAE"
|
||||
},
|
||||
"compositingBlur": {
|
||||
"heading": "Размытие",
|
||||
"paragraphs": [
|
||||
"Радиус размытия маски."
|
||||
]
|
||||
},
|
||||
"paramIterations": {
|
||||
"paragraphs": [
|
||||
"Количество изображений, которые нужно сгенерировать.",
|
||||
@ -1422,13 +1405,6 @@
|
||||
],
|
||||
"heading": "Шумоподавление"
|
||||
},
|
||||
"compositingStrength": {
|
||||
"heading": "Сила",
|
||||
"paragraphs": [
|
||||
null,
|
||||
"То же, что параметр «Сила шумоподавления img2img»."
|
||||
]
|
||||
},
|
||||
"paramNegativeConditioning": {
|
||||
"paragraphs": [
|
||||
"Stable Diffusion пытается избежать указанных в отрицательном запросе концепций. Используйте это, чтобы исключить качества или объекты из вывода.",
|
||||
|
@ -355,7 +355,6 @@
|
||||
"starImage": "Yıldız Koy",
|
||||
"download": "İndir",
|
||||
"deleteSelection": "Seçileni Sil",
|
||||
"preparingDownloadFailed": "İndirme Hazırlanırken Sorun",
|
||||
"problemDeletingImages": "Görsel Silmede Sorun",
|
||||
"featuresWillReset": "Bu görseli silerseniz, o özellikler resetlenecektir.",
|
||||
"galleryImageResetSize": "Boyutu Resetle",
|
||||
@ -377,7 +376,6 @@
|
||||
"setCurrentImage": "Çalışma Görseli Yap",
|
||||
"unableToLoad": "Galeri Yüklenemedi",
|
||||
"downloadSelection": "Seçileni İndir",
|
||||
"preparingDownload": "İndirmeye Hazırlanıyor",
|
||||
"singleColumnLayout": "Tek Sütun Düzen",
|
||||
"generations": "Çıktılar",
|
||||
"showUploads": "Yüklenenleri Göster",
|
||||
@ -723,7 +721,6 @@
|
||||
"clipSkip": "CLIP Atlama",
|
||||
"randomizeSeed": "Rastgele Tohum",
|
||||
"cfgScale": "CFG Ölçeği",
|
||||
"coherenceStrength": "Etki",
|
||||
"controlNetControlMode": "Yönetim Kipi",
|
||||
"general": "Genel",
|
||||
"img2imgStrength": "Görselden Görsel Ölçüsü",
|
||||
@ -793,7 +790,6 @@
|
||||
"cfgRescaleMultiplier": "CFG Rescale Çarpanı",
|
||||
"cfgRescale": "CFG Rescale",
|
||||
"coherencePassHeader": "Uyum Geçişi",
|
||||
"coherenceSteps": "Adım",
|
||||
"infillMethod": "Doldurma Yöntemi",
|
||||
"maskBlurMethod": "Bulandırma Yöntemi",
|
||||
"steps": "Adım",
|
||||
|
@ -136,8 +136,6 @@
|
||||
"copy": "复制",
|
||||
"download": "下载",
|
||||
"setCurrentImage": "设为当前图像",
|
||||
"preparingDownload": "准备下载",
|
||||
"preparingDownloadFailed": "准备下载时出现问题",
|
||||
"downloadSelection": "下载所选内容",
|
||||
"noImageSelected": "无选中的图像",
|
||||
"deleteSelection": "删除所选内容",
|
||||
@ -616,11 +614,9 @@
|
||||
"incompatibleBaseModelForControlAdapter": "有 #{{number}} 个 Control Adapter 模型与主模型不兼容。"
|
||||
},
|
||||
"patchmatchDownScaleSize": "缩小",
|
||||
"coherenceSteps": "步数",
|
||||
"clipSkip": "CLIP 跳过层",
|
||||
"compositingSettingsHeader": "合成设置",
|
||||
"useCpuNoise": "使用 CPU 噪声",
|
||||
"coherenceStrength": "强度",
|
||||
"enableNoiseSettings": "启用噪声设置",
|
||||
"coherenceMode": "模式",
|
||||
"cpuNoise": "CPU 噪声",
|
||||
@ -1402,19 +1398,6 @@
|
||||
"图像尺寸(单位:像素)建议 SD 1.5 模型使用等效 512x512 的尺寸,SDXL 模型使用等效 1024x1024 的尺寸。"
|
||||
]
|
||||
},
|
||||
"compositingCoherenceSteps": {
|
||||
"heading": "步数",
|
||||
"paragraphs": [
|
||||
"一致性层中使用的去噪步数。",
|
||||
"与主参数中的步数相同。"
|
||||
]
|
||||
},
|
||||
"compositingBlur": {
|
||||
"heading": "模糊",
|
||||
"paragraphs": [
|
||||
"遮罩模糊半径。"
|
||||
]
|
||||
},
|
||||
"noiseUseCPU": {
|
||||
"heading": "使用 CPU 噪声",
|
||||
"paragraphs": [
|
||||
@ -1467,13 +1450,6 @@
|
||||
"第二轮去噪有助于合成内补/外扩图像。"
|
||||
]
|
||||
},
|
||||
"compositingStrength": {
|
||||
"heading": "强度",
|
||||
"paragraphs": [
|
||||
"一致性层使用的去噪强度。",
|
||||
"去噪强度与图生图的参数相同。"
|
||||
]
|
||||
},
|
||||
"paramNegativeConditioning": {
|
||||
"paragraphs": [
|
||||
"生成过程会避免生成负向提示词中的概念。使用此选项来使输出排除部分质量或对象。",
|
||||
|
@ -55,6 +55,8 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
@ -151,5 +153,7 @@ addFirstListImagesListener(startAppListening);
|
||||
// Ad-hoc upscale workflwo
|
||||
addUpscaleRequestedListener(startAppListening);
|
||||
|
||||
// Dynamic prompts
|
||||
// Prompts
|
||||
addDynamicPromptsListener(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening);
|
||||
|
@ -7,8 +7,10 @@ import {
|
||||
selectAllT2IAdapters,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { loraRemoved } from 'features/lora/store/loraSlice';
|
||||
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
|
||||
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
|
||||
import { heightChanged, modelChanged, vaeSelected, widthChanged } from 'features/parameters/store/generationSlice';
|
||||
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
||||
import { forEach, some } from 'lodash-es';
|
||||
import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models';
|
||||
@ -24,7 +26,9 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
||||
const log = logger('models');
|
||||
log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`);
|
||||
|
||||
const currentModel = getState().generation.model;
|
||||
const state = getState();
|
||||
|
||||
const currentModel = state.generation.model;
|
||||
const models = mainModelsAdapterSelectors.selectAll(action.payload);
|
||||
|
||||
if (models.length === 0) {
|
||||
@ -39,6 +43,29 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
||||
return;
|
||||
}
|
||||
|
||||
const defaultModel = state.config.sd.defaultModel;
|
||||
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
|
||||
|
||||
if (defaultModelInList) {
|
||||
const result = zParameterModel.safeParse(defaultModelInList);
|
||||
if (result.success) {
|
||||
dispatch(modelChanged(defaultModelInList, currentModel));
|
||||
|
||||
const optimalDimension = getOptimalDimension(defaultModelInList);
|
||||
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
|
||||
return;
|
||||
}
|
||||
const { width, height } = calculateNewSize(
|
||||
state.generation.aspectRatio.value,
|
||||
optimalDimension * optimalDimension
|
||||
);
|
||||
|
||||
dispatch(widthChanged(width));
|
||||
dispatch(heightChanged(height));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const result = zParameterModel.safeParse(models[0]);
|
||||
|
||||
if (!result.success) {
|
||||
|
@ -0,0 +1,96 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||
import {
|
||||
setCfgRescaleMultiplier,
|
||||
setCfgScale,
|
||||
setScheduler,
|
||||
setSteps,
|
||||
vaePrecisionChanged,
|
||||
vaeSelected,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
isParameterCFGRescaleMultiplier,
|
||||
isParameterCFGScale,
|
||||
isParameterPrecision,
|
||||
isParameterScheduler,
|
||||
isParameterSteps,
|
||||
zParameterVAEModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { map } from 'lodash-es';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
|
||||
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: setDefaultSettings,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
|
||||
const currentModel = state.generation.model;
|
||||
|
||||
if (!currentModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
|
||||
|
||||
if (!modelConfig || !modelConfig.default_settings) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = modelConfig.default_settings;
|
||||
|
||||
if (vae) {
|
||||
// we store this as "default" within default settings
|
||||
// to distinguish it from no default set
|
||||
if (vae === 'default') {
|
||||
dispatch(vaeSelected(null));
|
||||
} else {
|
||||
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
|
||||
const vaeArray = map(data?.entities);
|
||||
const validVae = vaeArray.find((model) => model.key === vae);
|
||||
|
||||
const result = zParameterVAEModel.safeParse(validVae);
|
||||
if (!result.success) {
|
||||
return;
|
||||
}
|
||||
dispatch(vaeSelected(result.data));
|
||||
}
|
||||
}
|
||||
|
||||
if (vae_precision) {
|
||||
if (isParameterPrecision(vae_precision)) {
|
||||
dispatch(vaePrecisionChanged(vae_precision));
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg_scale) {
|
||||
if (isParameterCFGScale(cfg_scale)) {
|
||||
dispatch(setCfgScale(cfg_scale));
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg_rescale_multiplier) {
|
||||
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
|
||||
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
||||
}
|
||||
}
|
||||
|
||||
if (steps) {
|
||||
if (isParameterSteps(steps)) {
|
||||
dispatch(setSteps(steps));
|
||||
}
|
||||
}
|
||||
|
||||
if (scheduler) {
|
||||
if (isParameterScheduler(scheduler)) {
|
||||
dispatch(setScheduler(scheduler));
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
|
||||
},
|
||||
});
|
||||
};
|
@ -14,7 +14,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
const { bytes, total_bytes, id } = action.payload.data;
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.bytes = bytes;
|
||||
@ -33,7 +33,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
const { id } = action.payload.data;
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'completed';
|
||||
@ -41,7 +41,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelConfig' }]));
|
||||
dispatch(api.util.invalidateTags(['Model']));
|
||||
},
|
||||
});
|
||||
|
||||
@ -51,7 +51,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
const { id, error, error_type } = action.payload.data;
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'error';
|
||||
|
@ -1,4 +1,5 @@
|
||||
import type { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
|
||||
@ -82,6 +83,8 @@ export type AppConfig = {
|
||||
guidance: NumericalParameterConfig;
|
||||
cfgRescaleMultiplier: NumericalParameterConfig;
|
||||
img2imgStrength: NumericalParameterConfig;
|
||||
scheduler?: ParameterScheduler;
|
||||
vaePrecision?: ParameterPrecision;
|
||||
// Canvas
|
||||
boundingBoxHeight: NumericalParameterConfig; // initial value comes from model
|
||||
boundingBoxWidth: NumericalParameterConfig; // initial value comes from model
|
||||
|
@ -1,21 +0,0 @@
|
||||
import type { Meta, StoryObj } from '@storybook/react';
|
||||
|
||||
import { EmbeddingSelect } from './EmbeddingSelect';
|
||||
import type { EmbeddingSelectProps } from './types';
|
||||
|
||||
const meta: Meta<typeof EmbeddingSelect> = {
|
||||
title: 'Feature/Prompt/EmbeddingSelect',
|
||||
tags: ['autodocs'],
|
||||
component: EmbeddingSelect,
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof EmbeddingSelect>;
|
||||
|
||||
const Component = (props: EmbeddingSelectProps) => {
|
||||
return <EmbeddingSelect {...props}>Invoke</EmbeddingSelect>;
|
||||
};
|
||||
|
||||
export const Default: Story = {
|
||||
render: Component,
|
||||
};
|
@ -1,67 +0,0 @@
|
||||
import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import type { EmbeddingSelectProps } from 'features/embedding/types';
|
||||
import { t } from 'i18next';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { TextualInversionModelConfig } from 'services/api/types';
|
||||
|
||||
const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
|
||||
|
||||
export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(embedding: TextualInversionModelConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === embedding.base;
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
return !hasMainModel || !isCompatible;
|
||||
},
|
||||
[currentBaseModel]
|
||||
);
|
||||
const { data, isLoading } = useGetTextualInversionModelsQuery();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(embedding: TextualInversionModelConfig | null) => {
|
||||
if (!embedding) {
|
||||
return;
|
||||
}
|
||||
onSelect(embedding.name);
|
||||
},
|
||||
[onSelect]
|
||||
);
|
||||
|
||||
const { options, onChange } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
getIsDisabled,
|
||||
onChange: _onChange,
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<Combobox
|
||||
placeholder={isLoading ? t('common.loading') : t('embedding.addEmbedding')}
|
||||
defaultMenuIsOpen
|
||||
autoFocus
|
||||
value={null}
|
||||
options={options}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
onChange={onChange}
|
||||
onMenuClose={onClose}
|
||||
data-testid="add-embedding"
|
||||
sx={selectStyles}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
EmbeddingSelect.displayName = 'EmbeddingSelect';
|
||||
|
||||
const selectStyles: ChakraProps['sx'] = {
|
||||
w: 'full',
|
||||
};
|
@ -59,7 +59,7 @@ const LoRASelect = () => {
|
||||
return (
|
||||
<FormControl isDisabled={!options.length}>
|
||||
<InformationalPopover feature="lora">
|
||||
<FormLabel>{t('models.lora')} </FormLabel>
|
||||
<FormLabel>{t('models.concepts')} </FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox
|
||||
placeholder={placeholder}
|
||||
|
@ -1,228 +0,0 @@
|
||||
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import BaseModelSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect';
|
||||
import BooleanSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect';
|
||||
import ModelFormatSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect';
|
||||
import ModelTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect';
|
||||
import ModelVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect';
|
||||
import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect';
|
||||
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { isNil, omitBy } from 'lodash-es';
|
||||
import { useCallback, useEffect } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
export const AdvancedImport = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [installModel] = useInstallModelMutation();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const {
|
||||
register,
|
||||
handleSubmit,
|
||||
control,
|
||||
formState: { errors },
|
||||
setValue,
|
||||
resetField,
|
||||
reset,
|
||||
watch,
|
||||
} = useForm<AnyModelConfig>({
|
||||
defaultValues: {
|
||||
name: '',
|
||||
base: 'sd-1',
|
||||
type: 'main',
|
||||
path: '',
|
||||
description: '',
|
||||
format: 'diffusers',
|
||||
vae: '',
|
||||
variant: 'normal',
|
||||
},
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
|
||||
(values) => {
|
||||
installModel({
|
||||
source: values.path,
|
||||
config: omitBy(values, isNil),
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelAdded', {
|
||||
modelName: values.name,
|
||||
}),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
reset();
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
},
|
||||
[installModel, dispatch, t, reset]
|
||||
);
|
||||
|
||||
const watchedModelType = watch('type');
|
||||
const watchedModelFormat = watch('format');
|
||||
|
||||
useEffect(() => {
|
||||
if (watchedModelType === 'main') {
|
||||
setValue('format', 'diffusers');
|
||||
setValue('repo_variant', '');
|
||||
setValue('variant', 'normal');
|
||||
}
|
||||
if (watchedModelType === 'lora') {
|
||||
setValue('format', 'lycoris');
|
||||
} else if (watchedModelType === 'embedding') {
|
||||
setValue('format', 'embedding_file');
|
||||
} else if (watchedModelType === 'ip_adapter') {
|
||||
setValue('format', 'invokeai');
|
||||
} else {
|
||||
setValue('format', 'diffusers');
|
||||
}
|
||||
resetField('upcast_attention');
|
||||
resetField('ztsnr_training');
|
||||
resetField('vae');
|
||||
resetField('config');
|
||||
resetField('prediction_type');
|
||||
resetField('image_encoder_model_id');
|
||||
}, [watchedModelType, resetField, setValue]);
|
||||
|
||||
return (
|
||||
<ScrollableContent>
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
|
||||
<Flex alignItems="flex-end" gap="4">
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
||||
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
|
||||
</FormControl>
|
||||
<Text px="2" fontSize="xs" textAlign="center">
|
||||
{t('modelManager.advancedImportInfo')}
|
||||
</Text>
|
||||
</Flex>
|
||||
|
||||
<Flex p={4} borderRadius={4} bg="base.850" height="100%" direction="column" gap="3">
|
||||
<FormControl isInvalid={Boolean(errors.name)}>
|
||||
<Flex direction="column" width="full">
|
||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
||||
<Input
|
||||
{...register('name', {
|
||||
validate: (value) => value.trim().length >= 3 || 'Must be at least 3 characters',
|
||||
})}
|
||||
/>
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
<Flex>
|
||||
<FormControl>
|
||||
<Flex direction="column" width="full">
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Textarea size="sm" {...register('description')} />
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect control={control} name="base" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('common.format')}</FormLabel>
|
||||
<ModelFormatSelect control={control} name="format" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>{t('modelManager.path')}</FormLabel>
|
||||
<Input
|
||||
{...register('path', {
|
||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||
})}
|
||||
/>
|
||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
</Flex>
|
||||
{watchedModelType === 'main' && (
|
||||
<>
|
||||
<Flex gap={4}>
|
||||
{watchedModelFormat === 'diffusers' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
|
||||
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
|
||||
</FormControl>
|
||||
)}
|
||||
{watchedModelFormat === 'checkpoint' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||
<Input {...register('config')} />
|
||||
</FormControl>
|
||||
)}
|
||||
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
|
||||
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
||||
<Input {...register('vae')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{watchedModelType === 'ip_adapter' && (
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
|
||||
<Input {...register('image_encoder_model_id')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
)}
|
||||
<Button mt={2} type="submit">
|
||||
{t('modelManager.addModel')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</form>
|
||||
</ScrollableContent>
|
||||
);
|
||||
};
|
@ -12,7 +12,7 @@ type SimpleImportModelConfig = {
|
||||
location: string;
|
||||
};
|
||||
|
||||
export const SimpleImport = () => {
|
||||
export const InstallModelForm = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [installModel, { isLoading }] = useInstallModelMutation();
|
@ -5,19 +5,19 @@ import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models';
|
||||
import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { ImportQueueItem } from './ImportQueueItem';
|
||||
import { ModelInstallQueueItem } from './ModelInstallQueueItem';
|
||||
|
||||
export const ImportQueue = () => {
|
||||
export const ModelInstallQueue = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { data } = useGetModelImportsQuery();
|
||||
const { data } = useListModelInstallsQuery();
|
||||
|
||||
const [pruneModelImports] = usePruneModelImportsMutation();
|
||||
const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation();
|
||||
|
||||
const pruneQueue = useCallback(() => {
|
||||
pruneModelImports()
|
||||
const pruneCompletedModelInstalls = useCallback(() => {
|
||||
_pruneCompletedModelInstalls()
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
@ -41,7 +41,7 @@ export const ImportQueue = () => {
|
||||
);
|
||||
}
|
||||
});
|
||||
}, [pruneModelImports, dispatch]);
|
||||
}, [_pruneCompletedModelInstalls, dispatch]);
|
||||
|
||||
const pruneAvailable = useMemo(() => {
|
||||
return data?.some(
|
||||
@ -53,14 +53,19 @@ export const ImportQueue = () => {
|
||||
<Flex flexDir="column" p={3} h="full">
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<Text>{t('modelManager.importQueue')}</Text>
|
||||
<Button size="sm" isDisabled={!pruneAvailable} onClick={pruneQueue} tooltip={t('modelManager.pruneTooltip')}>
|
||||
<Button
|
||||
size="sm"
|
||||
isDisabled={!pruneAvailable}
|
||||
onClick={pruneCompletedModelInstalls}
|
||||
tooltip={t('modelManager.pruneTooltip')}
|
||||
>
|
||||
{t('modelManager.prune')}
|
||||
</Button>
|
||||
</Flex>
|
||||
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column-reverse" gap="2">
|
||||
{data?.map((model) => <ImportQueueItem key={model.id} model={model} />)}
|
||||
{data?.map((model) => <ModelInstallQueueItem key={model.id} installJob={model} />)}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
</Box>
|
@ -6,17 +6,24 @@ import type { ModelInstallStatus } from 'services/api/types';
|
||||
const STATUSES = {
|
||||
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },
|
||||
downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||
downloads_done: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||
running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||
completed: { colorScheme: 'green', translationKey: 'queue.completed' },
|
||||
error: { colorScheme: 'red', translationKey: 'queue.failed' },
|
||||
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
|
||||
};
|
||||
|
||||
const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus; errorReason?: string | null }) => {
|
||||
const ModelInstallQueueBadge = ({
|
||||
status,
|
||||
errorReason,
|
||||
}: {
|
||||
status?: ModelInstallStatus;
|
||||
errorReason?: string | null;
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (!status) {
|
||||
return <></>;
|
||||
if (!status || !Object.keys(STATUSES).includes(status)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
@ -25,4 +32,4 @@ const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
export default memo(ImportQueueBadge);
|
||||
export default memo(ModelInstallQueueBadge);
|
@ -3,15 +3,16 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
|
||||
import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
|
||||
import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
|
||||
|
||||
import ImportQueueBadge from './ImportQueueBadge';
|
||||
import ModelInstallQueueBadge from './ModelInstallQueueBadge';
|
||||
|
||||
type ModelListItemProps = {
|
||||
model: ModelInstallJob;
|
||||
installJob: ModelInstallJob;
|
||||
};
|
||||
|
||||
const formatBytes = (bytes: number) => {
|
||||
@ -26,26 +27,26 @@ const formatBytes = (bytes: number) => {
|
||||
return `${bytes.toFixed(2)} ${units[i]}`;
|
||||
};
|
||||
|
||||
export const ImportQueueItem = (props: ModelListItemProps) => {
|
||||
const { model } = props;
|
||||
export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
||||
const { installJob } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [deleteImportModel] = useDeleteModelImportMutation();
|
||||
const [deleteImportModel] = useCancelModelInstallMutation();
|
||||
|
||||
const source = useMemo(() => {
|
||||
if (model.source.type === 'hf') {
|
||||
return model.source as HFModelSource;
|
||||
} else if (model.source.type === 'local') {
|
||||
return model.source as LocalModelSource;
|
||||
} else if (model.source.type === 'url') {
|
||||
return model.source as URLModelSource;
|
||||
if (installJob.source.type === 'hf') {
|
||||
return installJob.source as HFModelSource;
|
||||
} else if (installJob.source.type === 'local') {
|
||||
return installJob.source as LocalModelSource;
|
||||
} else if (installJob.source.type === 'url') {
|
||||
return installJob.source as URLModelSource;
|
||||
} else {
|
||||
return model.source as LocalModelSource;
|
||||
return installJob.source as LocalModelSource;
|
||||
}
|
||||
}, [model.source]);
|
||||
}, [installJob.source]);
|
||||
|
||||
const handleDeleteModelImport = useCallback(() => {
|
||||
deleteImportModel(model.id)
|
||||
deleteImportModel(installJob.id)
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
@ -69,7 +70,7 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
|
||||
);
|
||||
}
|
||||
});
|
||||
}, [deleteImportModel, model, dispatch]);
|
||||
}, [deleteImportModel, installJob, dispatch]);
|
||||
|
||||
const modelName = useMemo(() => {
|
||||
switch (source.type) {
|
||||
@ -85,19 +86,23 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
|
||||
}, [source]);
|
||||
|
||||
const progressValue = useMemo(() => {
|
||||
if (model.bytes === undefined || model.total_bytes === undefined) {
|
||||
if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (installJob.total_bytes === 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return (model.bytes / model.total_bytes) * 100;
|
||||
}, [model.bytes, model.total_bytes]);
|
||||
return (installJob.bytes / installJob.total_bytes) * 100;
|
||||
}, [installJob.bytes, installJob.total_bytes]);
|
||||
|
||||
const progressString = useMemo(() => {
|
||||
if (model.status !== 'downloading' || model.bytes === undefined || model.total_bytes === undefined) {
|
||||
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
|
||||
return '';
|
||||
}
|
||||
return `${formatBytes(model.bytes)} / ${formatBytes(model.total_bytes)}`;
|
||||
}, [model.bytes, model.total_bytes, model.status]);
|
||||
return `${formatBytes(installJob.bytes)} / ${formatBytes(installJob.total_bytes)}`;
|
||||
}, [installJob.bytes, installJob.total_bytes, installJob.status]);
|
||||
|
||||
return (
|
||||
<Flex gap="2" w="full" alignItems="center">
|
||||
@ -109,19 +114,21 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
|
||||
<Flex flexDir="column" flex={1}>
|
||||
<Tooltip label={progressString}>
|
||||
<Progress
|
||||
value={progressValue}
|
||||
isIndeterminate={progressValue === undefined}
|
||||
value={progressValue ?? 0}
|
||||
isIndeterminate={progressValue === null}
|
||||
aria-label={t('accessibility.invokeProgressBar')}
|
||||
h={2}
|
||||
/>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
<Box minW="100px" textAlign="center">
|
||||
<ImportQueueBadge status={model.status} errorReason={model.error_reason} />
|
||||
<ModelInstallQueueBadge status={installJob.status} errorReason={installJob.error_reason} />
|
||||
</Box>
|
||||
|
||||
<Box minW="20px">
|
||||
{(model.status === 'downloading' || model.status === 'waiting' || model.status === 'running') && (
|
||||
{(installJob.status === 'downloading' ||
|
||||
installJob.status === 'waiting' ||
|
||||
installJob.status === 'running') && (
|
||||
<IconButton
|
||||
isRound={true}
|
||||
size="xs"
|
@ -2,24 +2,24 @@ import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useLazyScanModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useLazyScanFolderQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { ScanModelsResults } from './ScanModelsResults';
|
||||
import { ScanModelsResults } from './ScanFolderResults';
|
||||
|
||||
export const ScanModelsForm = () => {
|
||||
const [scanPath, setScanPath] = useState('');
|
||||
const [errorMessage, setErrorMessage] = useState('');
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery();
|
||||
const [_scanFolder, { isLoading, data }] = useLazyScanFolderQuery();
|
||||
|
||||
const handleSubmitScan = useCallback(async () => {
|
||||
_scanModels({ scan_path: scanPath }).catch((error) => {
|
||||
const scanFolder = useCallback(async () => {
|
||||
_scanFolder({ scan_path: scanPath }).catch((error) => {
|
||||
if (error) {
|
||||
setErrorMessage(error.data.detail);
|
||||
}
|
||||
});
|
||||
}, [_scanModels, scanPath]);
|
||||
}, [_scanFolder, scanPath]);
|
||||
|
||||
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
||||
setScanPath(e.target.value);
|
||||
@ -36,7 +36,7 @@ export const ScanModelsForm = () => {
|
||||
<Input value={scanPath} onChange={handleSetScanPath} />
|
||||
</Flex>
|
||||
|
||||
<Button onClick={handleSubmitScan} isLoading={isLoading} isDisabled={scanPath.length === 0}>
|
||||
<Button onClick={scanFolder} isLoading={isLoading} isDisabled={scanPath.length === 0}>
|
||||
{t('modelManager.scanFolder')}
|
||||
</Button>
|
||||
</Flex>
|
@ -18,7 +18,7 @@ import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { ScanModelResultItem } from './ScanModelResultItem';
|
||||
import { ScanModelResultItem } from './ScanFolderResultItem';
|
||||
|
||||
type ScanModelResultsProps = {
|
||||
results: ScanFolderResponse;
|
@ -1,12 +1,11 @@
|
||||
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { AdvancedImport } from './AddModelPanel/AdvancedImport';
|
||||
import { ImportQueue } from './AddModelPanel/ImportQueue/ImportQueue';
|
||||
import { ScanModelsForm } from './AddModelPanel/ScanModels/ScanModelsForm';
|
||||
import { SimpleImport } from './AddModelPanel/SimpleImport';
|
||||
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
|
||||
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
|
||||
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
|
||||
|
||||
export const ImportModels = () => {
|
||||
export const InstallModels = () => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Flex layerStyle="first" p={3} borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
|
||||
@ -17,15 +16,11 @@ export const ImportModels = () => {
|
||||
<Tabs variant="collapse" height="100%">
|
||||
<TabList>
|
||||
<Tab>{t('common.simple')}</Tab>
|
||||
<Tab>{t('modelManager.advanced')}</Tab>
|
||||
<Tab>{t('modelManager.scan')}</Tab>
|
||||
</TabList>
|
||||
<TabPanels p={3} height="100%">
|
||||
<TabPanel>
|
||||
<SimpleImport />
|
||||
</TabPanel>
|
||||
<TabPanel height="100%">
|
||||
<AdvancedImport />
|
||||
<InstallModelForm />
|
||||
</TabPanel>
|
||||
<TabPanel height="100%">
|
||||
<ScanModelsForm />
|
||||
@ -34,7 +29,7 @@ export const ImportModels = () => {
|
||||
</Tabs>
|
||||
</Box>
|
||||
<Box layerStyle="second" borderRadius="base" w="full" h="50%">
|
||||
<ImportQueue />
|
||||
<ModelInstallQueue />
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
@ -5,7 +5,7 @@ import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoFilter } from 'react-icons/io5';
|
||||
|
||||
export const MODEL_TYPE_LABELS: { [key: string]: string } = {
|
||||
const MODEL_TYPE_LABELS: { [key: string]: string } = {
|
||||
main: 'Main',
|
||||
lora: 'LoRA',
|
||||
embedding: 'Textual Inversion',
|
||||
|
@ -1,14 +1,14 @@
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
|
||||
import { ImportModels } from './ImportModels';
|
||||
import { InstallModels } from './InstallModels';
|
||||
import { Model } from './ModelPanel/Model';
|
||||
|
||||
export const ModelPane = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
return (
|
||||
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
|
||||
{selectedModelKey ? <Model /> : <ImportModels />}
|
||||
{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
@ -0,0 +1,68 @@
|
||||
import { Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
|
||||
|
||||
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
|
||||
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision } = config.sd;
|
||||
|
||||
return {
|
||||
initialSteps: steps.initial,
|
||||
initialCfg: guidance.initial,
|
||||
initialScheduler: scheduler,
|
||||
initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial,
|
||||
initialVaePrecision: vaePrecision,
|
||||
};
|
||||
});
|
||||
|
||||
export const DefaultSettings = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
|
||||
useAppSelector(initialStatesSelector);
|
||||
|
||||
const defaultSettingsDefaults = useMemo(() => {
|
||||
return {
|
||||
vae: { isEnabled: !isNil(data?.default_settings?.vae), value: data?.default_settings?.vae || 'default' },
|
||||
vaePrecision: {
|
||||
isEnabled: !isNil(data?.default_settings?.vae_precision),
|
||||
value: data?.default_settings?.vae_precision || initialVaePrecision || 'fp32',
|
||||
},
|
||||
scheduler: {
|
||||
isEnabled: !isNil(data?.default_settings?.scheduler),
|
||||
value: data?.default_settings?.scheduler || initialScheduler || 'euler',
|
||||
},
|
||||
steps: { isEnabled: !isNil(data?.default_settings?.steps), value: data?.default_settings?.steps || initialSteps },
|
||||
cfgScale: {
|
||||
isEnabled: !isNil(data?.default_settings?.cfg_scale),
|
||||
value: data?.default_settings?.cfg_scale || initialCfg,
|
||||
},
|
||||
cfgRescaleMultiplier: {
|
||||
isEnabled: !isNil(data?.default_settings?.cfg_rescale_multiplier),
|
||||
value: data?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier,
|
||||
},
|
||||
};
|
||||
}, [
|
||||
data?.default_settings,
|
||||
initialSteps,
|
||||
initialCfg,
|
||||
initialScheduler,
|
||||
initialCfgRescaleMultiplier,
|
||||
initialVaePrecision,
|
||||
]);
|
||||
|
||||
if (isLoading) {
|
||||
return <Text>{t('common.loading')}</Text>;
|
||||
}
|
||||
|
||||
return <DefaultSettingsForm defaultSettingsDefaults={defaultSettingsDefaults} />;
|
||||
};
|
@ -0,0 +1,72 @@
|
||||
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultCfgRescaleMultiplierType = DefaultSettingsFormData['cfgRescaleMultiplier'];
|
||||
|
||||
export function DefaultCfgRescaleMultiplier(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMax);
|
||||
const numberInputMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMin);
|
||||
const numberInputMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMax);
|
||||
const coarseStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.coarseStep);
|
||||
const fineStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.fineStep);
|
||||
const { t } = useTranslation();
|
||||
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultCfgRescaleMultiplierType),
|
||||
value: v,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return (field.value as DefaultCfgRescaleMultiplierType).value;
|
||||
}, [field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultCfgRescaleMultiplierType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramCFGRescaleMultiplier">
|
||||
<FormLabel>{t('parameters.cfgRescaleMultiplier')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Flex w="full" gap={1}>
|
||||
<CompositeSlider
|
||||
value={value}
|
||||
min={sliderMin}
|
||||
max={sliderMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={value}
|
||||
min={numberInputMin}
|
||||
max={numberInputMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,72 @@
|
||||
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultCfgType = DefaultSettingsFormData['cfgScale'];
|
||||
|
||||
export function DefaultCfgScale(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.guidance.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.guidance.sliderMax);
|
||||
const numberInputMin = useAppSelector((s) => s.config.sd.guidance.numberInputMin);
|
||||
const numberInputMax = useAppSelector((s) => s.config.sd.guidance.numberInputMax);
|
||||
const coarseStep = useAppSelector((s) => s.config.sd.guidance.coarseStep);
|
||||
const fineStep = useAppSelector((s) => s.config.sd.guidance.fineStep);
|
||||
const { t } = useTranslation();
|
||||
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultCfgType),
|
||||
value: v,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return (field.value as DefaultCfgType).value;
|
||||
}, [field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultCfgType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramCFGScale">
|
||||
<FormLabel>{t('parameters.cfgScale')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Flex w="full" gap={1}>
|
||||
<CompositeSlider
|
||||
value={value}
|
||||
min={sliderMin}
|
||||
max={sliderMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={value}
|
||||
min={numberInputMin}
|
||||
max={numberInputMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants';
|
||||
import { isParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultSchedulerType = DefaultSettingsFormData['scheduler'];
|
||||
|
||||
export function DefaultScheduler(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isParameterScheduler(v?.value)) {
|
||||
return;
|
||||
}
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultSchedulerType),
|
||||
value: v.value,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(
|
||||
() => SCHEDULER_OPTIONS.find((o) => o.value === (field.value as DefaultSchedulerType).value),
|
||||
[field]
|
||||
);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultSchedulerType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramScheduler">
|
||||
<FormLabel>{t('parameters.scheduler')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox isDisabled={isDisabled} value={value} options={SCHEDULER_OPTIONS} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,147 @@
|
||||
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoPencil } from 'react-icons/io5';
|
||||
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
|
||||
import { DefaultCfgScale } from './DefaultCfgScale';
|
||||
import { DefaultScheduler } from './DefaultScheduler';
|
||||
import { DefaultSteps } from './DefaultSteps';
|
||||
import { DefaultVae } from './DefaultVae';
|
||||
import { DefaultVaePrecision } from './DefaultVaePrecision';
|
||||
import { SettingToggle } from './SettingToggle';
|
||||
|
||||
export interface FormField<T> {
|
||||
value: T;
|
||||
isEnabled: boolean;
|
||||
}
|
||||
|
||||
export type DefaultSettingsFormData = {
|
||||
vae: FormField<string>;
|
||||
vaePrecision: FormField<string>;
|
||||
scheduler: FormField<ParameterScheduler>;
|
||||
steps: FormField<number>;
|
||||
cfgScale: FormField<number>;
|
||||
cfgRescaleMultiplier: FormField<number>;
|
||||
};
|
||||
|
||||
export const DefaultSettingsForm = ({
|
||||
defaultSettingsDefaults,
|
||||
}: {
|
||||
defaultSettingsDefaults: DefaultSettingsFormData;
|
||||
}) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
|
||||
const [updateModel, { isLoading }] = useUpdateModelMutation();
|
||||
|
||||
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
|
||||
defaultValues: defaultSettingsDefaults,
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<DefaultSettingsFormData>>(
|
||||
(data) => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
const body = {
|
||||
vae: data.vae.isEnabled ? data.vae.value : null,
|
||||
vae_precision: data.vaePrecision.isEnabled ? data.vaePrecision.value : null,
|
||||
cfg_scale: data.cfgScale.isEnabled ? data.cfgScale.value : null,
|
||||
cfg_rescale_multiplier: data.cfgRescaleMultiplier.isEnabled ? data.cfgRescaleMultiplier.value : null,
|
||||
steps: data.steps.isEnabled ? data.steps.value : null,
|
||||
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
|
||||
};
|
||||
|
||||
updateModel({
|
||||
key: selectedModelKey,
|
||||
body: { default_settings: body },
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.defaultSettingsSaved'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${error.data.detail} `,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
},
|
||||
[selectedModelKey, dispatch, updateModel, t]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex gap="2" justifyContent="space-between" w="full" mb={5}>
|
||||
<Heading fontSize="md">{t('modelManager.defaultSettings')}</Heading>
|
||||
<Button
|
||||
size="sm"
|
||||
leftIcon={<IoPencil />}
|
||||
colorScheme="invokeYellow"
|
||||
isDisabled={!formState.isDirty}
|
||||
onClick={handleSubmit(onSubmit)}
|
||||
type="submit"
|
||||
isLoading={isLoading}
|
||||
>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={8}>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="vae" />
|
||||
<DefaultVae control={control} name="vae" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="vaePrecision" />
|
||||
<DefaultVaePrecision control={control} name="vaePrecision" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="scheduler" />
|
||||
<DefaultScheduler control={control} name="scheduler" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="steps" />
|
||||
<DefaultSteps control={control} name="steps" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="cfgScale" />
|
||||
<DefaultCfgScale control={control} name="cfgScale" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<SettingToggle control={control} name="cfgRescaleMultiplier" />
|
||||
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
};
|
@ -0,0 +1,72 @@
|
||||
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultSteps = DefaultSettingsFormData['steps'];
|
||||
|
||||
export function DefaultSteps(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const sliderMin = useAppSelector((s) => s.config.sd.steps.sliderMin);
|
||||
const sliderMax = useAppSelector((s) => s.config.sd.steps.sliderMax);
|
||||
const numberInputMin = useAppSelector((s) => s.config.sd.steps.numberInputMin);
|
||||
const numberInputMax = useAppSelector((s) => s.config.sd.steps.numberInputMax);
|
||||
const coarseStep = useAppSelector((s) => s.config.sd.steps.coarseStep);
|
||||
const fineStep = useAppSelector((s) => s.config.sd.steps.fineStep);
|
||||
const { t } = useTranslation();
|
||||
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultSteps),
|
||||
value: v,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return (field.value as DefaultSteps).value;
|
||||
}, [field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultSteps).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramSteps">
|
||||
<FormLabel>{t('parameters.steps')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Flex w="full" gap={1}>
|
||||
<CompositeSlider
|
||||
value={value}
|
||||
min={sliderMin}
|
||||
max={sliderMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={value}
|
||||
min={numberInputMin}
|
||||
max={numberInputMax}
|
||||
step={coarseStep}
|
||||
fineStep={fineStep}
|
||||
onChange={onChange}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Flex>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,65 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { map } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery, useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
type DefaultVaeType = DefaultSettingsFormData['vae'];
|
||||
|
||||
export function DefaultVae(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const { compatibleOptions } = useGetVaeModelsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => {
|
||||
const modelArray = map(data?.entities);
|
||||
const compatibleOptions = modelArray
|
||||
.filter((vae) => vae.base === modelData?.base)
|
||||
.map((vae) => ({ label: vae.name, value: vae.key }));
|
||||
|
||||
const defaultOption = { label: 'Default VAE', value: 'default' };
|
||||
|
||||
return { compatibleOptions: [defaultOption, ...compatibleOptions] };
|
||||
},
|
||||
});
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
const newValue = !v?.value ? 'default' : v.value;
|
||||
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultVaeType),
|
||||
value: newValue,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return compatibleOptions.find((vae) => vae.value === (field.value as DefaultVaeType).value);
|
||||
}, [compatibleOptions, field.value]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultVaeType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramVAE">
|
||||
<FormLabel>{t('modelManager.vae')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox isDisabled={isDisabled} value={value} options={compatibleOptions} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,51 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { isParameterPrecision } from 'features/parameters/types/parameterSchemas';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
|
||||
|
||||
const options = [
|
||||
{ label: 'FP16', value: 'fp16' },
|
||||
{ label: 'FP32', value: 'fp32' },
|
||||
];
|
||||
|
||||
type DefaultVaePrecisionType = DefaultSettingsFormData['vaePrecision'];
|
||||
|
||||
export function DefaultVaePrecision(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isParameterPrecision(v?.value)) {
|
||||
return;
|
||||
}
|
||||
const updatedValue = {
|
||||
...(field.value as DefaultVaePrecisionType),
|
||||
value: v.value,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === (field.value as DefaultVaePrecisionType).value), [field]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
return !(field.value as DefaultVaePrecisionType).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
return (
|
||||
<FormControl flexDir="column" gap={1} alignItems="flex-start">
|
||||
<InformationalPopover feature="paramVAEPrecision">
|
||||
<FormLabel>{t('modelManager.vaePrecision')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox isDisabled={isDisabled} value={value} options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
import { Switch } from '@invoke-ai/ui-library';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
|
||||
import type { DefaultSettingsFormData, FormField } from './DefaultSettingsForm';
|
||||
|
||||
export function SettingToggle<T>(props: UseControllerProps<DefaultSettingsFormData>) {
|
||||
const { field } = useController(props);
|
||||
|
||||
const value = useMemo(() => {
|
||||
return !!(field.value as FormField<T>).isEnabled;
|
||||
}, [field.value]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
const updatedValue: FormField<T> = {
|
||||
...(field.value as FormField<T>),
|
||||
isEnabled: e.target.checked,
|
||||
};
|
||||
field.onChange(updatedValue);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
return <Switch isChecked={value} onChange={onChange} />;
|
||||
}
|
@ -3,9 +3,9 @@ import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import type { Control } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
||||
@ -14,8 +14,12 @@ const options: ComboboxOption[] = [
|
||||
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
||||
];
|
||||
|
||||
const BaseModelSelect = (props: UseControllerProps<AnyModelConfig>) => {
|
||||
const { field } = useController(props);
|
||||
type Props = {
|
||||
control: Control<UpdateModelArg['body']>;
|
||||
};
|
||||
|
||||
const BaseModelSelect = ({ control }: Props) => {
|
||||
const { field } = useController({ control, name: 'base' });
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
|
@ -1,27 +0,0 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'none', label: '-' },
|
||||
{ value: 'true', label: 'True' },
|
||||
{ value: 'false', label: 'False' },
|
||||
];
|
||||
|
||||
const BooleanSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value === 'true');
|
||||
},
|
||||
[field]
|
||||
);
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(BooleanSelect);
|
@ -1,47 +0,0 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController, useWatch } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const ModelFormatSelect = (props: UseControllerProps<AnyModelConfig>) => {
|
||||
const { field, formState } = useController(props);
|
||||
const type = useWatch({ control: props.control, name: 'type' });
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const options: ComboboxOption[] = useMemo(() => {
|
||||
const modelType = type || formState.defaultValues?.type;
|
||||
if (modelType === 'lora') {
|
||||
return [
|
||||
{ value: 'lycoris', label: 'LyCORIS' },
|
||||
{ value: 'diffusers', label: 'Diffusers' },
|
||||
];
|
||||
} else if (modelType === 'embedding') {
|
||||
return [
|
||||
{ value: 'embedding_file', label: 'Embedding File' },
|
||||
{ value: 'embedding_folder', label: 'Embedding Folder' },
|
||||
];
|
||||
} else if (modelType === 'ip_adapter') {
|
||||
return [{ value: 'invokeai', label: 'invokeai' }];
|
||||
} else {
|
||||
return [
|
||||
{ value: 'diffusers', label: 'Diffusers' },
|
||||
{ value: 'checkpoint', label: 'Checkpoint' },
|
||||
];
|
||||
}
|
||||
}, [type, formState.defaultValues?.type]);
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [options, field.value]);
|
||||
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(ModelFormatSelect);
|
@ -1,32 +0,0 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { MODEL_TYPE_LABELS } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'main', label: MODEL_TYPE_LABELS['main'] as string },
|
||||
{ value: 'lora', label: MODEL_TYPE_LABELS['lora'] as string },
|
||||
{ value: 'embedding', label: MODEL_TYPE_LABELS['embedding'] as string },
|
||||
{ value: 'vae', label: MODEL_TYPE_LABELS['vae'] as string },
|
||||
{ value: 'controlnet', label: MODEL_TYPE_LABELS['controlnet'] as string },
|
||||
{ value: 'ip_adapter', label: MODEL_TYPE_LABELS['ip_adapter'] as string },
|
||||
{ value: 't2i_adapater', label: MODEL_TYPE_LABELS['t2i_adapter'] as string },
|
||||
] as const;
|
||||
|
||||
const ModelTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(ModelTypeSelect);
|
@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import type { Control } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'normal', label: 'Normal' },
|
||||
@ -12,8 +12,12 @@ const options: ComboboxOption[] = [
|
||||
{ value: 'depth', label: 'Depth' },
|
||||
];
|
||||
|
||||
const ModelVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
type Props = {
|
||||
control: Control<UpdateModelArg['body']>;
|
||||
};
|
||||
|
||||
const ModelVariantSelect = ({ control }: Props) => {
|
||||
const { field } = useController({ control, name: 'variant' });
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
|
@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import type { Control } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'none', label: '-' },
|
||||
@ -13,8 +13,12 @@ const options: ComboboxOption[] = [
|
||||
{ value: 'sample', label: 'sample' },
|
||||
];
|
||||
|
||||
const PredictionTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
type Props = {
|
||||
control: Control<UpdateModelArg['body']>;
|
||||
};
|
||||
|
||||
const PredictionTypeSelect = ({ control }: Props) => {
|
||||
const { field } = useController({ control, name: 'prediction_type' });
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
|
@ -1,27 +0,0 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'none', label: '-' },
|
||||
{ value: 'fp16', label: 'fp16' },
|
||||
{ value: 'fp32', label: 'fp32' },
|
||||
];
|
||||
|
||||
const RepoVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(RepoVariantSelect);
|
@ -0,0 +1,41 @@
|
||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
import type { ModelType } from 'services/api/types';
|
||||
|
||||
import { TriggerPhrases } from './TriggerPhrases';
|
||||
|
||||
const MODEL_TYPE_TRIGGER_PHRASE: ModelType[] = ['main', 'lora'];
|
||||
|
||||
export const ModelMetadata = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const shouldShowTriggerPhraseSettings = useMemo(() => {
|
||||
if (!data?.type) {
|
||||
return false;
|
||||
}
|
||||
return MODEL_TYPE_TRIGGER_PHRASE.includes(data.type);
|
||||
}, [data]);
|
||||
|
||||
const apiResponseFormatted = useMemo(() => {
|
||||
if (!data?.source_api_response) {
|
||||
return {};
|
||||
}
|
||||
return JSON.parse(data.source_api_response);
|
||||
}, [data?.source_api_response]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" height="full" gap="3">
|
||||
{shouldShowTriggerPhraseSettings && (
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<TriggerPhrases />
|
||||
</Box>
|
||||
)}
|
||||
<DataViewer label="metadata" data={apiResponseFormatted} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -0,0 +1,106 @@
|
||||
import {
|
||||
Button,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
Input,
|
||||
Tag,
|
||||
TagCloseButton,
|
||||
TagLabel,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ModelListHeader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelListHeader';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
export const TriggerPhrases = () => {
|
||||
const { t } = useTranslation();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data: modelConfig } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
const [phrase, setPhrase] = useState('');
|
||||
|
||||
const [updateModel, { isLoading }] = useUpdateModelMutation();
|
||||
|
||||
const handlePhraseChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setPhrase(e.target.value);
|
||||
}, []);
|
||||
|
||||
const triggerPhrases = useMemo(() => {
|
||||
return modelConfig?.trigger_phrases || [];
|
||||
}, [modelConfig?.trigger_phrases]);
|
||||
|
||||
const errors = useMemo(() => {
|
||||
const errors = [];
|
||||
|
||||
if (phrase.length && triggerPhrases.includes(phrase)) {
|
||||
errors.push('Phrase is already in list');
|
||||
}
|
||||
|
||||
return errors;
|
||||
}, [phrase, triggerPhrases]);
|
||||
|
||||
const addTriggerPhrase = useCallback(async () => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!phrase.length || triggerPhrases.includes(phrase)) {
|
||||
return;
|
||||
}
|
||||
|
||||
await updateModel({
|
||||
key: selectedModelKey,
|
||||
body: { trigger_phrases: [...triggerPhrases, phrase] },
|
||||
}).unwrap();
|
||||
setPhrase('');
|
||||
}, [updateModel, selectedModelKey, phrase, triggerPhrases]);
|
||||
|
||||
const removeTriggerPhrase = useCallback(
|
||||
async (phraseToRemove: string) => {
|
||||
if (!selectedModelKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
const filteredPhrases = triggerPhrases.filter((p) => p !== phraseToRemove);
|
||||
|
||||
await updateModel({ key: selectedModelKey, body: { trigger_phrases: filteredPhrases } }).unwrap();
|
||||
},
|
||||
[updateModel, selectedModelKey, triggerPhrases]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" w="full" gap="5">
|
||||
<ModelListHeader title={t('modelManager.triggerPhrases')} />
|
||||
<form>
|
||||
<FormControl w="full" isInvalid={Boolean(errors.length)}>
|
||||
<Flex flexDir="column" w="full">
|
||||
<Flex gap="3" alignItems="center" w="full">
|
||||
<Input value={phrase} onChange={handlePhraseChange} placeholder={t('modelManager.typePhraseHere')} />
|
||||
<Button
|
||||
type="submit"
|
||||
onClick={addTriggerPhrase}
|
||||
isDisabled={Boolean(errors.length)}
|
||||
isLoading={isLoading}
|
||||
>
|
||||
{t('common.add')}
|
||||
</Button>
|
||||
</Flex>
|
||||
{!!errors.length && errors.map((error) => <FormErrorMessage key={error}>{error}</FormErrorMessage>)}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
</form>
|
||||
|
||||
<Flex gap="4" flexWrap="wrap" mt="3" mb="3">
|
||||
{triggerPhrases.map((phrase, index) => (
|
||||
<Tag size="md" key={index}>
|
||||
<TagLabel>{phrase}</TagLabel>
|
||||
<TagCloseButton onClick={removeTriggerPhrase.bind(null, phrase)} isDisabled={isLoading} />
|
||||
</Tag>
|
||||
))}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -1,9 +1,58 @@
|
||||
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs, Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { ModelMetadata } from './Metadata/ModelMetadata';
|
||||
import { ModelAttrView } from './ModelAttrView';
|
||||
import { ModelEdit } from './ModelEdit';
|
||||
import { ModelView } from './ModelView';
|
||||
|
||||
export const Model = () => {
|
||||
const { t } = useTranslation();
|
||||
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
|
||||
return selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />;
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
if (isLoading) {
|
||||
return <Text>{t('common.loading')}</Text>;
|
||||
}
|
||||
|
||||
if (!data) {
|
||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex flexDir="column" gap={1} p={2}>
|
||||
<Heading as="h2" fontSize="lg">
|
||||
{data.name}
|
||||
</Heading>
|
||||
|
||||
{data.source && (
|
||||
<Text variant="subtext">
|
||||
{t('modelManager.source')}: {data?.source}
|
||||
</Text>
|
||||
)}
|
||||
<Box mt="4">
|
||||
<ModelAttrView label="Description" value={data.description} />
|
||||
</Box>
|
||||
</Flex>
|
||||
|
||||
<Tabs mt="4" h="100%">
|
||||
<TabList>
|
||||
<Tab>{t('modelManager.settings')}</Tab>
|
||||
<Tab>{t('modelManager.metadata')}</Tab>
|
||||
</TabList>
|
||||
|
||||
<TabPanels h="100%">
|
||||
<TabPanel>{selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />}</TabPanel>
|
||||
<TabPanel h="full">
|
||||
<ModelMetadata />
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
@ -13,7 +13,7 @@ import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import { useConvertModelMutation } from 'services/api/endpoints/models';
|
||||
import type { CheckpointModelConfig } from 'services/api/types';
|
||||
|
||||
interface ModelConvertProps {
|
||||
@ -24,7 +24,7 @@ export const ModelConvert = (props: ModelConvertProps) => {
|
||||
const { model } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
|
||||
const [convertModel, { isLoading }] = useConvertModelMutation();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const modelConvertHandler = useCallback(() => {
|
||||
|
@ -1,5 +1,6 @@
|
||||
import {
|
||||
Button,
|
||||
Checkbox,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
@ -19,66 +20,27 @@ import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
import { useGetModelConfigQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import BaseModelSelect from './Fields/BaseModelSelect';
|
||||
import BooleanSelect from './Fields/BooleanSelect';
|
||||
import ModelFormatSelect from './Fields/ModelFormatSelect';
|
||||
import ModelTypeSelect from './Fields/ModelTypeSelect';
|
||||
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
||||
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
||||
import RepoVariantSelect from './Fields/RepoVariantSelect';
|
||||
|
||||
export const ModelEdit = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelsMutation();
|
||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
// const modelData = useMemo(() => {
|
||||
// if (!data) {
|
||||
// return null;
|
||||
// }
|
||||
// const modelFormat = data.format;
|
||||
// const modelType = data.type;
|
||||
|
||||
// if (modelType === 'main') {
|
||||
// if (modelFormat === 'diffusers') {
|
||||
// return data as DiffusersModelConfig;
|
||||
// } else if (modelFormat === 'checkpoint') {
|
||||
// return data as CheckpointModelConfig;
|
||||
// }
|
||||
// }
|
||||
|
||||
// switch (modelType) {
|
||||
// case 'lora':
|
||||
// return data as LoRAModelConfig;
|
||||
// case 'embedding':
|
||||
// return data as TextualInversionModelConfig;
|
||||
// case 't2i_adapter':
|
||||
// return data as T2IAdapterModelConfig;
|
||||
// case 'ip_adapter':
|
||||
// return data as IPAdapterModelConfig;
|
||||
// case 'controlnet':
|
||||
// return data as ControlNetModelConfig;
|
||||
// case 'vae':
|
||||
// return data as VAEModelConfig;
|
||||
// default:
|
||||
// return null;
|
||||
// }
|
||||
// }, [data]);
|
||||
|
||||
const {
|
||||
register,
|
||||
handleSubmit,
|
||||
control,
|
||||
formState: { errors },
|
||||
reset,
|
||||
watch,
|
||||
} = useForm<UpdateModelArg['body']>({
|
||||
defaultValues: {
|
||||
...data,
|
||||
@ -86,10 +48,7 @@ export const ModelEdit = () => {
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const watchedModelType = watch('type');
|
||||
const watchedModelFormat = watch('format');
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
|
||||
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
|
||||
(values) => {
|
||||
if (!data?.key) {
|
||||
return;
|
||||
@ -143,33 +102,31 @@ export const ModelEdit = () => {
|
||||
return (
|
||||
<Flex flexDir="column" h="full">
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
|
||||
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
||||
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
|
||||
<FormLabel hidden={true}>{t('modelManager.modelName')}</FormLabel>
|
||||
<Input
|
||||
{...register('name', {
|
||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
||||
validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
|
||||
})}
|
||||
size="lg"
|
||||
/>
|
||||
|
||||
<Flex gap={2}>
|
||||
<Button size="sm" onClick={handleClickCancel}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
colorScheme="invokeYellow"
|
||||
onClick={handleSubmit(onSubmit)}
|
||||
isLoading={isSubmitting}
|
||||
isDisabled={Boolean(Object.keys(errors).length)}
|
||||
>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</Flex>
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
<Button size="sm" onClick={handleClickCancel}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
colorScheme="invokeYellow"
|
||||
onClick={handleSubmit(onSubmit)}
|
||||
isLoading={isSubmitting}
|
||||
isDisabled={Boolean(Object.keys(errors).length)}
|
||||
>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={3} mt="4">
|
||||
<Flex>
|
||||
@ -184,76 +141,22 @@ export const ModelEdit = () => {
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect control={control} name="base" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
||||
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
|
||||
<BaseModelSelect control={control} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('common.format')}</FormLabel>
|
||||
<ModelFormatSelect control={control} name="format" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>{t('modelManager.path')}</FormLabel>
|
||||
<Input
|
||||
{...register('path', {
|
||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||
})}
|
||||
/>
|
||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
</Flex>
|
||||
{watchedModelType === 'main' && (
|
||||
<>
|
||||
<Flex gap={4}>
|
||||
{watchedModelFormat === 'diffusers' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
|
||||
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
|
||||
</FormControl>
|
||||
)}
|
||||
{watchedModelFormat === 'checkpoint' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||
<Input {...register('config')} />
|
||||
</FormControl>
|
||||
)}
|
||||
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
|
||||
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
||||
<Input {...register('vae')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{watchedModelType === 'ip_adapter' && (
|
||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
|
||||
<Input {...register('image_encoder_model_id')} />
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={control} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect control={control} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<Checkbox {...register('upcast_attention')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
)}
|
||||
|
@ -1,12 +1,11 @@
|
||||
import { Box, Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
|
||||
import { Box, Button, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoPencil } from 'react-icons/io5';
|
||||
import { useGetModelConfigQuery, useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
import type {
|
||||
CheckpointModelConfig,
|
||||
ControlNetModelConfig,
|
||||
@ -18,6 +17,7 @@ import type {
|
||||
VAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { DefaultSettings } from './DefaultSettings';
|
||||
import { ModelAttrView } from './ModelAttrView';
|
||||
import { ModelConvert } from './ModelConvert';
|
||||
|
||||
@ -26,7 +26,6 @@ export const ModelView = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const modelData = useMemo(() => {
|
||||
if (!data) {
|
||||
@ -73,85 +72,49 @@ export const ModelView = () => {
|
||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
||||
}
|
||||
return (
|
||||
<Flex flexDir="column" h="full">
|
||||
<Flex w="full" justifyContent="space-between">
|
||||
<Flex flexDir="column" gap={1} p={2}>
|
||||
<Heading as="h2" fontSize="lg">
|
||||
{modelData.name}
|
||||
</Heading>
|
||||
|
||||
{modelData.source && (
|
||||
<Text variant="subtext">
|
||||
{t('modelManager.source')}: {modelData.source}
|
||||
</Text>
|
||||
)}
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<Flex flexDir="column" h="full" gap="2">
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<Flex gap="2" justifyContent="flex-end" w="full">
|
||||
<Button size="sm" leftIcon={<IoPencil />} colorScheme="invokeYellow" onClick={handleEditModel}>
|
||||
{t('modelManager.edit')}
|
||||
</Button>
|
||||
|
||||
{modelData.type === 'main' && modelData.format === 'checkpoint' && <ModelConvert model={modelData} />}
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" p={2} gap={3}>
|
||||
<Flex>
|
||||
<ModelAttrView label="Description" value={modelData.description} />
|
||||
</Flex>
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.modelSettings')}
|
||||
</Heading>
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<Flex flexDir="column" gap={3}>
|
||||
<Flex flexDir="column" gap={3}>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.baseModel')} value={modelData.base} />
|
||||
<ModelAttrView label={t('modelManager.modelType')} value={modelData.type} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('common.format')} value={modelData.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
|
||||
</Flex>
|
||||
{modelData.type === 'main' && (
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.baseModel')} value={modelData.base} />
|
||||
<ModelAttrView label={t('modelManager.modelType')} value={modelData.type} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('common.format')} value={modelData.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
|
||||
</Flex>
|
||||
{modelData.type === 'main' && (
|
||||
<>
|
||||
<Flex gap={2}>
|
||||
{modelData.format === 'diffusers' && (
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
||||
)}
|
||||
{modelData.format === 'checkpoint' && (
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
|
||||
)}
|
||||
|
||||
{modelData.format === 'diffusers' && modelData.repo_variant && (
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
||||
)}
|
||||
{modelData.format === 'checkpoint' && (
|
||||
<>
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} />
|
||||
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
|
||||
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{modelData.type === 'ip_adapter' && (
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} />
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</Box>
|
||||
</Flex>
|
||||
|
||||
{metadata && (
|
||||
<>
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.modelMetadata')}
|
||||
</Heading>
|
||||
<Flex h="full" w="full" p={2}>
|
||||
<DataViewer label="metadata" data={metadata} />
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
)}
|
||||
{modelData.type === 'ip_adapter' && (
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} />
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</Box>
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<DefaultSettings />
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user