Compare commits

...

70 Commits

Author SHA1 Message Date
cd3f5f30dc Run ruff format 2024-03-05 16:38:55 -05:00
71ee28ac12 Refractor session runner, move profiling back to processor, create abstract class for session runners, create path for passing in custom session runner to default session processor 2024-03-05 16:01:47 -05:00
46c904d08a Rename graph processor to session runner to better describe what it's doing, add before/after callbacks for sessions 2024-03-05 16:01:47 -05:00
7d5a88b69d Move graph processor into session_processor_default 2024-03-05 16:01:47 -05:00
afa4df1991 Separate the logic that actually runs a graph in the session_processor into its own class 2024-03-05 16:01:47 -05:00
e30cb4b52f updates for defaultModel (#5866)
* move defaultModel logic to modelsLoaded and update to work for key instead of name/base/type string

* lint fix

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2024-03-05 09:55:22 -05:00
ba1f6bf926 chore: lint 2024-03-05 23:50:19 +11:00
4a9cca6c2d fix(ui): format model API response data 2024-03-05 23:50:19 +11:00
b0275700b3 refactor(ui): compute prompt trigger options in the component
We can derive the valid trigger options in the component without needing to lift the options list into global state.
2024-03-05 23:50:19 +11:00
8319aca5f9 chore(ui): typegen 2024-03-05 23:50:19 +11:00
51a604f907 pkg(ui): do not fix knip in lint:fix script 2024-03-05 23:50:19 +11:00
7515d73628 make trigger phrases a list of options and add lora name as description to appear in dropdown 2024-03-05 23:50:19 +11:00
2c453aa531 fix type error 2024-03-05 23:50:19 +11:00
2cca6e4c76 check if lora is enabled before adding trigger phrases 2024-03-05 23:50:19 +11:00
ef171e890a use a listener to recalculate trigger phrases when model or lora list changes 2024-03-05 23:50:19 +11:00
caafbf2f0d only show trigger phrase settings on main and lora 2024-03-05 23:50:19 +11:00
2db5eaf907 lint fix 2024-03-05 23:50:19 +11:00
f234bf6256 cleanup 2024-03-05 23:50:19 +11:00
cfa78b4052 adapt embedding popover to work for trigger phrases also 2024-03-05 23:50:19 +11:00
ba1dd4b02b UI in MM to create trigger phrases 2024-03-05 23:50:19 +11:00
bcf58cac59 feat(mm): add config to skip model hash
This is useful for when you are using a memory DB and do not want to wait for all models to be hashed on startup.
2024-03-05 23:50:19 +11:00
e866d90ab2 tidy(mm): remove unused method on probe 2024-03-05 23:50:19 +11:00
e8797787cf fix(mm): fix incorrect calls to update_model 2024-03-05 23:50:19 +11:00
0082ecb22b feat(mm): add path to ModelRecordChanges 2024-03-05 23:50:19 +11:00
656839fcd1 fix(mm): fix typing on heuristic_import 2024-03-05 23:50:19 +11:00
99407c899f feat(ui): update UI to use new model config backend
- Update all queries
- Remove Advanced Add
- Removed un-editable, internal-only model attributes from model edit UI (e.g. format, repo variant, model type)
- Update model tags so the list refreshes when a model installs
- Rename some queries, components, variables, types to match backend
- Fix divide-by-zero in install queue
2024-03-05 23:50:19 +11:00
48119d9010 revert(mm): restore convert route 2024-03-05 23:50:19 +11:00
7c9128b253 tidy(mm): use canonical capitalization for all model-related enums, classes
For example, "Lora" -> "LoRA", "Vae" -> "VAE".
2024-03-05 23:50:19 +11:00
4f9bb00275 tidy(api): tidy mm routes
Rename MM routes to be consistent:
- "import" -> "install"
- "model_record" -> "model"

Comment several unused routes while I work (may end up removing them?):
- list model summary (we use the search route instead)
- add model record
- convert model
- merge models
2024-03-05 23:50:19 +11:00
78895b3e80 fix(mm): add missing inplace parameter to model install abc 2024-03-05 23:50:19 +11:00
3030a34b88 fix(mm): make type and format required in openapi schema for model config 2024-03-05 23:50:19 +11:00
58fa9c2fac fix(mm): do not allow extra fields on ModelRecordChanges 2024-03-05 23:50:19 +11:00
a8b6635050 fix(mm): make key required in openapi schema for model config 2024-03-05 23:50:19 +11:00
6829610a71 tests: rename "example_config" -> "example_it_config" 2024-03-05 23:50:19 +11:00
5551cf8ac4 feat(mm): revise update_model to use ModelRecordChanges 2024-03-05 23:50:19 +11:00
37b969d339 tidy(mm): add default_settings to model config 2024-03-05 23:50:19 +11:00
c953e61294 tidy(mm): "trigger_words" -> "trigger_phrases" 2024-03-05 23:50:19 +11:00
93dd3c848e tidy(mm): remove unused code in select_hf_files.py 2024-03-05 23:50:19 +11:00
02bde7bb75 tests: fix test_hf_model_select::test_select_multiple_weights on windows 2024-03-05 23:50:19 +11:00
3391c19926 chore: ruff 2024-03-05 23:50:19 +11:00
0f60b1ced4 fix(mm): use .value for model config discriminators
There is a breaking change in python 3.11 related to how enums with `str` as a mixin are formatted. This appears to have not caused any grief for us until now.

Re-jigger the discriminator setup to use `.value` so everything works on both python 3.10 and 3.11.
2024-03-05 23:50:19 +11:00
44c40d7d1a refactor(mm): remove unused metadata logic, fix tests
- Metadata is merged with the config. We can simplify the MM substantially and remove the handling for metadata.
- Per discussion, we don't have an ETA for frontend implementation of tags, and with the realization that the tags from CivitAI are largely useless, there's no reason to keep tags in the MM right now. When we are ready to implement tags on the frontend, we can refer back to the implementation here and use it if it supports the design.
- Fix all tests.
2024-03-05 23:50:19 +11:00
0b9a212363 tests: remove 60s timeout for tests
This makes it very difficult to troubleshoot tests. Our github actions now have timeouts, so there's no risk of a test stalling for ages.
2024-03-05 23:50:19 +11:00
c3aa985c93 refactor(mm): get metadata working 2024-03-05 23:50:19 +11:00
7cb0da1f66 refactor(mm): wip schema changes 2024-03-05 23:50:19 +11:00
3534366146 fix(mm): fix extraneous downloaded files in diffusers
Sometimes, diffusers model components (tokenizer, unet, etc.) have multiple weights files in the same directory.

In this situation, we assume the files are different versions of the same weights. For example, we may have multiple
formats (`.bin`, `.safetensors`) with different precisions. When downloading model files, we want to select only
the best of these files for the requested format and precision/variant.

The previous logic assumed that each model weights file would have the same base filename, but this assumption was
not always true. The logic is revised score each file and choose the best scoring file, resulting in only a single
file being downloaded for each submodel/subdirectory.
2024-03-05 23:50:19 +11:00
f2b5f8753f tidy(mm): remove json_schema_extra from config - not needed 2024-03-05 23:50:19 +11:00
f13f5984c0 fix(mm): update db schema & migration 2024-03-05 23:50:19 +11:00
94e1e64296 chore: ruff 2024-03-05 23:50:19 +11:00
2411bf53c0 tidy(mm): better descriptions for model configs 2024-03-05 23:50:19 +11:00
9378e47a06 feat(mm): add source_type to model configs 2024-03-05 23:50:19 +11:00
4471ea8ad1 refactor(mm): simplify model metadata schemas 2024-03-05 23:50:19 +11:00
2c835fd550 refactor(mm): WIP db schema 2024-03-05 23:50:19 +11:00
61b737bb9f tidy(mm): remove update method from ModelConfigBase
It's only used in the soon-to-be-removed model merge logic
2024-03-05 23:50:19 +11:00
a8cd3dfc99 refactor(mm): add models table (schema WIP), rename "original_hash" -> "hash" 2024-03-05 23:50:19 +11:00
0cce582f2f tidy(mm): remove current_hash 2024-03-05 23:50:19 +11:00
4347d1c7f7 tests(mm): fix some objects in tests 2024-03-05 23:50:19 +11:00
bd4fd9693d tidy(mm): rename ckpt "last_modified" -> "converted_at"
Clarify what this timestamp means
2024-03-05 23:50:19 +11:00
9b40c28144 tidy(mm): rename ckpy "config" -> "config_path" 2024-03-05 23:50:19 +11:00
16a5d718bf fix(mm): add config field to ckpt vaes 2024-03-05 23:50:19 +11:00
76cbc745e1 refactor(mm): add CheckpointConfigBase for all ckpt models 2024-03-05 23:50:19 +11:00
0a614943f6 fix(mm): fix broken get_model_discriminator_value 2024-03-05 23:50:19 +11:00
e426096d32 fix(mm): misc typing fixes for model loaders 2024-03-05 23:50:19 +11:00
c561cd751f fix(mm): use correct import path for ConfigMixin, ModelMixin 2024-03-05 23:50:19 +11:00
af9298f0ef tidy(mm): tidy class names in config.py 2024-03-05 23:50:19 +11:00
5b74117836 fix(mm): use generic for model loader registry
This preserves the typing for classes using the decorator
2024-03-05 23:50:19 +11:00
38474c9797 fix(mm): use correct import path for ModelMixin 2024-03-05 23:50:19 +11:00
b880a31039 refactor(mm): remove ztsnr_training field on _MainConfig
This is used to determine the CFG Rescale Multiplier setting. We'll handle this in the UI as a default setting.
2024-03-05 23:50:19 +11:00
dd31bc4586 refactor(mm): remove vae field on _MainConfig
We will handle default VAE selection in the UI.
2024-03-05 23:50:19 +11:00
316573df2d feat(mm): use callable discriminator for AnyModelConfig union 2024-03-05 23:50:19 +11:00
95 changed files with 2716 additions and 4193 deletions

View File

@ -32,7 +32,6 @@ model. These are the:
Responsible for loading a model from disk
into RAM and VRAM and getting it ready for inference.
## Location of the Code
The four main services can be found in
@ -63,23 +62,21 @@ provides the following fields:
|----------------|-----------------|------------------|
| `key` | str | Unique identifier for the model |
| `name` | str | Name of the model (not unique) |
| `model_type` | ModelType | The type of the model |
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
| `base_model` | BaseModelType | The base model that the model is compatible with |
| `model_type` | ModelType | The type of the model |
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
| `base_model` | BaseModelType | The base model that the model is compatible with |
| `path` | str | Location of model on disk |
| `original_hash` | str | Hash of the model when it was first installed |
| `current_hash` | str | Most recent hash of the model's contents |
| `hash` | str | Hash of the model |
| `description` | str | Human-readable description of the model (optional) |
| `source` | str | Model's source URL or repo id (optional) |
The `key` is a unique 32-character random ID which was generated at
install time. The `original_hash` field stores a hash of the model's
install time. The `hash` field stores a hash of the model's
contents at install time obtained by sampling several parts of the
model's files using the `imohash` library. Over the course of the
model's lifetime it may be transformed in various ways, such as
changing its precision or converting it from a .safetensors to a
diffusers model. When this happens, `original_hash` is unchanged, but
`current_hash` is updated to indicate the current contents.
diffusers model.
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
are defined in `invokeai.backend.model_manager.config`. They are also
@ -94,7 +91,6 @@ The `path` field can be absolute or relative. If relative, it is taken
to be relative to the `models_dir` setting in the user's
`invokeai.yaml` file.
### CheckpointConfig
This adds support for checkpoint configurations, and adds the
@ -174,7 +170,7 @@ store = context.services.model_manager.store
or from elsewhere in the code by accessing
`ApiDependencies.invoker.services.model_manager.store`.
### Creating a `ModelRecordService`
### Creating a `ModelRecordService`
To create a new `ModelRecordService` database or open an existing one,
you can directly create either a `ModelRecordServiceSQL` or a
@ -217,27 +213,27 @@ for use in the InvokeAI web server. Its signature is:
```
def open(
cls,
config: InvokeAIAppConfig,
conn: Optional[sqlite3.Connection] = None,
lock: Optional[threading.Lock] = None
config: InvokeAIAppConfig,
conn: Optional[sqlite3.Connection] = None,
lock: Optional[threading.Lock] = None
) -> Union[ModelRecordServiceSQL, ModelRecordServiceFile]:
```
The way it works is as follows:
1. Retrieve the value of the `model_config_db` option from the user's
`invokeai.yaml` config file.
`invokeai.yaml` config file.
2. If `model_config_db` is `auto` (the default), then:
- Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
opened on the passed connection and lock.
- Open up a new connection to `databases/invokeai.db` if `conn`
* Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
opened on the passed connection and lock.
* Open up a new connection to `databases/invokeai.db` if `conn`
and/or `lock` are missing (see note below).
3. If `model_config_db` is a Path, then use `from_db_file`
to return the appropriate type of ModelRecordService.
4. If `model_config_db` is None, then retrieve the legacy
`conf_path` option from `invokeai.yaml` and use the Path
indicated there. This will default to `configs/models.yaml`.
So a typical startup pattern would be:
```
@ -255,7 +251,7 @@ store = ModelRecordServiceBase.open(config, db_conn, lock)
Configurations can be retrieved in several ways.
#### get_model(key) -> AnyModelConfig:
#### get_model(key) -> AnyModelConfig
The basic functionality is to call the record store object's
`get_model()` method with the desired model's unique key. It returns
@ -272,28 +268,28 @@ print(model_conf.path)
If the key is unrecognized, this call raises an
`UnknownModelException`.
#### exists(key) -> AnyModelConfig:
#### exists(key) -> AnyModelConfig
Returns True if a model with the given key exists in the databsae.
#### search_by_path(path) -> AnyModelConfig:
#### search_by_path(path) -> AnyModelConfig
Returns the configuration of the model whose path is `path`. The path
is matched using a simple string comparison and won't correctly match
models referred to by different paths (e.g. using symbolic links).
#### search_by_name(name, base, type) -> List[AnyModelConfig]:
#### search_by_name(name, base, type) -> List[AnyModelConfig]
This method searches for models that match some combination of `name`,
`BaseType` and `ModelType`. Calling without any arguments will return
all the models in the database.
#### all_models() -> List[AnyModelConfig]:
#### all_models() -> List[AnyModelConfig]
Return all the model configs in the database. Exactly equivalent to
calling `search_by_name()` with no arguments.
#### search_by_tag(tags) -> List[AnyModelConfig]:
#### search_by_tag(tags) -> List[AnyModelConfig]
`tags` is a list of strings. This method returns a list of model
configs that contain all of the given tags. Examples:
@ -312,11 +308,11 @@ commercializable_models = [x for x in store.all_models() \
if x.license.contains('allowCommercialUse=Sell')]
```
#### version() -> str:
#### version() -> str
Returns the version of the database, currently at `3.2`
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase:
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase
This method exists to ease the transition from the previous version of
the model manager, in which `get_model()` took the three arguments
@ -337,7 +333,7 @@ model and pass its key to `get_model()`.
Several methods allow you to create and update stored model config
records.
#### add_model(key, config) -> AnyModelConfig:
#### add_model(key, config) -> AnyModelConfig
Given a key and a configuration, this will add the model's
configuration record to the database. `config` can either be a subclass of
@ -352,7 +348,7 @@ model with the same key is already in the database, or an
`InvalidModelConfigException` if a dict was passed and Pydantic
experienced a parse or validation error.
### update_model(key, config) -> AnyModelConfig:
### update_model(key, config) -> AnyModelConfig
Given a key and a configuration, this will update the model
configuration record in the database. `config` can be either a
@ -370,31 +366,31 @@ The `ModelInstallService` class implements the
shop for all your model install needs. It provides the following
functionality:
- Registering a model config record for a model already located on the
* Registering a model config record for a model already located on the
local filesystem, without moving it or changing its path.
- Installing a model alreadiy located on the local filesystem, by
* Installing a model alreadiy located on the local filesystem, by
moving it into the InvokeAI root directory under the
`models` folder (or wherever config parameter `models_dir`
specifies).
- Probing of models to determine their type, base type and other key
* Probing of models to determine their type, base type and other key
information.
- Interface with the InvokeAI event bus to provide status updates on
* Interface with the InvokeAI event bus to provide status updates on
the download, installation and registration process.
- Downloading a model from an arbitrary URL and installing it in
* Downloading a model from an arbitrary URL and installing it in
`models_dir`.
- Special handling for Civitai model URLs which allow the user to
* Special handling for Civitai model URLs which allow the user to
paste in a model page's URL or download link
- Special handling for HuggingFace repo_ids to recursively download
* Special handling for HuggingFace repo_ids to recursively download
the contents of the repository, paying attention to alternative
variants such as fp16.
- Saving tags and other metadata about the model into the invokeai database
* Saving tags and other metadata about the model into the invokeai database
when fetching from a repo that provides that type of information,
(currently only Civitai and HuggingFace).
@ -427,8 +423,8 @@ queue.start()
installer = ModelInstallService(app_config=config,
record_store=record_store,
download_queue=queue
)
download_queue=queue
)
installer.start()
```
@ -443,7 +439,6 @@ required parameters:
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
Once initialized, the installer will provide the following methods:
#### install_job = installer.heuristic_import(source, [config], [access_token])
@ -457,15 +452,15 @@ The `source` is a string that can be any of these forms
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
3. A HuggingFace repo_id with any of the following formats:
- `model/name` -- entire model
- `model/name:fp32` -- entire model, using the fp32 variant
- `model/name:fp16:vae` -- vae submodel, using the fp16 variant
- `model/name::vae` -- vae submodel, using default precision
- `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
- `model/name::path/to/model.safetensors` -- an individual model file, default variant
* `model/name` -- entire model
* `model/name:fp32` -- entire model, using the fp32 variant
* `model/name:fp16:vae` -- vae submodel, using the fp16 variant
* `model/name::vae` -- vae submodel, using default precision
* `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
* `model/name::path/to/model.safetensors` -- an individual model file, default variant
Note that by specifying a relative path to the top of the HuggingFace
repo, you can download and install arbitrary models files.
repo, you can download and install arbitrary models files.
The variant, if not provided, will be automatically filled in with
`fp32` if the user has requested full precision, and `fp16`
@ -491,9 +486,9 @@ following illustrates basic usage:
```
from invokeai.app.services.model_install import (
LocalModelSource,
HFModelSource,
URLModelSource,
LocalModelSource,
HFModelSource,
URLModelSource,
)
source1 = LocalModelSource(path='/opt/models/sushi.safetensors') # a local safetensors file
@ -513,13 +508,13 @@ for source in [source1, source2, source3, source4, source5, source6, source7]:
source2job = installer.wait_for_installs(timeout=120)
for source in sources:
job = source2job[source]
if job.complete:
model_config = job.config_out
model_key = model_config.key
print(f"{source} installed as {model_key}")
elif job.errored:
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
if job.complete:
model_config = job.config_out
model_key = model_config.key
print(f"{source} installed as {model_key}")
elif job.errored:
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
```
As shown here, the `import_model()` method accepts a variety of
@ -528,7 +523,7 @@ HuggingFace repo_ids with and without a subfolder designation,
Civitai model URLs and arbitrary URLs that point to checkpoint files
(but not to folders).
Each call to `import_model()` return a `ModelInstallJob` job,
Each call to `import_model()` return a `ModelInstallJob` job,
an object which tracks the progress of the install.
If a remote model is requested, the model's files are downloaded in
@ -555,7 +550,7 @@ The full list of arguments to `import_model()` is as follows:
| `config` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
The next few sections describe the various types of ModelSource that
can be passed to `import_model()`.
can be passed to `import_model()`.
`config` can be used to override all or a portion of the configuration
attributes returned by the model prober. See the section below for
@ -566,7 +561,6 @@ details.
This is used for a model that is located on a locally-accessible Posix
filesystem, such as a local disk or networked fileshare.
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `path` | str | Path | None | Path to the model file or directory |
@ -625,7 +619,6 @@ HuggingFace has the most complicated `ModelSource` structure:
| `subfolder` | Path | None | Look for the model in a subfolder of the repo. |
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`.
The `variant` is one of the various diffusers formats that HuggingFace
@ -661,7 +654,6 @@ in. To download these files, you must provide an
`HfFolder.get_token()` will be called to fill it in with the cached
one.
#### Monitoring the install job process
When you create an install job with `import_model()`, it launches the
@ -675,14 +667,13 @@ The `ModelInstallJob` class has the following structure:
| `id` | `int` | Integer ID for this job |
| `status` | `InstallStatus` | An enum of [`waiting`, `downloading`, `running`, `completed`, `error` and `cancelled`]|
| `config_in` | `dict` | Overriding configuration values provided by the caller |
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
| `local_path` | `Path` | If a remote model, holds the path of the model after it is downloaded; if a local model, same as `source` |
| `error_type` | `str` | Name of the exception that led to an error status |
| `error` | `str` | Traceback of the error |
If the `event_bus` argument was provided, events will also be
broadcast to the InvokeAI event bus. The events will appear on the bus
as an event of type `EventServiceBase.model_event`, a timestamp and
@ -702,14 +693,13 @@ following keys:
| `total_bytes` | int | Total size of all the files that make up the model |
| `parts` | List[Dict]| Information on the progress of the individual files that make up the model |
The parts is a list of dictionaries that give information on each of
the components pieces of the download. The dictionary's keys are
`source`, `local_path`, `bytes` and `total_bytes`, and correspond to
the like-named keys in the main event.
Note that downloading events will not be issued for local models, and
that downloading events occur *before* the running event.
that downloading events occur _before_ the running event.
##### `model_install_running`
@ -752,14 +742,13 @@ properties: `waiting`, `downloading`, `running`, `complete`, `errored`
and `cancelled`, as well as `in_terminal_state`. The last will return
True if the job is in the complete, errored or cancelled states.
#### Model configuration and probing
The install service uses the `invokeai.backend.model_manager.probe`
module during import to determine the model's type, base type, and
other configuration parameters. Among other things, it assigns a
default name and description for the model based on probed
fields.
fields.
When downloading remote models is implemented, additional
configuration information, such as list of trigger terms, will be
@ -774,11 +763,11 @@ attributes. Here is an example of setting the
```
install_job = installer.import_model(
source=HFModelSource(repo_id='stabilityai/stable-diffusion-2-1',variant='fp32'),
config=dict(
prediction_type=SchedulerPredictionType('v_prediction')
name='stable diffusion 2 base model',
)
)
config=dict(
prediction_type=SchedulerPredictionType('v_prediction')
name='stable diffusion 2 base model',
)
)
```
### Other installer methods
@ -862,7 +851,6 @@ This method is similar to `unregister()`, but also unconditionally
deletes the corresponding model weights file(s), regardless of whether
they are inside or outside the InvokeAI models hierarchy.
#### path = installer.download_and_cache(remote_source, [access_token], [timeout])
This utility routine will download the model file located at source,
@ -953,7 +941,7 @@ following fields:
When you create a job, you can assign it a `priority`. If multiple
jobs are queued, the job with the lowest priority runs first. (Don't
blame me! The Unix developers came up with this convention.)
blame me! The Unix developers came up with this convention.)
Every job has a `source` and a `destination`. `source` is a string in
the base class, but subclassses redefine it more specifically.
@ -974,7 +962,7 @@ is in its lifecycle. Values are defined in the string enum
`DownloadJobStatus`, a symbol available from
`invokeai.app.services.download_manager`. Possible values are:
| **Value** | **String Value** | ** Description ** |
| **Value** | **String Value** | **Description** |
|--------------|---------------------|-------------------|
| `IDLE` | idle | Job created, but not submitted to the queue |
| `ENQUEUED` | enqueued | Job is patiently waiting on the queue |
@ -991,7 +979,7 @@ debugging and performance testing.
In case of an error, the Exception that caused the error will be
placed in the `error` field, and the job's status will be set to
`DownloadJobStatus.ERROR`.
`DownloadJobStatus.ERROR`.
After an error occurs, any partially downloaded files will be deleted
from disk, unless `preserve_partial_downloads` was set to True at job
@ -1040,11 +1028,11 @@ While a job is being downloaded, the queue will emit events at
periodic intervals. A typical series of events during a successful
download session will look like this:
- enqueued
- running
- running
- running
- completed
* enqueued
* running
* running
* running
* completed
There will be a single enqueued event, followed by one or more running
events, and finally one `completed`, `error` or `cancelled`
@ -1053,12 +1041,12 @@ events.
It is possible for a caller to pause download temporarily, in which
case the events may look something like this:
- enqueued
- running
- running
- paused
- running
- completed
* enqueued
* running
* running
* paused
* running
* completed
The download queue logs when downloads start and end (unless `quiet`
is set to True at initialization time) but doesn't log any progress
@ -1120,11 +1108,11 @@ A typical initialization sequence will look like:
from invokeai.app.services.download_manager import DownloadQueueService
def log_download_event(job: DownloadJobBase):
logger.info(f'job={job.id}: status={job.status}')
logger.info(f'job={job.id}: status={job.status}')
queue = DownloadQueueService(
event_handlers=[log_download_event]
)
event_handlers=[log_download_event]
)
```
Event handlers can be provided to the queue at initialization time as
@ -1155,9 +1143,9 @@ To use the former method, follow this example:
```
job = DownloadJobRemoteSource(
source='http://www.civitai.com/models/13456',
destination='/tmp/models/',
event_handlers=[my_handler1, my_handler2], # if desired
)
destination='/tmp/models/',
event_handlers=[my_handler1, my_handler2], # if desired
)
queue.submit_download_job(job, start=True)
```
@ -1172,13 +1160,13 @@ To have the queue create the job for you, follow this example instead:
```
job = queue.create_download_job(
source='http://www.civitai.com/models/13456',
destdir='/tmp/models/',
filename='my_model.safetensors',
event_handlers=[my_handler1, my_handler2], # if desired
start=True,
)
destdir='/tmp/models/',
filename='my_model.safetensors',
event_handlers=[my_handler1, my_handler2], # if desired
start=True,
)
```
The `filename` argument forces the downloader to use the specified
name for the file rather than the name provided by the remote source,
and is equivalent to manually specifying a destination of
@ -1187,7 +1175,6 @@ and is equivalent to manually specifying a destination of
Here is the full list of arguments that can be provided to
`create_download_job()`:
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `source` | Union[str, Path, AnyHttpUrl] | | Download remote or local source |
@ -1200,7 +1187,7 @@ Here is the full list of arguments that can be provided to
Internally, `create_download_job()` has a little bit of internal logic
that looks at the type of the source and selects the right subclass of
`DownloadJobBase` to create and enqueue.
`DownloadJobBase` to create and enqueue.
**TODO**: move this logic into its own method for overriding in
subclasses.
@ -1275,7 +1262,7 @@ for getting the model to run. For example "author" is metadata, while
"type", "base" and "format" are not. The latter fields are part of the
model's config, as defined in `invokeai.backend.model_manager.config`.
### Example Usage:
### Example Usage
```
from invokeai.backend.model_manager.metadata import (
@ -1328,7 +1315,6 @@ This is the common base class for metadata:
| `author` | str | Model's author |
| `tags` | Set[str] | Model tags |
Note that the model config record also has a `name` field. It is
intended that the config record version be locally customizable, while
the metadata version is read-only. However, enforcing this is expected
@ -1348,7 +1334,6 @@ This descends from `ModelMetadataBase` and adds the following fields:
| `last_modified`| datetime | Date of last commit of this model to the repo |
| `files` | List[Path] | List of the files in the model repo |
#### `CivitaiMetadata`
This descends from `ModelMetadataBase` and adds the following fields:
@ -1415,7 +1400,6 @@ testing suite to avoid hitting the internet.
The HuggingFace and Civitai fetcher subclasses add additional
repo-specific fetching methods:
#### HuggingFaceMetadataFetch
This overrides its base class `from_json()` method to return a
@ -1434,13 +1418,12 @@ retrieves its metadata. Functionally equivalent to `from_id()`, the
only difference is that it returna a `CivitaiMetadata` object rather
than an `AnyModelRepoMetadata`.
### Metadata Storage
The `ModelMetadataStore` provides a simple facility to store model
metadata in the `invokeai.db` database. The data is stored as a JSON
blob, with a few common fields (`name`, `author`, `tags`) broken out
to be searchable.
to be searchable.
When a metadata object is saved to the database, it is identified
using the model key, _and this key must correspond to an existing
@ -1535,16 +1518,16 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist
config = InvokeAIAppConfig.get_config()
ram_cache = ModelCache(
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
)
convert_cache = ModelConvertCache(
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
)
loader = ModelLoadService(
app_config=config,
ram_cache=ram_cache,
convert_cache=convert_cache,
registry=ModelLoaderRegistry
app_config=config,
ram_cache=ram_cache,
convert_cache=convert_cache,
registry=ModelLoaderRegistry
)
```
@ -1567,7 +1550,6 @@ The returned `LoadedModel` object contains a copy of the configuration
record returned by the model record `get_model()` method, as well as
the in-memory loaded model:
| **Attribute Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
@ -1581,7 +1563,6 @@ return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious.
`LoadedModel` acts as a context manager. The context loads the model
into the execution device (e.g. VRAM on CUDA systems), locks the model
in the execution device for the duration of the context, and returns
@ -1590,14 +1571,14 @@ the model. Use it like this:
```
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with model_info as vae:
image = vae.decode(latents)[0]
image = vae.decode(latents)[0]
```
`get_model_by_key()` may raise any of the following exceptions:
- `UnknownModelException` -- key not in database
- `ModelNotFoundException` -- key in database but model not found at path
- `NotImplementedException` -- the loader doesn't know how to load this type of model
* `UnknownModelException` -- key not in database
* `ModelNotFoundException` -- key in database but model not found at path
* `NotImplementedException` -- the loader doesn't know how to load this type of model
### Emitting model loading events
@ -1609,15 +1590,15 @@ following payload:
```
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_key=model_key,
submodel_type=submodel,
hash=model_info.hash,
location=str(model_info.location),
precision=str(model_info.precision),
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_key=model_key,
submodel_type=submodel,
hash=model_info.hash,
location=str(model_info.location),
precision=str(model_info.precision),
)
```
@ -1724,6 +1705,7 @@ object, or in `context.services.model_manager` from within an
invocation.
In the examples below, we have retrieved the manager using:
```
mm = ApiDependencies.invoker.services.model_manager
```

View File

@ -26,7 +26,6 @@ from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_metadata import ModelMetadataStoreSQL
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
@ -93,10 +92,9 @@ class ApiDependencies:
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
)
download_queue_service = DownloadQueueService(event_bus=events)
model_metadata_service = ModelMetadataStoreSQL(db=db)
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
model_record_service=ModelRecordServiceSQL(db=db),
download_queue=download_queue_service,
events=events,
)

View File

@ -3,9 +3,7 @@
import pathlib
import shutil
from hashlib import sha1
from random import randbytes
from typing import Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Optional
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
@ -14,15 +12,11 @@ from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
ModelRecordOrderBy,
ModelSummary,
UnknownModelException,
)
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
@ -31,9 +25,6 @@ from invokeai.backend.model_manager.config import (
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata
from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies
@ -49,15 +40,6 @@ class ModelsList(BaseModel):
model_config = ConfigDict(use_enum_values=True)
class ModelTagSet(BaseModel):
"""Return tags for a set of models."""
key: str
name: str
author: str
tags: Set[str]
##############################################################################
# These are example inputs and outputs that are used in places where Swagger
# is unable to generate a correct example.
@ -68,19 +50,16 @@ example_model_config = {
"base": "sd-1",
"type": "main",
"format": "checkpoint",
"config": "string",
"config_path": "string",
"key": "string",
"original_hash": "string",
"current_hash": "string",
"hash": "string",
"description": "string",
"source": "string",
"last_modified": 0,
"vae": "string",
"converted_at": 0,
"variant": "normal",
"prediction_type": "epsilon",
"repo_variant": "fp16",
"upcast_attention": False,
"ztsnr_training": False,
}
example_model_input = {
@ -89,50 +68,12 @@ example_model_input = {
"base": "sd-1",
"type": "main",
"format": "checkpoint",
"config": "configs/stable-diffusion/v1-inference.yaml",
"config_path": "configs/stable-diffusion/v1-inference.yaml",
"description": "Model description",
"vae": None,
"variant": "normal",
}
example_model_metadata = {
"name": "ip_adapter_sd_image_encoder",
"author": "InvokeAI",
"tags": [
"transformers",
"safetensors",
"clip_vision_model",
"endpoints_compatible",
"region:us",
"has_space",
"license:apache-2.0",
],
"files": [
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
"path": "ip_adapter_sd_image_encoder/README.md",
"size": 628,
"sha256": None,
},
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
"path": "ip_adapter_sd_image_encoder/config.json",
"size": 560,
"sha256": None,
},
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
"path": "ip_adapter_sd_image_encoder/model.safetensors",
"size": 2528373448,
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
},
],
"type": "huggingface",
"id": "InvokeAI/ip_adapter_sd_image_encoder",
"tag_dict": {"license": "apache-2.0"},
"last_modified": "2023-09-23T17:33:25Z",
}
##############################################################################
# ROUTES
##############################################################################
@ -212,89 +153,16 @@ async def get_model_record(
raise HTTPException(status_code=404, detail=str(e))
@model_manager_router.get("/summary", operation_id="list_model_summary")
async def list_model_summary(
page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of models per page"),
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
) -> PaginatedResults[ModelSummary]:
"""Gets a page of model summary data."""
record_store = ApiDependencies.invoker.services.model_manager.store
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
return results
@model_manager_router.get(
"/i/{key}/metadata",
operation_id="get_model_metadata",
responses={
200: {
"description": "The model metadata was retrieved successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
},
)
async def get_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
) -> Optional[AnyModelRepoMetadata]:
"""Get a model metadata object."""
record_store = ApiDependencies.invoker.services.model_manager.store
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
return result
@model_manager_router.patch(
"/i/{key}/metadata",
operation_id="update_model_metadata",
responses={
201: {
"description": "The model metadata was updated successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
},
)
async def update_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
changes: ModelMetadataChanges = Body(description="The changes"),
) -> Optional[AnyModelRepoMetadata]:
"""Updates or creates a model metadata object."""
record_store = ApiDependencies.invoker.services.model_manager.store
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
try:
original_metadata = record_store.get_metadata(key)
if original_metadata:
if changes.default_settings:
original_metadata.default_settings = changes.default_settings
metadata_store.update_metadata(key, original_metadata)
else:
metadata_store.add_metadata(
key, BaseMetadata(name="", author="", default_settings=changes.default_settings)
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An error occurred while updating the model metadata: {e}",
)
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
return result
@model_manager_router.get(
"/tags",
operation_id="list_tags",
)
async def list_tags() -> Set[str]:
"""Get a unique set of all the model tags."""
record_store = ApiDependencies.invoker.services.model_manager.store
result: Set[str] = record_store.list_tags()
return result
# @model_manager_router.get("/summary", operation_id="list_model_summary")
# async def list_model_summary(
# page: int = Query(default=0, description="The page to get"),
# per_page: int = Query(default=10, description="The number of models per page"),
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
# ) -> PaginatedResults[ModelSummary]:
# """Gets a page of model summary data."""
# record_store = ApiDependencies.invoker.services.model_manager.store
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
# return results
class FoundModel(BaseModel):
@ -366,19 +234,6 @@ async def scan_for_models(
return scan_results
@model_manager_router.get(
"/tags/search",
operation_id="search_by_metadata_tags",
)
async def search_by_metadata_tags(
tags: Set[str] = Query(default=None, description="Tags to search for"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_manager.store
results = record_store.search_by_metadata_tag(tags)
return ModelsList(models=results)
@model_manager_router.patch(
"/i/{key}",
operation_id="update_model_record",
@ -395,15 +250,13 @@ async def search_by_metadata_tags(
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
"""Update a model's config."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
try:
model_response: AnyModelConfig = record_store.update_model(key, config=info)
model_response: AnyModelConfig = record_store.update_model(key, changes=changes)
logger.info(f"Updated model: {key}")
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@ -415,14 +268,14 @@ async def update_model_record(
@model_manager_router.delete(
"/i/{key}",
operation_id="del_model_record",
operation_id="delete_model",
responses={
204: {"description": "Model deleted successfully"},
404: {"description": "Model not found"},
},
status_code=204,
)
async def del_model_record(
async def delete_model(
key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response:
"""
@ -443,42 +296,39 @@ async def del_model_record(
raise HTTPException(status_code=404, detail=str(e))
@model_manager_router.post(
"/i/",
operation_id="add_model_record",
responses={
201: {
"description": "The model added successfully",
"content": {"application/json": {"example": example_model_config}},
},
409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"},
},
status_code=201,
)
async def add_model_record(
config: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
) -> AnyModelConfig:
"""Add a model using the configuration information appropriate for its type."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
if config.key == "<NOKEY>":
config.key = sha1(randbytes(100)).hexdigest()
logger.info(f"Created model {config.key} for {config.name}")
try:
record_store.add_model(config.key, config)
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
# @model_manager_router.post(
# "/i/",
# operation_id="add_model_record",
# responses={
# 201: {
# "description": "The model added successfully",
# "content": {"application/json": {"example": example_model_config}},
# },
# 409: {"description": "There is already a model corresponding to this path or repo_id"},
# 415: {"description": "Unrecognized file/folder format"},
# },
# status_code=201,
# )
# async def add_model_record(
# config: Annotated[
# AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
# ],
# ) -> AnyModelConfig:
# """Add a model using the configuration information appropriate for its type."""
# logger = ApiDependencies.invoker.services.logger
# record_store = ApiDependencies.invoker.services.model_manager.store
# try:
# record_store.add_model(config)
# except DuplicateModelException as e:
# logger.error(str(e))
# raise HTTPException(status_code=409, detail=str(e))
# except InvalidModelException as e:
# logger.error(str(e))
# raise HTTPException(status_code=415)
# now fetch it out
result: AnyModelConfig = record_store.get_model(config.key)
return result
# # now fetch it out
# result: AnyModelConfig = record_store.get_model(config.key)
# return result
@model_manager_router.post(
@ -553,10 +403,10 @@ async def install_model(
@model_manager_router.get(
"/import",
operation_id="list_model_install_jobs",
"/install",
operation_id="list_model_installs",
)
async def list_model_install_jobs() -> List[ModelInstallJob]:
async def list_model_installs() -> List[ModelInstallJob]:
"""Return the list of model install jobs.
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
@ -570,9 +420,8 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
* "cancelled" -- Job was cancelled before completion.
Once completed, information about the model such as its size, base
model, type, and metadata can be retrieved from the `config_out`
field. For multi-file models such as diffusers, information on individual files
can be retrieved from `download_parts`.
model and type can be retrieved from the `config_out` field. For multi-file models such as diffusers,
information on individual files can be retrieved from `download_parts`.
See the example and schema below for more information.
"""
@ -581,7 +430,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
@model_manager_router.get(
"/import/{id}",
"/install/{id}",
operation_id="get_model_install_job",
responses={
200: {"description": "Success"},
@ -601,7 +450,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
@model_manager_router.delete(
"/import/{id}",
"/install/{id}",
operation_id="cancel_model_install_job",
responses={
201: {"description": "The job was cancelled successfully"},
@ -619,8 +468,8 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
installer.cancel_job(job)
@model_manager_router.patch(
"/import",
@model_manager_router.delete(
"/install",
operation_id="prune_model_install_jobs",
responses={
204: {"description": "All completed and errored jobs have been pruned"},
@ -699,7 +548,8 @@ async def convert_model(
# temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name
model_config.name = f"{original_name}.DELETE"
store.update_model(key, config=model_config)
changes = ModelRecordChanges(name=model_config.name)
store.update_model(key, changes=changes)
# install the diffusers
try:
@ -708,7 +558,7 @@ async def convert_model(
config={
"name": original_name,
"description": model_config.description,
"original_hash": model_config.original_hash,
"hash": model_config.hash,
"source": model_config.source,
},
)
@ -716,10 +566,6 @@ async def convert_model(
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
# get the original metadata
if orig_metadata := store.get_metadata(key):
store.metadata_store.add_metadata(new_key, orig_metadata)
# delete the original safetensors file
installer.delete(key)
@ -731,66 +577,66 @@ async def convert_model(
return new_config
@model_manager_router.put(
"/merge",
operation_id="merge",
responses={
200: {
"description": "Model converted successfully",
"content": {"application/json": {"example": example_model_config}},
},
400: {"description": "Bad request"},
404: {"description": "Model not found"},
409: {"description": "There is already a model registered at this location"},
},
)
async def merge(
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
force: bool = Body(
description="Force merging of models created with different versions of diffusers",
default=False,
),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
merge_dest_directory: Optional[str] = Body(
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
default=None,
),
) -> AnyModelConfig:
"""
Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
```
Argument Description [default]
-------- ----------------------
keys List of 2-3 model keys to merge together. All models must use the same base type.
merged_model_name Name for the merged model [Concat model names]
alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
merge_dest_directory Specify a directory to store the merged model in [models directory]
```
"""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
installer = ApiDependencies.invoker.services.model_manager.install
merger = ModelMerger(installer)
model_names = [installer.record_store.get_model(x).name for x in keys]
response = merger.merge_diffusion_models_and_save(
model_keys=keys,
merged_model_name=merged_model_name or "+".join(model_names),
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory=dest,
)
except UnknownModelException:
raise HTTPException(
status_code=404,
detail=f"One or more of the models '{keys}' not found",
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
# @model_manager_router.put(
# "/merge",
# operation_id="merge",
# responses={
# 200: {
# "description": "Model converted successfully",
# "content": {"application/json": {"example": example_model_config}},
# },
# 400: {"description": "Bad request"},
# 404: {"description": "Model not found"},
# 409: {"description": "There is already a model registered at this location"},
# },
# )
# async def merge(
# keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
# merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
# alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
# force: bool = Body(
# description="Force merging of models created with different versions of diffusers",
# default=False,
# ),
# interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
# merge_dest_directory: Optional[str] = Body(
# description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
# default=None,
# ),
# ) -> AnyModelConfig:
# """
# Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
# ```
# Argument Description [default]
# -------- ----------------------
# keys List of 2-3 model keys to merge together. All models must use the same base type.
# merged_model_name Name for the merged model [Concat model names]
# alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
# force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
# interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
# merge_dest_directory Specify a directory to store the merged model in [models directory]
# ```
# """
# logger = ApiDependencies.invoker.services.logger
# try:
# logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
# dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
# installer = ApiDependencies.invoker.services.model_manager.install
# merger = ModelMerger(installer)
# model_names = [installer.record_store.get_model(x).name for x in keys]
# response = merger.merge_diffusion_models_and_save(
# model_keys=keys,
# merged_model_name=merged_model_name or "+".join(model_names),
# alpha=alpha,
# interp=interp,
# force=force,
# merge_dest_directory=dest,
# )
# except UnknownModelException:
# raise HTTPException(
# status_code=404,
# detail=f"One or more of the models '{keys}' not found",
# )
# except ValueError as e:
# raise HTTPException(status_code=400, detail=str(e))
# return response

View File

@ -133,7 +133,7 @@ class MainModelLoaderInvocation(BaseInvocation):
vae=VaeField(
vae=ModelInfo(
key=key,
submodel_type=SubModelType.Vae,
submodel_type=SubModelType.VAE,
),
),
)

View File

@ -85,7 +85,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel_type=SubModelType.Vae,
submodel_type=SubModelType.VAE,
),
),
)
@ -142,7 +142,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel_type=SubModelType.Vae,
submodel_type=SubModelType.VAE,
),
),
)

View File

@ -256,6 +256,7 @@ class InvokeAIAppConfig(InvokeAISettings):
profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development)
profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development)
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce startup time.", json_schema_extra=Categories.Development)
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)

View File

@ -18,10 +18,9 @@ from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..model_metadata import ModelMetadataStoreBase
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
@ -151,6 +150,13 @@ ModelSource = Annotated[
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
CivitaiModelSource: ModelSourceType.CivitAI,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
@ -260,7 +266,6 @@ class ModelInstallServiceBase(ABC):
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
metadata_store: ModelMetadataStoreBase,
event_bus: Optional["EventServiceBase"] = None,
):
"""
@ -347,6 +352,7 @@ class ModelInstallServiceBase(ABC):
source: str,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
) -> ModelInstallJob:
r"""Install the indicated model using heuristics to interpret user intentions.
@ -392,7 +398,7 @@ class ModelInstallServiceBase(ABC):
will override corresponding autoassigned probe fields in the
model's config record. Use it to override
`name`, `description`, `base_type`, `model_type`, `format`,
`prediction_type`, `image_size`, and/or `ztsnr_training`.
`prediction_type`, and/or `image_size`.
This will download the model located at `source`,
probe it, and install it into the models directory.

View File

@ -20,12 +20,15 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
CheckpointConfigBase,
InvalidModelConfigException,
ModelRepoVariant,
ModelSourceType,
ModelType,
)
from invokeai.backend.model_manager.metadata import (
@ -35,12 +38,14 @@ from invokeai.backend.model_manager.metadata import (
ModelMetadataWithFiles,
RemoteModelFile,
)
from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import Chdir, InvokeAILogger
from invokeai.backend.util.devices import choose_precision, choose_torch_device
from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP,
CivitaiModelSource,
HFModelSource,
InstallStatus,
@ -90,7 +95,6 @@ class ModelInstallService(ModelInstallServiceBase):
self._running = False
self._session = session
self._next_job_id = 0
self._metadata_store = record_store.metadata_store # for convenience
@property
def app_config(self) -> InvokeAIAppConfig: # noqa D102
@ -139,6 +143,7 @@ class ModelInstallService(ModelInstallServiceBase):
config = config or {}
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
config["source_type"] = ModelSourceType.Path
return self._register(model_path, config)
def install_path(
@ -148,11 +153,11 @@ class ModelInstallService(ModelInstallServiceBase):
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
config["key"] = config.get("key", uuid_string())
info: AnyModelConfig = self._probe_model(Path(model_path), config)
if self._app_config.skip_model_hash:
config["hash"] = uuid_string()
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
if preferred_name := config.get("name"):
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
@ -178,7 +183,7 @@ class ModelInstallService(ModelInstallServiceBase):
source: str,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
inplace: bool = False,
inplace: Optional[bool] = False,
) -> ModelInstallJob:
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
@ -374,15 +379,18 @@ class ModelInstallService(ModelInstallServiceBase):
job.bytes = job.total_bytes
self._signal_job_running(job)
job.config_in["source"] = str(job.source)
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any
if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)):
job.config_in["source_api_response"] = job.source_metadata.api_response
if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_phrases:
job.config_in["trigger_phrases"] = job.source_metadata.trigger_phrases
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:
key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
# enter the metadata, if there is any
if job.source_metadata:
self._metadata_store.add_metadata(key, job.source_metadata)
self._signal_job_completed(job)
except InvalidModelConfigException as excp:
@ -468,7 +476,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path)
model.path = new_path.relative_to(models_dir).as_posix()
self.record_store.update_model(key, model)
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
return model
def _scan_register(self, model: Path) -> bool:
@ -520,24 +528,15 @@ class ModelInstallService(ModelInstallServiceBase):
move(old_path, new_path)
return new_path
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
if config: # used to override probe fields
for key, value in config.items():
setattr(info, key, value)
return info
def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str:
# Note that we may be passed a pre-populated AnyModelConfig object,
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
config["key"] = config.get("key", uuid_string())
info = info or ModelProbe.probe(model_path, config)
override_key: Optional[str] = config.get("key") if config else None
config = config or {}
assert info.original_hash # always assigned by probe()
info.key = override_key or info.original_hash
if self._app_config.skip_model_hash:
config["hash"] = uuid_string()
info = info or ModelProbe.probe(model_path, config)
model_path = model_path.absolute()
if model_path.is_relative_to(self.app_config.models_path):
@ -546,11 +545,11 @@ class ModelInstallService(ModelInstallServiceBase):
info.path = model_path.as_posix()
# add 'main' specific fields
if hasattr(info, "config"):
if isinstance(info, CheckpointConfigBase):
# make config relative to our root
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
self.record_store.add_model(info.key, info)
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve()
info.config_path = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
self.record_store.add_model(info)
return info.key
def _next_id(self) -> int:
@ -571,13 +570,15 @@ class ModelInstallService(ModelInstallServiceBase):
source=source,
config_in=config or {},
local_path=Path(source.path),
inplace=source.inplace,
inplace=source.inplace or False,
)
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
if not source.access_token:
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
metadata = CivitaiMetadataFetch(self._session).from_id(str(source.version_id))
metadata = CivitaiMetadataFetch(self._session, self.app_config.get_config().civitai_api_key).from_id(
str(source.version_id)
)
assert isinstance(metadata, ModelMetadataWithFiles)
remote_files = metadata.download_urls(session=self._session)
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
@ -605,15 +606,17 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
# URLs from Civitai or HuggingFace will be handled specially
url_patterns = {
r"^https?://civitai.com/": CivitaiMetadataFetch,
r"^https?://huggingface.co/[^/]+/[^/]+$": HuggingFaceMetadataFetch,
}
metadata = None
for pattern, fetcher in url_patterns.items():
if re.match(pattern, str(source.url), re.IGNORECASE):
metadata = fetcher(self._session).from_url(source.url)
break
fetcher = None
try:
fetcher = self.get_fetcher_from_url(str(source.url))
except ValueError:
pass
kwargs: dict[str, Any] = {"session": self._session}
if fetcher is CivitaiMetadataFetch:
kwargs["api_key"] = self._app_config.get_config().civitai_api_key
if fetcher is not None:
metadata = fetcher(**kwargs).from_url(source.url)
self._logger.debug(f"metadata={metadata}")
if metadata and isinstance(metadata, ModelMetadataWithFiles):
remote_files = metadata.download_urls(session=self._session)
@ -628,7 +631,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_remote_model(
self,
source: ModelSource,
source: HFModelSource | CivitaiModelSource | URLModelSource,
remote_files: List[RemoteModelFile],
metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]],
@ -656,7 +659,7 @@ class ModelInstallService(ModelInstallServiceBase):
# In the event that there is a subfolder specified in the source,
# we need to remove it from the destination path in order to avoid
# creating unwanted subfolders
if hasattr(source, "subfolder") and source.subfolder:
if isinstance(source, HFModelSource) and source.subfolder:
root = Path(remote_files[0].path.parts[0])
subfolder = root / source.subfolder
else:
@ -843,3 +846,11 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"{job.source}: model installation was cancelled")
if self._event_bus:
self._event_bus.emit_model_install_cancelled(str(job.source))
@staticmethod
def get_fetcher_from_url(url: str):
if re.match(r"^https?://civitai.com/", url.lower()):
return CivitaiMetadataFetch
elif re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'")

View File

@ -1,9 +0,0 @@
"""Init file for ModelMetadataStoreService module."""
from .metadata_store_base import ModelMetadataStoreBase
from .metadata_store_sql import ModelMetadataStoreSQL
__all__ = [
"ModelMetadataStoreBase",
"ModelMetadataStoreSQL",
]

View File

@ -1,81 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Storage for Model Metadata
"""
from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple
from pydantic import Field
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
"""A set of changes to apply to model metadata.
Only limited changes are valid:
- `default_settings`: the user-configured default settings for this model
"""
default_settings: Optional[ModelDefaultSettings] = Field(
default=None, description="The user-configured default settings for this model"
)
"""The user-configured default settings for this model"""
class ModelMetadataStoreBase(ABC):
"""Store, search and fetch model metadata retrieved from remote repositories."""
@abstractmethod
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
@abstractmethod
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
@abstractmethod
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
@abstractmethod
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
@abstractmethod
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
@abstractmethod
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
@abstractmethod
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
@abstractmethod
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""

View File

@ -1,223 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Storage for Model Metadata
"""
import sqlite3
from typing import List, Optional, Set, Tuple
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
from .metadata_store_base import ModelMetadataStoreBase
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
"""Store, search and fetch model metadata retrieved from remote repositories."""
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_metadata(
id,
metadata
)
VALUES (?,?);
""",
(
model_key,
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
raise UnknownMetadataException from excp
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata FROM model_metadata
WHERE id=?;
""",
(model_key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownMetadataException("model metadata not found")
return ModelMetadataFetchBase.from_json(rows[0])
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT id,metadata FROM model_metadata;
""",
(),
)
rows = self._cursor.fetchall()
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
json_serialized = metadata.model_dump_json() # turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_metadata
SET
metadata=?
WHERE id=?;
""",
(json_serialized, model_key),
)
if self._cursor.rowcount == 0:
raise UnknownMetadataException("model metadata not found")
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_metadata(model_key)
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
self._cursor.execute(
"""--sql
select tag_text from tags;
"""
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Optional[Set[str]] = None
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.model_id FROM model_tags AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
if matches is None:
matches = model_keys
matches = matches.intersection(model_keys)
except sqlite3.Error as e:
raise e
return matches if matches else set()
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE author=?;
""",
(author,),
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE name=?;
""",
(name,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
"""Update tags for the model referenced by model_key."""
if tags:
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@ -6,20 +6,19 @@ Abstract base class for storing and retrieving model configuration records.
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import List, Optional, Set, Union
from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager import (
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..model_metadata import ModelMetadataStoreBase
from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType
class DuplicateModelException(Exception):
@ -60,11 +59,33 @@ class ModelSummary(BaseModel):
tags: Set[str] = Field(description="tags associated with model")
class ModelRecordChanges(BaseModelExcludeNull):
"""A set of changes to apply to a model."""
# Changes applicable to all models
name: Optional[str] = Field(description="Name of the model.", default=None)
path: Optional[str] = Field(description="Path to the model.", default=None)
description: Optional[str] = Field(description="Model description", default=None)
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
# Checkpoint-specific changes
# TODO(MM2): Should we expose these? Feels footgun-y...
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
prediction_type: Optional[SchedulerPredictionType] = Field(
description="The prediction type of the model.", default=None
)
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
class ModelRecordServiceBase(ABC):
"""Abstract base class for storage and retrieval of model configs."""
@abstractmethod
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
"""
Add a model to the database.
@ -88,13 +109,12 @@ class ModelRecordServiceBase(ABC):
pass
@abstractmethod
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
:param key: Unique key for the model to be updated.
:param changes: A set of changes to apply to this model. Changes are validated before being written.
"""
pass
@ -109,40 +129,6 @@ class ModelRecordServiceBase(ABC):
"""
pass
@property
@abstractmethod
def metadata_store(self) -> ModelMetadataStoreBase:
"""Return a ModelMetadataStore initialized on the same database."""
pass
@abstractmethod
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""
Retrieve metadata (if any) from when model was downloaded from a repo.
:param key: Model key
"""
pass
@abstractmethod
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
"""List metadata for all models that have it."""
pass
@abstractmethod
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
"""
Search model metadata for ones with all listed tags and return their corresponding configs.
:param tags: Set of tags to search for. All tags must be present.
"""
pass
@abstractmethod
def list_tags(self) -> Set[str]:
"""Return a unique set of all the model tags in the metadata database."""
pass
@abstractmethod
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
@ -217,21 +203,3 @@ class ModelRecordServiceBase(ABC):
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
)
return model_configs[0]
def rename_model(
self,
key: str,
new_name: str,
) -> AnyModelConfig:
"""
Rename the indicated model. Just a special case of update_model().
In some implementations, renaming the model may involve changing where
it is stored on the filesystem. So this is broken out.
:param key: Model key
:param new_name: New name for model
"""
config = self.get_model(key)
config.name = new_name
return self.update_model(key, config)

View File

@ -43,7 +43,7 @@ import json
import sqlite3
from math import ceil
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import List, Optional, Union
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import (
@ -53,12 +53,11 @@ from invokeai.backend.model_manager.config import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import (
DuplicateModelException,
ModelRecordChanges,
ModelRecordOrderBy,
ModelRecordServiceBase,
ModelSummary,
@ -69,7 +68,7 @@ from .model_records_base import (
class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database."""
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
@ -78,14 +77,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
super().__init__()
self._db = db
self._cursor = db.conn.cursor()
self._metadata_store = metadata_store
@property
def db(self) -> SqliteDatabase:
"""Return the underlying database."""
return self._db
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
"""
Add a model to the database.
@ -95,23 +93,19 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_config (
INSERT INTO models (
id,
original_hash,
config
)
VALUES (?,?,?);
VALUES (?,?);
""",
(
key,
record.original_hash,
json_serialized,
config.key,
config.model_dump_json(),
),
)
self._db.conn.commit()
@ -119,12 +113,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "model_config.path" in str(e):
msg = f"A model with path '{record.path}' is already installed"
elif "model_config.name" in str(e):
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
else:
msg = f"A model with key '{key}' is already installed"
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
else:
raise e
@ -132,7 +126,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db.conn.rollback()
raise e
return self.get_model(key)
return self.get_model(config.key)
def del_model(self, key: str) -> None:
"""
@ -146,7 +140,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
try:
self._cursor.execute(
"""--sql
DELETE FROM model_config
DELETE FROM models
WHERE id=?;
""",
(key,),
@ -158,21 +152,20 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db.conn.rollback()
raise e
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
"""
Update the model, returning the updated version.
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
record = self.get_model(key)
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
json_serialized = record.model_dump_json()
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_config
UPDATE models
SET
config=?
WHERE id=?;
@ -199,7 +192,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM model_config
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
@ -220,7 +213,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
"""--sql
select count(*) FROM model_config
select count(*) FROM models
WHERE id=?;
""",
(key,),
@ -246,9 +239,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
If none of the optional filters are passed, will return all
models in the database.
"""
results = []
where_clause = []
bindings = []
where_clause: list[str] = []
bindings: list[str] = []
if model_name:
where_clause.append("name=?")
bindings.append(model_name)
@ -265,14 +257,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
f"""--sql
select config, strftime('%s',updated_at) FROM model_config
SELECT config, strftime('%s',updated_at) FROM models
{where};
""",
tuple(bindings),
)
results = [
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
]
result = self._cursor.fetchall()
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]
return results
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
@ -281,7 +272,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM model_config
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
@ -292,13 +283,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated original_hash."""
"""Return models with the indicated hash."""
results = []
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM model_config
WHERE original_hash=?;
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
@ -307,83 +298,35 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
]
return results
@property
def metadata_store(self) -> ModelMetadataStoreBase:
"""Return a ModelMetadataStore initialized on the same database."""
return self._metadata_store
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""
Retrieve metadata (if any) from when model was downloaded from a repo.
:param key: Model key
"""
store = self.metadata_store
try:
metadata = store.get_metadata(key)
return metadata
except UnknownMetadataException:
return None
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
"""
Search model metadata for ones with all listed tags and return their corresponding configs.
:param tags: Set of tags to search for. All tags must be present.
"""
store = ModelMetadataStoreSQL(self._db)
keys = store.search_by_tag(tags)
return [self.get_model(x) for x in keys]
def list_tags(self) -> Set[str]:
"""Return a unique set of all the model tags in the metadata database."""
store = ModelMetadataStoreSQL(self._db)
return store.list_tags()
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
"""List metadata for all models that have it."""
store = ModelMetadataStoreSQL(self._db)
return store.list_all_metadata()
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name",
ModelRecordOrderBy.Type: "a.type",
ModelRecordOrderBy.Base: "a.base",
ModelRecordOrderBy.Name: "a.name",
ModelRecordOrderBy.Format: "a.format",
ModelRecordOrderBy.Default: "type, base, format, name",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]:
"""Fix up results so that there are no null values."""
result: Dict[str, Union[str, int, Set[str]]] = {}
for key, item in summary.items():
result[key] = item or ""
result["tags"] = set(json.loads(summary["tags"] or "[]"))
return result
# Lock so that the database isn't updated while we're doing the two queries.
with self._db.lock:
# query1: get the total number of model configs
self._cursor.execute(
"""--sql
select count(*) from model_config;
select count(*) from models;
""",
(),
)
total = int(self._cursor.fetchone()[0])
# query2: fetch key fields from the join of model_config and model_metadata
# query2: fetch key fields
self._cursor.execute(
f"""--sql
SELECT a.id as key, a.type, a.base, a.format, a.name,
json_extract(a.config, '$.description') as description,
json_extract(b.metadata, '$.tags') as tags
FROM model_config AS a
LEFT JOIN model_metadata AS b on a.id=b.id
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
@ -394,7 +337,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
),
)
rows = self._cursor.fetchall()
items = [ModelSummary.model_validate(_fixup(dict(x))) for x in rows]
items = [ModelSummary.model_validate(dict(x)) for x in rows]
return PaginatedResults(
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
)

View File

@ -1,6 +1,35 @@
from abc import ABC, abstractmethod
from threading import Event
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
class SessionRunnerBase(ABC):
"""
Base class for session runner.
"""
@abstractmethod
def start(self, services: InvocationServices, cancel_event: Event) -> None:
"""Starts the session runner"""
pass
@abstractmethod
def run(self, queue_item: SessionQueueItem) -> None:
"""Runs the session"""
pass
@abstractmethod
def complete(self, queue_item: SessionQueueItem) -> None:
"""Completes the session"""
pass
@abstractmethod
def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None:
"""Runs an already prepared node on the session"""
pass
class SessionProcessorBase(ABC):

View File

@ -2,13 +2,14 @@ import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from typing import Optional
from typing import Callable, Optional, Union
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
@ -16,15 +17,164 @@ from invokeai.app.services.shared.invocation_context import InvocationContextDat
from invokeai.app.util.profiler import Profiler
from ..invoker import Invoker
from .session_processor_base import SessionProcessorBase
from .session_processor_base import SessionProcessorBase, SessionRunnerBase
from .session_processor_common import SessionProcessorStatus
class DefaultSessionRunner(SessionRunnerBase):
"""Processes a single session's invocations"""
def __init__(
self,
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
):
self.on_before_run_node = on_before_run_node
self.on_after_run_node = on_after_run_node
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
"""Start the session runner"""
self.services = services
self.cancel_event = cancel_event
def run(self, queue_item: SessionQueueItem):
"""Run the graph"""
if not queue_item.session:
raise ValueError("Queue item has no session")
# Loop over invocations until the session is complete or canceled
while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
# Prepare the next node
invocation = queue_item.session.next()
if invocation is None:
# If there are no more invocations, complete the graph
break
# Build invocation context (the node-facing API
self.run_node(invocation.id, queue_item)
self.complete(queue_item)
def complete(self, queue_item: SessionQueueItem):
"""Complete the graph"""
self.services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
)
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run before a node is executed"""
# Send starting event
self.services.events.emit_invocation_started(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session_id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
)
if self.on_before_run_node is not None:
self.on_before_run_node(invocation, queue_item)
def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run after a node is executed"""
if self.on_after_run_node is not None:
self.on_after_run_node(invocation, queue_item)
def run_node(self, node_id: str, queue_item: SessionQueueItem):
"""Run a single node in the graph"""
# If this error raises a NodeNotFoundError that's handled by the processor
invocation = queue_item.session.execution_graph.get_node(node_id)
try:
self._on_before_run_node(invocation, queue_item)
data = InvocationContextData(
invocation=invocation,
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
queue_item=queue_item,
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
context = build_invocation_context(
data=data,
services=self.services,
cancel_event=self.cancel_event,
)
# Invoke the node
outputs = invocation.invoke_internal(context=context, services=self.services)
# Save outputs and history
queue_item.session.complete(invocation.id, outputs)
self._on_after_run_node(invocation, queue_item)
# Send complete event on successful runs
self.services.events.emit_invocation_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=data.source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
queue_item.session.set_node_error(invocation.id, error)
self.services.logger.error(
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
)
self.services.logger.error(error)
# Send error event
self.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
error_type=e.__class__.__name__,
error=error,
)
class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
"""Processes sessions from the session queue"""
def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None:
super().__init__()
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
def start(
self,
invoker: Invoker,
thread_limit: int = 1,
polling_interval: int = 1,
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
) -> None:
self._invoker: Invoker = invoker
self._queue_item: Optional[SessionQueueItem] = None
self._invocation: Optional[BaseInvocation] = None
self.on_before_run_session = on_before_run_session
self.on_after_run_session = on_after_run_session
self._resume_event = ThreadEvent()
self._stop_event = ThreadEvent()
@ -59,6 +209,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
"cancel_event": self._cancel_event,
},
)
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
self._thread.start()
def stop(self, *args, **kwargs) -> None:
@ -117,131 +268,34 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
# If we have a on_before_run_session callback, call it
if self.on_before_run_session is not None:
self.on_before_run_session(self._queue_item)
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=self._queue_item.session_id)
# Prepare invocations and take the first
self._invocation = self._queue_item.session.next()
# Run the graph
self.session_runner.run(queue_item=self._queue_item)
# Loop over invocations until the session is complete or canceled
while self._invocation is not None and not cancel_event.is_set():
# get the source node id to provide to clients (the prepared node id is not as useful)
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
# Send starting event
self._invoker.services.events.emit_invocation_started(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session_id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
try:
with self._invoker.services.performance_statistics.collect_stats(
self._invocation, self._queue_item.session.id
):
# Build invocation context (the node-facing API)
data = InvocationContextData(
invocation=self._invocation,
source_invocation_id=source_invocation_id,
queue_item=self._queue_item,
)
context = build_invocation_context(
data=data,
services=self._invoker.services,
cancel_event=self._cancel_event,
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
# Invoke the node
outputs = self._invocation.invoke_internal(
context=context, services=self._invoker.services
)
# Save outputs and history
self._queue_item.session.complete(self._invocation.id, outputs)
# Send complete event
self._invoker.services.events.emit_invocation_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
self._queue_item.session.set_node_error(self._invocation.id, error)
self._invoker.services.logger.error(
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
)
self._invoker.services.logger.error(error)
# Send error event
self._invoker.services.events.emit_invocation_error(
queue_batch_id=self._queue_item.session_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
error_type=e.__class__.__name__,
error=error,
)
pass
# The session is complete if the all invocations are complete or there was an error
if self._queue_item.session.is_complete() or cancel_event.is_set():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
# Set the invocation to None to prepare for the next session
self._invocation = None
else:
# Prepare the next invocation
self._invocation = self._queue_item.session.next()
# If we have a on_after_run_session callback, call it
if self.on_after_run_session is not None:
self.on_after_run_session(self._queue_item)
# The session is complete, immediately poll for next session
self._queue_item = None
@ -275,3 +329,4 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.clear()
self._queue_item = None
self._thread_semaphore.release()
self._invoker.services.logger.debug("Session processor stopped")

View File

@ -9,6 +9,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -35,6 +36,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_4())
migrator.register_migration(build_migration_5())
migrator.register_migration(build_migration_6())
migrator.register_migration(build_migration_7())
migrator.run_migrations()
return db

View File

@ -0,0 +1,88 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration7Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._create_models_table(cursor)
self._drop_old_models_tables(cursor)
def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None:
"""Drops the old model_records, model_metadata, model_tags and tags tables."""
tables = ["model_records", "model_metadata", "model_tags", "tags"]
for table in tables:
cursor.execute(f"DROP TABLE IF EXISTS {table};")
def _create_models_table(self, cursor: sqlite3.Cursor) -> None:
"""Creates the v4.0.0 models table."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS models (
id TEXT NOT NULL PRIMARY KEY,
hash TEXT GENERATED ALWAYS as (json_extract(config, '$.hash')) VIRTUAL NOT NULL,
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
description TEXT GENERATED ALWAYS as (json_extract(config, '$.description')) VIRTUAL,
source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL,
source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL,
source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL,
trigger_phrases TEXT GENERATED ALWAYS as (json_extract(config, '$.trigger_phrases')) VIRTUAL,
-- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
]
# Add trigger for `updated_at`.
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS models_updated_at
AFTER UPDATE
ON models FOR EACH ROW
BEGIN
UPDATE models SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
]
# Add indexes for searchable fields
indices = [
"CREATE INDEX IF NOT EXISTS base_index ON models(base);",
"CREATE INDEX IF NOT EXISTS type_index ON models(type);",
"CREATE INDEX IF NOT EXISTS name_index ON models(name);",
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON models(path);",
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def build_migration_7() -> Migration:
"""
Build the migration from database version 6 to 7.
This migration does the following:
- Adds the new models table
- Drops the old model_records, model_metadata, model_tags and tags tables.
- TODO(MM2): Migrates model names and descriptions from `models.yaml` to the new table (?).
"""
migration_7 = Migration(
from_version=6,
to_version=7,
callback=Migration7Callback(),
)
return migration_7

View File

@ -150,7 +150,7 @@ class MigrateModelYamlToDb1:
""",
(
key,
record.original_hash,
record.hash,
json_serialized,
),
)

View File

@ -17,7 +17,8 @@ class MigrateCallback(Protocol):
See :class:`Migration` for an example.
"""
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
def __call__(self, cursor: sqlite3.Cursor) -> None:
...
class MigrationError(RuntimeError):

View File

@ -1,55 +0,0 @@
import json
from typing import Optional
from pydantic import ValidationError
from invokeai.app.services.shared.graph import Edge
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
"""
Parses raw session string, returning a dict of the graph.
Only the general graph shape is validated; none of the fields are validated.
Any `metadata_accumulator` nodes and edges are removed.
Any validation failure will return None.
"""
graph = json.loads(session_raw).get("graph", None)
# sanity check make sure the graph is at least reasonably shaped
if (
not isinstance(graph, dict)
or "nodes" not in graph
or not isinstance(graph["nodes"], dict)
or "edges" not in graph
or not isinstance(graph["edges"], list)
):
# something has gone terribly awry, return an empty dict
return None
try:
# delete the `metadata_accumulator` node
del graph["nodes"]["metadata_accumulator"]
except KeyError:
# no accumulator node, all good
pass
# delete any edges to or from it
for i, edge in enumerate(graph["edges"]):
try:
# try to parse the edge
Edge(**edge)
except ValidationError:
# something has gone terribly awry, return an empty dict
return None
if (
edge["source"]["node_id"] == "metadata_accumulator"
or edge["destination"]["node_id"] == "metadata_accumulator"
):
del graph["edges"][i]
return graph

View File

@ -25,10 +25,13 @@ from enum import Enum
from typing import Literal, Optional, Type, Union
import torch
from diffusers import ModelMixin
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string
from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models
@ -56,8 +59,8 @@ class ModelType(str, Enum):
ONNX = "onnx"
Main = "main"
Vae = "vae"
Lora = "lora"
VAE = "vae"
LoRA = "lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
@ -73,9 +76,9 @@ class SubModelType(str, Enum):
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Vae = "vae"
VaeDecoder = "vae_decoder"
VaeEncoder = "vae_encoder"
VAE = "vae"
VAEDecoder = "vae_decoder"
VAEEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
@ -93,8 +96,8 @@ class ModelFormat(str, Enum):
Diffusers = "diffusers"
Checkpoint = "checkpoint"
Lycoris = "lycoris"
Onnx = "onnx"
LyCORIS = "lycoris"
ONNX = "onnx"
Olive = "olive"
EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder"
@ -112,128 +115,187 @@ class SchedulerPredictionType(str, Enum):
class ModelRepoVariant(str, Enum):
"""Various hugging face variants on the diffusers format."""
DEFAULT = "" # model files without "fp16" or other qualifier - empty str
Default = "" # model files without "fp16" or other qualifier - empty str
FP16 = "fp16"
FP32 = "fp32"
ONNX = "onnx"
OPENVINO = "openvino"
FLAX = "flax"
OpenVINO = "openvino"
Flax = "flax"
class ModelSourceType(str, Enum):
"""Model source type."""
Path = "path"
Url = "url"
HFRepoID = "hf_repo_id"
CivitAI = "civitai"
class ModelDefaultSettings(BaseModel):
vae: str | None
vae_precision: str | None
scheduler: SCHEDULER_NAME_VALUES | None
steps: int | None
cfg_scale: float | None
cfg_rescale_multiplier: float | None
class ModelConfigBase(BaseModel):
"""Base class for model configuration information."""
path: str = Field(description="filesystem path to the model file or directory")
name: str = Field(description="model name")
base: BaseModelType = Field(description="base model")
type: ModelType = Field(description="type of the model")
format: ModelFormat = Field(description="model format")
key: str = Field(description="unique key for model", default="<NOKEY>")
original_hash: Optional[str] = Field(
description="original fasthash of model contents", default=None
) # this is assigned at install time and will not change
current_hash: Optional[str] = Field(
description="current fasthash of model contents", default=None
) # if model is converted or otherwise modified, this will hold updated hash
description: Optional[str] = Field(description="human readable description of the model", default=None)
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time)
key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
hash: str = Field(description="The hash of the model file(s).")
path: str = Field(
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
)
name: str = Field(description="Name of the model.")
base: BaseModelType = Field(description="The base model.")
description: Optional[str] = Field(description="Model description", default=None)
source: str = Field(description="The original source of the model (path, URL or repo_id).")
source_type: ModelSourceType = Field(description="The type of source")
source_api_response: Optional[str] = Field(
description="The original API response from the source, as stringified JSON.", default=None
)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
schema["required"].extend(
["key", "base", "type", "format", "original_hash", "current_hash", "source", "last_modified"]
)
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(
use_enum_values=False,
validate_assignment=True,
json_schema_extra=json_schema_extra,
)
def update(self, attributes: Dict[str, Any]) -> None:
"""Update the object with fields in dict."""
for key, value in attributes.items():
setattr(self, key, value) # may raise a validation error
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
class _CheckpointConfig(ModelConfigBase):
class CheckpointConfigBase(ModelConfigBase):
"""Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config: str = Field(description="path to the checkpoint model config file")
config_path: str = Field(description="path to the checkpoint model config file")
converted_at: Optional[float] = Field(
description="When this model was last converted to diffusers", default_factory=time.time
)
class _DiffusersConfig(ModelConfigBase):
class DiffusersConfigBase(ModelConfigBase):
"""Model config for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
class LoRAConfig(ModelConfigBase):
class LoRALyCORISConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
type: Literal[ModelType.Lora] = ModelType.Lora
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
type: Literal[ModelType.LoRA] = ModelType.LoRA
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
class VaeCheckpointConfig(ModelConfigBase):
class LoRADiffusersConfig(ModelConfigBase):
"""Model config for LoRA/Diffusers models."""
type: Literal[ModelType.LoRA] = ModelType.LoRA
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}")
class VAECheckpointConfig(CheckpointConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.Vae] = ModelType.Vae
type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}")
class VaeDiffusersConfig(ModelConfigBase):
class VAEDiffusersConfig(ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.Vae] = ModelType.Vae
type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
class ControlNetDiffusersConfig(_DiffusersConfig):
class ControlNetDiffusersConfig(DiffusersConfigBase):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
class ControlNetCheckpointConfig(_CheckpointConfig):
class ControlNetCheckpointConfig(CheckpointConfigBase):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Checkpoint.value}")
class TextualInversionConfig(ModelConfigBase):
class TextualInversionFileConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
class _MainConfig(ModelConfigBase):
"""Model config for main models."""
class TextualInversionFolderConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""
vae: Optional[str] = Field(default=None)
variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
ztsnr_training: bool = False
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
class MainCheckpointConfig(CheckpointConfigBase):
"""Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
class MainDiffusersConfig(DiffusersConfigBase):
"""Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
class IPAdapterConfig(ModelConfigBase):
"""Model config for IP Adaptor format models."""
@ -242,6 +304,10 @@ class IPAdapterConfig(ModelConfigBase):
image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
class CLIPVisionDiffusersConfig(ModelConfigBase):
"""Model config for ClipVision."""
@ -249,58 +315,65 @@ class CLIPVisionDiffusersConfig(ModelConfigBase):
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
class T2IConfig(ModelConfigBase):
class T2IAdapterConfig(ModelConfigBase):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
_ControlNetConfig = Annotated[
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
Field(discriminator="format"),
]
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
AnyModelConfig = Union[
_MainModelConfig,
_VaeConfig,
_ControlNetConfig,
# ModelConfigBase,
LoRAConfig,
TextualInversionConfig,
IPAdapterConfig,
CLIPVisionDiffusersConfig,
T2IConfig,
def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
"""
format_ = None
type_ = None
if isinstance(v, dict):
format_ = v.get("format")
if isinstance(format_, Enum):
format_ = format_.value
type_ = v.get("type")
if isinstance(type_, Enum):
type_ = type_.value
else:
format_ = v.format.value
type_ = v.type.value
v = f"{type_}.{format_}"
return v
AnyModelConfig = Annotated[
Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
# IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
# This is a known issue. Please see:
# https://github.com/tiangolo/fastapi/discussions/9761 and
# https://github.com/tiangolo/fastapi/discussions/9287
# AnyModelConfig = Annotated[
# Union[
# _MainModelConfig,
# _ONNXConfig,
# _VaeConfig,
# _ControlNetConfig,
# LoRAConfig,
# TextualInversionConfig,
# IPAdapterConfig,
# CLIPVisionDiffusersConfig,
# T2IConfig,
# ],
# Field(discriminator="type"),
# ]
class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects."""
@ -332,6 +405,6 @@ class ModelConfigFactory(object):
assert model is not None
if key:
model.key = key
if timestamp:
model.last_modified = timestamp
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
model.converted_at = timestamp
return model # type: ignore

View File

@ -13,6 +13,7 @@ from invokeai.backend.model_manager import (
ModelRepoVariant,
SubModelType,
)
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
@ -50,7 +51,7 @@ class ModelLoader(ModelLoaderBase):
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
if model_config.type == "main" and not submodel_type:
if model_config.type is ModelType.Main and not submodel_type:
raise InvalidModelConfigException("submodel_type is required when loading a main model")
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
@ -80,7 +81,7 @@ class ModelLoader(ModelLoaderBase):
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
return self._convert_model(config, model_path, cache_path)
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
return False
def _load_if_needed(
@ -119,7 +120,7 @@ class ModelLoader(ModelLoaderBase):
return calc_model_size_by_fs(
model_path=model_path,
subfolder=submodel_type.value if submodel_type else None,
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
)
# This needs to be implemented in subclasses that handle checkpoints

View File

@ -15,10 +15,8 @@ Use like this:
"""
import hashlib
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple, Type
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
from ..config import (
AnyModelConfig,
@ -27,8 +25,6 @@ from ..config import (
ModelFormat,
ModelType,
SubModelType,
VaeCheckpointConfig,
VaeDiffusersConfig,
)
from . import ModelLoaderBase
@ -61,6 +57,9 @@ class ModelLoaderRegistryBase(ABC):
"""
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
class ModelLoaderRegistry:
"""
This class allows model loaders to register their type, base and format.
@ -71,10 +70,10 @@ class ModelLoaderRegistry:
@classmethod
def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
"""Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
key = cls._to_registry_key(base, type, format)
if key in cls._registry:
raise Exception(
@ -90,33 +89,15 @@ class ModelLoaderRegistry:
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type."""
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
implementation = cls._registry.get(key1) or cls._registry.get(key2)
if not implementation:
raise NotImplementedError(
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
)
return implementation, conf2, submodel_type
@classmethod
def _handle_subtype_overrides(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
model_path = Path(config.vae)
config_class = (
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
)
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
submodel_type = None
else:
new_conf = config
return new_conf, submodel_type
return implementation, config, submodel_type
@staticmethod
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:

View File

@ -3,8 +3,8 @@
from pathlib import Path
import safetensors
import torch
from safetensors.torch import load_file as safetensors_load_file
from invokeai.backend.model_manager import (
AnyModelConfig,
@ -12,6 +12,7 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
from .. import ModelLoaderRegistry
@ -20,15 +21,15 @@ from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
class ControlnetLoader(GenericDiffusersLoader):
class ControlNetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if config.format != ModelFormat.Checkpoint:
if not isinstance(config, CheckpointConfigBase):
return False
elif (
dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
@ -37,13 +38,13 @@ class ControlnetLoader(GenericDiffusersLoader):
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}")
raise Exception(f"ControlNet conversion not supported for model type: {config.base}")
else:
assert hasattr(config, "config")
config_file = config.config
assert isinstance(config, CheckpointConfigBase)
config_file = config.config_path
if model_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
checkpoint = safetensors_load_file(model_path, device="cpu")
else:
checkpoint = torch.load(model_path, map_location="cpu")

View File

@ -3,9 +3,10 @@
import sys
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Optional
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
from invokeai.backend.model_manager import (
AnyModel,
@ -41,6 +42,7 @@ class GenericDiffusersLoader(ModelLoader):
# TO DO: Add exception handling
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
result = None
if submodel_type:
try:
config = self._load_diffusers_config(model_path, config_name="model_index.json")
@ -64,6 +66,7 @@ class GenericDiffusersLoader(ModelLoader):
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
assert result is not None
return result
# TO DO: Add exception handling
@ -75,7 +78,7 @@ class GenericDiffusersLoader(ModelLoader):
result: ModelMixin = getattr(res_type, class_name)
return result
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> dict[str, Any]:
return ConfigLoader.load_config(model_path, config_name=config_name)
@ -83,8 +86,8 @@ class ConfigLoader(ConfigMixin):
"""Subclass of ConfigMixin for loading diffusers configuration files."""
@classmethod
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
def load_config(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # pyright: ignore [reportIncompatibleMethodOverride]
"""Load a diffusrs ConfigMixin configuration."""
cls.config_name = kwargs.pop("config_name")
# Diffusers doesn't provide typing info
# TODO(psyche): the types on this diffusers method are not correct
return super().load_config(*args, **kwargs) # type: ignore

View File

@ -31,7 +31,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
if submodel_type is not None:
raise ValueError("There are no submodels in an IP-Adapter model.")
model = build_ip_adapter(
ip_adapter_ckpt_path=model_path / "ip_adapter.bin",
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
device=torch.device("cpu"),
dtype=self._torch_dtype,
)

View File

@ -22,8 +22,8 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
from .. import ModelLoader, ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
class LoraLoader(ModelLoader):
"""Class to load LoRA models."""

View File

@ -18,7 +18,7 @@ from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.ONNX)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
class OnnyxDiffusersModel(GenericDiffusersLoader):
"""Class to load onnx models."""

View File

@ -4,7 +4,8 @@
from pathlib import Path
from typing import Optional
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from invokeai.backend.model_manager import (
AnyModel,
@ -16,7 +17,7 @@ from invokeai.backend.model_manager import (
ModelVariantType,
SubModelType,
)
from invokeai.backend.model_manager.config import MainCheckpointConfig
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from .. import ModelLoaderRegistry
@ -54,11 +55,11 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
return result
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if config.format != ModelFormat.Checkpoint:
if not isinstance(config, CheckpointConfigBase):
return False
elif (
dest_path.exists()
and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0)
and (dest_path / "model_index.json").stat().st_mtime >= (config.converted_at or 0.0)
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
@ -73,7 +74,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
)
config_file = config.config
config_file = config.config_path
self._logger.info(f"Converting {model_path} to diffusers format")
convert_ckpt_to_diffusers(

View File

@ -3,9 +3,9 @@
from pathlib import Path
import safetensors
import torch
from omegaconf import DictConfig, OmegaConf
from safetensors.torch import load_file as safetensors_load_file
from invokeai.backend.model_manager import (
AnyModelConfig,
@ -13,24 +13,25 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class VaeLoader(GenericDiffusersLoader):
"""Class to load VAE models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if config.format != ModelFormat.Checkpoint:
if not isinstance(config, CheckpointConfigBase):
return False
elif (
dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
@ -38,16 +39,15 @@ class VaeLoader(GenericDiffusersLoader):
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
# TO DO: check whether sdxl VAE models convert.
# TODO(MM2): check whether sdxl VAE models convert.
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}")
raise Exception(f"VAE conversion not supported for model type: {config.base}")
else:
config_file = (
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
)
assert isinstance(config, CheckpointConfigBase)
config_file = config.config_path
if model_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
checkpoint = safetensors_load_file(model_path, device="cpu")
else:
checkpoint = torch.load(model_path, map_location="cpu")
@ -55,7 +55,7 @@ class VaeLoader(GenericDiffusersLoader):
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
ckpt_config = OmegaConf.load(self._app_config.root_path / config_file)
assert isinstance(ckpt_config, DictConfig)
vae_model = convert_ldm_vae_to_diffusers(

View File

@ -16,6 +16,7 @@ from diffusers import AutoPipelineForText2Image
from diffusers.utils import logging as dlogging
from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from . import (
@ -117,7 +118,6 @@ class ModelMerger(object):
config = self._installer.app_config
store = self._installer.record_store
base_models: Set[BaseModelType] = set()
vae = None
variant = None if self._installer.app_config.full_precision else "fp16"
assert (
@ -134,10 +134,6 @@ class ModelMerger(object):
"normal"
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
# pick up the first model's vae
if key == model_keys[0]:
vae = info.vae
# tally base models used
base_models.add(info.base)
model_paths.extend([config.models_path / info.path])
@ -163,12 +159,10 @@ class ModelMerger(object):
# update model's config
model_config = self._installer.record_store.get_model(key)
model_config.update(
{
"name": merged_model_name,
"description": f"Merge of models {', '.join(model_names)}",
"vae": vae,
}
model_config.name = merged_model_name
model_config.description = f"Merge of models {', '.join(model_names)}"
self._installer.record_store.update_model(
key, ModelRecordChanges(name=model_config.name, description=model_config.description)
)
self._installer.record_store.update_model(key, model_config)
return model_config

View File

@ -25,9 +25,7 @@ from .metadata_base import (
AnyModelRepoMetadataValidator,
BaseMetadata,
CivitaiMetadata,
CommercialUsage,
HuggingFaceMetadata,
LicenseRestrictions,
ModelMetadataWithFiles,
RemoteModelFile,
UnknownMetadataException,
@ -38,10 +36,8 @@ __all__ = [
"AnyModelRepoMetadataValidator",
"CivitaiMetadata",
"CivitaiMetadataFetch",
"CommercialUsage",
"HuggingFaceMetadata",
"HuggingFaceMetadataFetch",
"LicenseRestrictions",
"ModelMetadataFetchBase",
"BaseMetadata",
"ModelMetadataWithFiles",

View File

@ -23,22 +23,21 @@ metadata = fetcher.from_url("https://civitai.com/models/206883/split")
print(metadata.trained_words)
"""
import json
import re
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Optional
import requests
from pydantic import TypeAdapter, ValidationError
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.backend.model_manager import ModelRepoVariant
from invokeai.backend.model_manager.config import ModelRepoVariant
from ..metadata_base import (
AnyModelRepoMetadata,
CivitaiMetadata,
CommercialUsage,
LicenseRestrictions,
RemoteModelFile,
UnknownMetadataException,
)
@ -52,10 +51,13 @@ CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
StringSetAdapter = TypeAdapter(set[str])
class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""Fetch model metadata from Civitai."""
def __init__(self, session: Optional[Session] = None):
def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
@ -63,6 +65,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
this module without an internet connection.
"""
self._requests = session or requests.Session()
self._api_key = api_key
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
@ -102,22 +105,21 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
May raise an `UnknownMetadataException`.
"""
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(model_url).json()
return self._from_model_json(model_json)
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
return self._from_api_response(model_json)
def _from_model_json(self, model_json: Dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
try:
version_id = version_id or model_json["modelVersions"][0]["id"]
version_id = version_id or api_response["modelVersions"][0]["id"]
except TypeError as excp:
raise UnknownMetadataException from excp
# loop till we find the section containing the version requested
version_sections = [x for x in model_json["modelVersions"] if x["id"] == version_id]
version_sections = [x for x in api_response["modelVersions"] if x["id"] == version_id]
if not version_sections:
raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
version_json = version_sections[0]
safe_thumbnails = [x["url"] for x in version_json["images"] if x["nsfw"] == "None"]
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
primary = [x for x in version_json["files"] if x.get("primary")]
@ -134,36 +136,23 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
url = url + f"?type={primary_file['type']}{metadata_string}"
model_files = [
RemoteModelFile(
url=url,
url=self._get_url_with_api_key(url),
path=Path(primary_file["name"]),
size=int(primary_file["sizeKB"] * 1024),
sha256=primary_file["hashes"]["SHA256"],
)
]
try:
trigger_phrases = StringSetAdapter.validate_python(version_json.get("trainedWords"))
except ValidationError:
trigger_phrases: set[str] = set()
return CivitaiMetadata(
id=model_json["id"],
name=version_json["name"],
version_id=version_json["id"],
version_name=version_json["name"],
created=datetime.fromisoformat(_fix_timezone(version_json["createdAt"])),
updated=datetime.fromisoformat(_fix_timezone(version_json["updatedAt"])),
published=datetime.fromisoformat(_fix_timezone(version_json["publishedAt"])),
base_model_trained_on=version_json["baseModel"], # note - need a dictionary to turn into a BaseModelType
files=model_files,
download_url=version_json["downloadUrl"],
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
author=model_json["creator"]["username"],
description=model_json["description"],
version_description=version_json["description"] or "",
tags=model_json["tags"],
trained_words=version_json["trainedWords"],
nsfw=model_json["nsfw"],
restrictions=LicenseRestrictions(
AllowNoCredit=model_json["allowNoCredit"],
AllowCommercialUse={CommercialUsage(x) for x in model_json["allowCommercialUse"]},
AllowDerivatives=model_json["allowDerivatives"],
AllowDifferentLicense=model_json["allowDifferentLicense"],
),
trigger_phrases=trigger_phrases,
api_response=json.dumps(version_json),
)
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
@ -174,14 +163,14 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""
if model_id is None:
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
version = self._requests.get(version_url).json()
version = self._requests.get(self._get_url_with_api_key(version_url)).json()
if error := version.get("error"):
raise UnknownMetadataException(error)
model_id = version["modelId"]
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(model_url).json()
return self._from_model_json(model_json, version_id)
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
return self._from_api_response(model_json, version_id)
@classmethod
def from_json(cls, json: str) -> CivitaiMetadata:
@ -189,6 +178,11 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
metadata = CivitaiMetadata.model_validate_json(json)
return metadata
def _get_url_with_api_key(self, url: str) -> str:
if not self._api_key:
return url
def _fix_timezone(date: str) -> str:
return re.sub(r"Z$", "+00:00", date)
if "?" in url:
return f"{url}&token={self._api_key}"
return f"{url}?token={self._api_key}"

View File

@ -13,6 +13,7 @@ metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
print(metadata.tags)
"""
import json
import re
from pathlib import Path
from typing import Optional
@ -23,7 +24,7 @@ from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFo
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.backend.model_manager import ModelRepoVariant
from invokeai.backend.model_manager.config import ModelRepoVariant
from ..metadata_base import (
AnyModelRepoMetadata,
@ -60,6 +61,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
# Little loop which tries fetching a revision corresponding to the selected variant.
# If not available, then set variant to None and get the default.
# If this too fails, raise exception.
model_info = None
while not model_info:
try:
@ -72,23 +74,24 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
else:
variant = None
files: list[RemoteModelFile] = []
_, name = id.split("/")
return HuggingFaceMetadata(
id=model_info.id,
author=model_info.author,
name=name,
last_modified=model_info.last_modified,
tag_dict=model_info.card_data.to_dict() if model_info.card_data else {},
tags=model_info.tags,
files=[
for s in model_info.siblings or []:
assert s.rfilename is not None
assert s.size is not None
files.append(
RemoteModelFile(
url=hf_hub_url(id, x.rfilename, revision=variant),
path=Path(name, x.rfilename),
size=x.size,
sha256=x.lfs.get("sha256") if x.lfs else None,
url=hf_hub_url(id, s.rfilename, revision=variant),
path=Path(name, s.rfilename),
size=s.size,
sha256=s.lfs.get("sha256") if s.lfs else None,
)
for x in model_info.siblings
],
)
return HuggingFaceMetadata(
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str)
)
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:

View File

@ -14,10 +14,8 @@ versions of these fields are intended to be kept in sync with the
remote repo.
"""
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
from typing import List, Literal, Optional, Union
from huggingface_hub import configure_http_backend, hf_hub_url
from pydantic import BaseModel, Field, TypeAdapter
@ -25,7 +23,6 @@ from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from typing_extensions import Annotated
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.backend.model_manager import ModelRepoVariant
from ..util import select_hf_files
@ -35,31 +32,6 @@ class UnknownMetadataException(Exception):
"""Raised when no metadata is available for a model."""
class CommercialUsage(str, Enum):
"""Type of commercial usage allowed."""
No = "None"
Image = "Image"
Rent = "Rent"
RentCivit = "RentCivit"
Sell = "Sell"
class LicenseRestrictions(BaseModel):
"""Broad categories of licensing restrictions."""
AllowNoCredit: bool = Field(
description="if true, model can be redistributed without crediting author", default=False
)
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
AllowDifferentLicense: bool = Field(
description="if true, derivatives of this model be redistributed under a different license", default=False
)
AllowCommercialUse: Optional[Set[CommercialUsage] | CommercialUsage] = Field(
description="Type of commercial use allowed if no commercial use is allowed.", default=None
)
class RemoteModelFile(BaseModel):
"""Information about a downloadable file that forms part of a model."""
@ -69,24 +41,10 @@ class RemoteModelFile(BaseModel):
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
class ModelDefaultSettings(BaseModel):
vae: str | None
vae_precision: str | None
scheduler: SCHEDULER_NAME_VALUES | None
steps: int | None
cfg_scale: float | None
cfg_rescale_multiplier: float | None
class ModelMetadataBase(BaseModel):
"""Base class for model metadata information."""
name: str = Field(description="model's name")
author: str = Field(description="model's author")
tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
default_settings: Optional[ModelDefaultSettings] = Field(
description="default settings for this model", default=None
)
class BaseMetadata(ModelMetadataBase):
@ -124,60 +82,16 @@ class CivitaiMetadata(ModelMetadataWithFiles):
"""Extended metadata fields provided by Civitai."""
type: Literal["civitai"] = "civitai"
id: int = Field(description="Civitai version identifier")
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
version_id: int = Field(description="Civitai model version identifier")
created: datetime = Field(description="date the model was created")
updated: datetime = Field(description="date the model was last modified")
published: datetime = Field(description="date the model was published to Civitai")
description: str = Field(description="text description of model; may contain HTML")
version_description: str = Field(
description="text description of the model's reversion; usually change history; may contain HTML"
)
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
download_url: AnyHttpUrl = Field(description="download URL for this model")
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
weight_minmax: Tuple[float, float] = Field(
description="minimum and maximum slider values for a LoRA or other secondary model", default=(-1.0, +2.0)
) # note: For future use
@property
def credit_required(self) -> bool:
"""Return True if you must give credit for derivatives of this model and images generated from it."""
return not self.restrictions.AllowNoCredit
@property
def allow_commercial_use(self) -> bool:
"""Return True if commercial use is allowed."""
if self.restrictions.AllowCommercialUse is None:
return False
else:
# accommodate schema change
acu = self.restrictions.AllowCommercialUse
commercial_usage = acu if isinstance(acu, set) else {acu}
return CommercialUsage.No not in commercial_usage
@property
def allow_derivatives(self) -> bool:
"""Return True if derivatives of this model can be redistributed."""
return self.restrictions.AllowDerivatives
@property
def allow_different_license(self) -> bool:
"""Return true if derivatives of this model can use a different license."""
return self.restrictions.AllowDifferentLicense
trigger_phrases: set[str] = Field(description="Trigger phrases extracted from the API response")
api_response: Optional[str] = Field(description="Response from the Civitai API as stringified JSON", default=None)
class HuggingFaceMetadata(ModelMetadataWithFiles):
"""Extended metadata fields provided by HuggingFace."""
type: Literal["huggingface"] = "huggingface"
id: str = Field(description="huggingface model id")
tag_dict: Dict[str, Any]
last_modified: datetime = Field(description="date of last commit to repo")
id: str = Field(description="The HF model id")
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
def download_urls(
self,
@ -206,7 +120,7 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
# the next step reads model_index.json to determine which subdirectories belong
# to the model
if Path(f"{prefix}model_index.json") in paths:
url = hf_hub_url(self.id, filename="model_index.json", subfolder=subfolder)
url = hf_hub_url(self.id, filename="model_index.json", subfolder=str(subfolder) if subfolder else None)
resp = session.get(url)
resp.raise_for_status()
submodels = resp.json()

View File

@ -1,221 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Storage for Model Metadata
"""
import sqlite3
from typing import List, Optional, Set, Tuple
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .fetch import ModelMetadataFetchBase
from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException
class ModelMetadataStore:
"""Store, search and fetch model metadata retrieved from remote repositories."""
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_metadata(
id,
metadata
)
VALUES (?,?);
""",
(
model_key,
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
raise UnknownMetadataException from excp
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata FROM model_metadata
WHERE id=?;
""",
(model_key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownMetadataException("model metadata not found")
return ModelMetadataFetchBase.from_json(rows[0])
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT id,metadata FROM model_metadata;
""",
(),
)
rows = self._cursor.fetchall()
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
json_serialized = metadata.model_dump_json() # turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_metadata
SET
metadata=?
WHERE id=?;
""",
(json_serialized, model_key),
)
if self._cursor.rowcount == 0:
raise UnknownMetadataException("model metadata not found")
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_metadata(model_key)
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
self._cursor.execute(
"""--sql
select tag_text from tags;
"""
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Optional[Set[str]] = None
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.model_id FROM model_tags AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
if matches is None:
matches = model_keys
matches = matches.intersection(model_keys)
except sqlite3.Error as e:
raise e
return matches if matches else set()
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE author=?;
""",
(author,),
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE name=?;
""",
(name,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@ -8,6 +8,7 @@ import torch
from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.util.util import SilenceWarnings
from .config import (
@ -17,6 +18,7 @@ from .config import (
ModelConfigFactory,
ModelFormat,
ModelRepoVariant,
ModelSourceType,
ModelType,
ModelVariantType,
SchedulerPredictionType,
@ -95,8 +97,8 @@ class ModelProbe(object):
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"AutoencoderTiny": ModelType.Vae,
"AutoencoderKL": ModelType.VAE,
"AutoencoderTiny": ModelType.VAE,
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter,
@ -108,14 +110,6 @@ class ModelProbe(object):
) -> None:
cls.PROBES[format][model_type] = probe_class
@classmethod
def heuristic_probe(
cls,
model_path: Path,
fields: Optional[Dict[str, Any]] = None,
) -> AnyModelConfig:
return cls.probe(model_path, fields)
@classmethod
def probe(
cls,
@ -137,19 +131,21 @@ class ModelProbe(object):
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None
model_type = None
if format_type == "diffusers":
if format_type is ModelFormat.Diffusers:
model_type = cls.get_model_type_from_folder(model_path)
else:
model_type = cls.get_model_type_from_checkpoint(model_path)
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class:
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
hash = ModelHash().hash(model_path)
probe = probe_class(model_path)
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
fields["source"] = fields.get("source") or model_path.as_posix()
fields["key"] = fields.get("key", uuid_string())
fields["path"] = model_path.as_posix()
fields["type"] = fields.get("type") or model_type
fields["base"] = fields.get("base") or probe.get_base_type()
@ -161,15 +157,17 @@ class ModelProbe(object):
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["original_hash"] = fields.get("original_hash") or hash
fields["current_hash"] = fields.get("current_hash") or hash
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
# additional fields needed for main and controlnet models
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
fields["config"] = cls._get_checkpoint_config_path(
if (
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
and fields["format"] is ModelFormat.Checkpoint
):
fields["config_path"] = cls._get_checkpoint_config_path(
model_path,
model_type=fields["type"],
base_type=fields["base"],
@ -179,7 +177,7 @@ class ModelProbe(object):
# additional fields needed for main non-checkpoint models
elif fields["type"] == ModelType.Main and fields["format"] in [
ModelFormat.Onnx,
ModelFormat.ONNX,
ModelFormat.Olive,
ModelFormat.Diffusers,
]:
@ -213,11 +211,11 @@ class ModelProbe(object):
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
return ModelType.Vae
return ModelType.VAE
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.Lora
return ModelType.LoRA
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
return ModelType.Lora
return ModelType.LoRA
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.ControlNet
elif key in {"emb_params", "string_to_param"}:
@ -239,7 +237,7 @@ class ModelProbe(object):
if (folder_path / f"learned_embeds.{suffix}").exists():
return ModelType.TextualInversion
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
return ModelType.Lora
return ModelType.LoRA
if (folder_path / "unet/model.onnx").exists():
return ModelType.ONNX
if (folder_path / "image_encoder.txt").exists():
@ -285,13 +283,21 @@ class ModelProbe(object):
if possible_conf.exists():
return possible_conf.absolute()
if model_type == ModelType.Main:
if model_type is ModelType.Main:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
elif model_type == ModelType.ControlNet:
elif model_type is ModelType.ControlNet:
config_file = (
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
"../controlnet/cldm_v15.yaml"
if base_type is BaseModelType.StableDiffusion1
else "../controlnet/cldm_v21.yaml"
)
elif model_type is ModelType.VAE:
config_file = (
"../stable-diffusion/v1-inference.yaml"
if base_type is BaseModelType.StableDiffusion1
else "../stable-diffusion/v2-inference.yaml"
)
else:
raise InvalidModelConfigException(
@ -497,12 +503,12 @@ class FolderProbeBase(ProbeBase):
if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16
if "openvino_model" in x.name:
return ModelRepoVariant.OPENVINO
return ModelRepoVariant.OpenVINO
if "flax_model" in x.name:
return ModelRepoVariant.FLAX
return ModelRepoVariant.Flax
if x.suffix == ".onnx":
return ModelRepoVariant.ONNX
return ModelRepoVariant.DEFAULT
return ModelRepoVariant.Default
class PipelineFolderProbe(FolderProbeBase):
@ -708,8 +714,8 @@ class T2IAdapterFolderProbe(FolderProbeBase):
############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
@ -717,8 +723,8 @@ ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderPro
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)

View File

@ -13,6 +13,7 @@ files_to_download = select_hf_model_files(metadata.files, variant='onnx')
"""
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Set
@ -34,7 +35,7 @@ def filter_files(
The file list can be obtained from the `files` field of HuggingFaceMetadata,
as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
"""
variant = variant or ModelRepoVariant.DEFAULT
variant = variant or ModelRepoVariant.Default
paths: List[Path] = []
root = files[0].parts[0]
@ -73,64 +74,81 @@ def filter_files(
return sorted(_filter_by_variant(paths, variant))
@dataclass
class SubfolderCandidate:
path: Path
score: int
def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]:
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
result = set()
basenames: Dict[Path, Path] = {}
result: set[Path] = set()
subfolder_weights: dict[Path, list[SubfolderCandidate]] = {}
for path in files:
if path.suffix in [".onnx", ".pb", ".onnx_data"]:
if variant == ModelRepoVariant.ONNX:
result.add(path)
elif "openvino_model" in path.name:
if variant == ModelRepoVariant.OPENVINO:
if variant == ModelRepoVariant.OpenVINO:
result.add(path)
elif "flax_model" in path.name:
if variant == ModelRepoVariant.FLAX:
if variant == ModelRepoVariant.Flax:
result.add(path)
elif path.suffix in [".json", ".txt"]:
result.add(path)
elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [
elif variant in [
ModelRepoVariant.FP16,
ModelRepoVariant.FP32,
ModelRepoVariant.DEFAULT,
]:
parent = path.parent
suffixes = path.suffixes
if len(suffixes) == 2:
variant_label, suffix = suffixes
basename = parent / Path(path.stem).stem
else:
variant_label = ""
suffix = suffixes[0]
basename = parent / path.stem
ModelRepoVariant.Default,
] and path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"]:
# For weights files, we want to select the best one for each subfolder. For example, we may have multiple
# text encoders:
#
# - text_encoder/model.fp16.safetensors
# - text_encoder/model.safetensors
# - text_encoder/pytorch_model.bin
# - text_encoder/pytorch_model.fp16.bin
#
# We prefer safetensors over other file formats and an exact variant match. We'll score each file based on
# variant and format and select the best one.
if previous := basenames.get(basename):
if (
previous.suffix != ".safetensors" and suffix == ".safetensors"
): # replace non-safetensors with safetensors when available
basenames[basename] = path
if variant_label == f".{variant}":
basenames[basename] = path
elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]:
basenames[basename] = path
else:
basenames[basename] = path
parent = path.parent
score = 0
if path.suffix == ".safetensors":
score += 1
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
if candidate_variant_label == f".{variant}" or (
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]
):
score += 1
if parent not in subfolder_weights:
subfolder_weights[parent] = []
subfolder_weights[parent].append(SubfolderCandidate(path=path, score=score))
else:
continue
for v in basenames.values():
result.add(v)
for candidate_list in subfolder_weights.values():
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
if highest_score_candidate:
result.add(highest_score_candidate.path)
# If one of the architecture-related variants was specified and no files matched other than
# config and text files then we return an empty list
if (
variant
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OPENVINO, ModelRepoVariant.FLAX]
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OpenVINO, ModelRepoVariant.Flax]
and not any(variant.value in x.name for x in result)
):
return set()

View File

@ -858,9 +858,9 @@ def do_textual_inversion_training(
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
orig_embeds_params[index_no_updates]
)
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:

View File

@ -144,7 +144,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.nextrely = top_of_table
self.lora_models = self.add_model_widgets(
model_type=ModelType.Lora,
model_type=ModelType.LoRA,
window_width=window_width,
)
bottom_of_table = max(bottom_of_table, self.nextrely)

View File

@ -30,7 +30,7 @@
"lint:prettier": "prettier --check .",
"lint:tsc": "tsc --noEmit",
"lint": "concurrently -g -c red,green,yellow,blue,magenta pnpm:lint:*",
"fix": "knip --fix && eslint --fix . && prettier --log-level warn --write .",
"fix": "eslint --fix . && prettier --log-level warn --write .",
"preinstall": "npx only-allow pnpm",
"storybook": "storybook dev -p 6006",
"build-storybook": "storybook build",

View File

@ -304,6 +304,12 @@
"method": "High Resolution Fix Method"
}
},
"prompt": {
"addPromptTrigger": "Add Prompt Trigger",
"compatibleEmbeddings": "Compatible Embeddings",
"noPromptTriggers": "No triggers available",
"noMatchingTriggers": "No matching triggers"
},
"embedding": {
"addEmbedding": "Add Embedding",
"incompatibleModel": "Incompatible base model:",

View File

@ -153,7 +153,7 @@ addFirstListImagesListener(startAppListening);
// Ad-hoc upscale workflwo
addUpscaleRequestedListener(startAppListening);
// Dynamic prompts
// Prompts
addDynamicPromptsListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);

View File

@ -7,8 +7,10 @@ import {
selectAllT2IAdapters,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
import { heightChanged, modelChanged, vaeSelected, widthChanged } from 'features/parameters/store/generationSlice';
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es';
import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models';
@ -24,7 +26,9 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
const log = logger('models');
log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`);
const currentModel = getState().generation.model;
const state = getState();
const currentModel = state.generation.model;
const models = mainModelsAdapterSelectors.selectAll(action.payload);
if (models.length === 0) {
@ -39,6 +43,29 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
return;
}
const defaultModel = state.config.sd.defaultModel;
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
if (defaultModelInList) {
const result = zParameterModel.safeParse(defaultModelInList);
if (result.success) {
dispatch(modelChanged(defaultModelInList, currentModel));
const optimalDimension = getOptimalDimension(defaultModelInList);
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
return;
}
const { width, height } = calculateNewSize(
state.generation.aspectRatio.value,
optimalDimension * optimalDimension
);
dispatch(widthChanged(width));
dispatch(heightChanged(height));
return;
}
}
const result = zParameterModel.safeParse(models[0]);
if (!result.success) {

View File

@ -34,13 +34,13 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
return;
}
const metadata = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(currentModel.key)).unwrap();
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
if (!metadata || !metadata.default_settings) {
if (!modelConfig || !modelConfig.default_settings) {
return;
}
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = metadata.default_settings;
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = modelConfig.default_settings;
if (vae) {
// we store this as "default" within default settings

View File

@ -14,7 +14,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
const { bytes, total_bytes, id } = action.payload.data;
dispatch(
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.bytes = bytes;
@ -33,7 +33,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
const { id } = action.payload.data;
dispatch(
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'completed';
@ -41,7 +41,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
return draft;
})
);
dispatch(api.util.invalidateTags([{ type: 'ModelConfig' }]));
dispatch(api.util.invalidateTags(['Model']));
},
});
@ -51,7 +51,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
const { id, error, error_type } = action.payload.data;
dispatch(
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'error';

View File

@ -1,21 +0,0 @@
import type { Meta, StoryObj } from '@storybook/react';
import { EmbeddingSelect } from './EmbeddingSelect';
import type { EmbeddingSelectProps } from './types';
const meta: Meta<typeof EmbeddingSelect> = {
title: 'Feature/Prompt/EmbeddingSelect',
tags: ['autodocs'],
component: EmbeddingSelect,
};
export default meta;
type Story = StoryObj<typeof EmbeddingSelect>;
const Component = (props: EmbeddingSelectProps) => {
return <EmbeddingSelect {...props}>Invoke</EmbeddingSelect>;
};
export const Default: Story = {
render: Component,
};

View File

@ -1,67 +0,0 @@
import type { ChakraProps } from '@invoke-ai/ui-library';
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import type { EmbeddingSelectProps } from 'features/embedding/types';
import { t } from 'i18next';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import type { TextualInversionModelConfig } from 'services/api/types';
const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const getIsDisabled = useCallback(
(embedding: TextualInversionModelConfig): boolean => {
const isCompatible = currentBaseModel === embedding.base;
const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible;
},
[currentBaseModel]
);
const { data, isLoading } = useGetTextualInversionModelsQuery();
const _onChange = useCallback(
(embedding: TextualInversionModelConfig | null) => {
if (!embedding) {
return;
}
onSelect(embedding.name);
},
[onSelect]
);
const { options, onChange } = useGroupedModelCombobox({
modelEntities: data,
getIsDisabled,
onChange: _onChange,
});
return (
<FormControl>
<Combobox
placeholder={isLoading ? t('common.loading') : t('embedding.addEmbedding')}
defaultMenuIsOpen
autoFocus
value={null}
options={options}
noOptionsMessage={noOptionsMessage}
onChange={onChange}
onMenuClose={onClose}
data-testid="add-embedding"
sx={selectStyles}
/>
</FormControl>
);
});
EmbeddingSelect.displayName = 'EmbeddingSelect';
const selectStyles: ChakraProps['sx'] = {
w: 'full',
};

View File

@ -1,228 +0,0 @@
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import BaseModelSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect';
import BooleanSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect';
import ModelFormatSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect';
import ModelTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect';
import ModelVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect';
import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect';
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { isNil, omitBy } from 'lodash-es';
import { useCallback, useEffect } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useInstallModelMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
export const AdvancedImport = () => {
const dispatch = useAppDispatch();
const [installModel] = useInstallModelMutation();
const { t } = useTranslation();
const {
register,
handleSubmit,
control,
formState: { errors },
setValue,
resetField,
reset,
watch,
} = useForm<AnyModelConfig>({
defaultValues: {
name: '',
base: 'sd-1',
type: 'main',
path: '',
description: '',
format: 'diffusers',
vae: '',
variant: 'normal',
},
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
(values) => {
installModel({
source: values.path,
config: omitBy(values, isNil),
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelAdded', {
modelName: values.name,
}),
status: 'success',
})
)
);
reset();
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddFailed'),
status: 'error',
})
)
);
}
});
},
[installModel, dispatch, t, reset]
);
const watchedModelType = watch('type');
const watchedModelFormat = watch('format');
useEffect(() => {
if (watchedModelType === 'main') {
setValue('format', 'diffusers');
setValue('repo_variant', '');
setValue('variant', 'normal');
}
if (watchedModelType === 'lora') {
setValue('format', 'lycoris');
} else if (watchedModelType === 'embedding') {
setValue('format', 'embedding_file');
} else if (watchedModelType === 'ip_adapter') {
setValue('format', 'invokeai');
} else {
setValue('format', 'diffusers');
}
resetField('upcast_attention');
resetField('ztsnr_training');
resetField('vae');
resetField('config');
resetField('prediction_type');
resetField('image_encoder_model_id');
}, [watchedModelType, resetField, setValue]);
return (
<ScrollableContent>
<form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
<Flex alignItems="flex-end" gap="4">
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.modelType')}</FormLabel>
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
</FormControl>
<Text px="2" fontSize="xs" textAlign="center">
{t('modelManager.advancedImportInfo')}
</Text>
</Flex>
<Flex p={4} borderRadius={4} bg="base.850" height="100%" direction="column" gap="3">
<FormControl isInvalid={Boolean(errors.name)}>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.name')}</FormLabel>
<Input
{...register('name', {
validate: (value) => value.trim().length >= 3 || 'Must be at least 3 characters',
})}
/>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
<Flex>
<FormControl>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.description')}</FormLabel>
<Textarea size="sm" {...register('description')} />
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect control={control} name="base" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('common.format')}</FormLabel>
<ModelFormatSelect control={control} name="format" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.path')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</FormControl>
</Flex>
{watchedModelType === 'main' && (
<>
<Flex gap={4}>
{watchedModelFormat === 'diffusers' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
</FormControl>
)}
{watchedModelFormat === 'checkpoint' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
<Input {...register('config')} />
</FormControl>
)}
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</FormControl>
</Flex>
</>
)}
{watchedModelType === 'ip_adapter' && (
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
<Input {...register('image_encoder_model_id')} />
</FormControl>
</Flex>
)}
<Button mt={2} type="submit">
{t('modelManager.addModel')}
</Button>
</Flex>
</Flex>
</form>
</ScrollableContent>
);
};

View File

@ -12,7 +12,7 @@ type SimpleImportModelConfig = {
location: string;
};
export const SimpleImport = () => {
export const InstallModelForm = () => {
const dispatch = useAppDispatch();
const [installModel, { isLoading }] = useInstallModelMutation();

View File

@ -5,19 +5,19 @@ import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { useCallback, useMemo } from 'react';
import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models';
import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models';
import { ImportQueueItem } from './ImportQueueItem';
import { ModelInstallQueueItem } from './ModelInstallQueueItem';
export const ImportQueue = () => {
export const ModelInstallQueue = () => {
const dispatch = useAppDispatch();
const { data } = useGetModelImportsQuery();
const { data } = useListModelInstallsQuery();
const [pruneModelImports] = usePruneModelImportsMutation();
const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation();
const pruneQueue = useCallback(() => {
pruneModelImports()
const pruneCompletedModelInstalls = useCallback(() => {
_pruneCompletedModelInstalls()
.unwrap()
.then((_) => {
dispatch(
@ -41,7 +41,7 @@ export const ImportQueue = () => {
);
}
});
}, [pruneModelImports, dispatch]);
}, [_pruneCompletedModelInstalls, dispatch]);
const pruneAvailable = useMemo(() => {
return data?.some(
@ -53,14 +53,19 @@ export const ImportQueue = () => {
<Flex flexDir="column" p={3} h="full">
<Flex justifyContent="space-between" alignItems="center">
<Text>{t('modelManager.importQueue')}</Text>
<Button size="sm" isDisabled={!pruneAvailable} onClick={pruneQueue} tooltip={t('modelManager.pruneTooltip')}>
<Button
size="sm"
isDisabled={!pruneAvailable}
onClick={pruneCompletedModelInstalls}
tooltip={t('modelManager.pruneTooltip')}
>
{t('modelManager.prune')}
</Button>
</Flex>
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
<ScrollableContent>
<Flex flexDir="column-reverse" gap="2">
{data?.map((model) => <ImportQueueItem key={model.id} model={model} />)}
{data?.map((model) => <ModelInstallQueueItem key={model.id} installJob={model} />)}
</Flex>
</ScrollableContent>
</Box>

View File

@ -6,17 +6,24 @@ import type { ModelInstallStatus } from 'services/api/types';
const STATUSES = {
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },
downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
downloads_done: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
completed: { colorScheme: 'green', translationKey: 'queue.completed' },
error: { colorScheme: 'red', translationKey: 'queue.failed' },
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
};
const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus; errorReason?: string | null }) => {
const ModelInstallQueueBadge = ({
status,
errorReason,
}: {
status?: ModelInstallStatus;
errorReason?: string | null;
}) => {
const { t } = useTranslation();
if (!status || !Object.keys(STATUSES).includes(status)) {
return <></>;
return null;
}
return (
@ -25,4 +32,4 @@ const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus
</Tooltip>
);
};
export default memo(ImportQueueBadge);
export default memo(ModelInstallQueueBadge);

View File

@ -3,15 +3,16 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { isNil } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { PiXBold } from 'react-icons/pi';
import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
import ImportQueueBadge from './ImportQueueBadge';
import ModelInstallQueueBadge from './ModelInstallQueueBadge';
type ModelListItemProps = {
model: ModelInstallJob;
installJob: ModelInstallJob;
};
const formatBytes = (bytes: number) => {
@ -26,26 +27,26 @@ const formatBytes = (bytes: number) => {
return `${bytes.toFixed(2)} ${units[i]}`;
};
export const ImportQueueItem = (props: ModelListItemProps) => {
const { model } = props;
export const ModelInstallQueueItem = (props: ModelListItemProps) => {
const { installJob } = props;
const dispatch = useAppDispatch();
const [deleteImportModel] = useDeleteModelImportMutation();
const [deleteImportModel] = useCancelModelInstallMutation();
const source = useMemo(() => {
if (model.source.type === 'hf') {
return model.source as HFModelSource;
} else if (model.source.type === 'local') {
return model.source as LocalModelSource;
} else if (model.source.type === 'url') {
return model.source as URLModelSource;
if (installJob.source.type === 'hf') {
return installJob.source as HFModelSource;
} else if (installJob.source.type === 'local') {
return installJob.source as LocalModelSource;
} else if (installJob.source.type === 'url') {
return installJob.source as URLModelSource;
} else {
return model.source as LocalModelSource;
return installJob.source as LocalModelSource;
}
}, [model.source]);
}, [installJob.source]);
const handleDeleteModelImport = useCallback(() => {
deleteImportModel(model.id)
deleteImportModel(installJob.id)
.unwrap()
.then((_) => {
dispatch(
@ -69,7 +70,7 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
);
}
});
}, [deleteImportModel, model, dispatch]);
}, [deleteImportModel, installJob, dispatch]);
const modelName = useMemo(() => {
switch (source.type) {
@ -85,19 +86,23 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
}, [source]);
const progressValue = useMemo(() => {
if (model.bytes === undefined || model.total_bytes === undefined) {
if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
return null;
}
if (installJob.total_bytes === 0) {
return 0;
}
return (model.bytes / model.total_bytes) * 100;
}, [model.bytes, model.total_bytes]);
return (installJob.bytes / installJob.total_bytes) * 100;
}, [installJob.bytes, installJob.total_bytes]);
const progressString = useMemo(() => {
if (model.status !== 'downloading' || model.bytes === undefined || model.total_bytes === undefined) {
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
return '';
}
return `${formatBytes(model.bytes)} / ${formatBytes(model.total_bytes)}`;
}, [model.bytes, model.total_bytes, model.status]);
return `${formatBytes(installJob.bytes)} / ${formatBytes(installJob.total_bytes)}`;
}, [installJob.bytes, installJob.total_bytes, installJob.status]);
return (
<Flex gap="2" w="full" alignItems="center">
@ -109,19 +114,21 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
<Flex flexDir="column" flex={1}>
<Tooltip label={progressString}>
<Progress
value={progressValue}
isIndeterminate={progressValue === undefined}
value={progressValue ?? 0}
isIndeterminate={progressValue === null}
aria-label={t('accessibility.invokeProgressBar')}
h={2}
/>
</Tooltip>
</Flex>
<Box minW="100px" textAlign="center">
<ImportQueueBadge status={model.status} errorReason={model.error_reason} />
<ModelInstallQueueBadge status={installJob.status} errorReason={installJob.error_reason} />
</Box>
<Box minW="20px">
{(model.status === 'downloading' || model.status === 'waiting' || model.status === 'running') && (
{(installJob.status === 'downloading' ||
installJob.status === 'waiting' ||
installJob.status === 'running') && (
<IconButton
isRound={true}
size="xs"

View File

@ -2,24 +2,24 @@ import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@
import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useLazyScanModelsQuery } from 'services/api/endpoints/models';
import { useLazyScanFolderQuery } from 'services/api/endpoints/models';
import { ScanModelsResults } from './ScanModelsResults';
import { ScanModelsResults } from './ScanFolderResults';
export const ScanModelsForm = () => {
const [scanPath, setScanPath] = useState('');
const [errorMessage, setErrorMessage] = useState('');
const { t } = useTranslation();
const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery();
const [_scanFolder, { isLoading, data }] = useLazyScanFolderQuery();
const handleSubmitScan = useCallback(async () => {
_scanModels({ scan_path: scanPath }).catch((error) => {
const scanFolder = useCallback(async () => {
_scanFolder({ scan_path: scanPath }).catch((error) => {
if (error) {
setErrorMessage(error.data.detail);
}
});
}, [_scanModels, scanPath]);
}, [_scanFolder, scanPath]);
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
setScanPath(e.target.value);
@ -36,7 +36,7 @@ export const ScanModelsForm = () => {
<Input value={scanPath} onChange={handleSetScanPath} />
</Flex>
<Button onClick={handleSubmitScan} isLoading={isLoading} isDisabled={scanPath.length === 0}>
<Button onClick={scanFolder} isLoading={isLoading} isDisabled={scanPath.length === 0}>
{t('modelManager.scanFolder')}
</Button>
</Flex>

View File

@ -18,7 +18,7 @@ import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models';
import { ScanModelResultItem } from './ScanModelResultItem';
import { ScanModelResultItem } from './ScanFolderResultItem';
type ScanModelResultsProps = {
results: ScanFolderResponse;

View File

@ -1,12 +1,11 @@
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { useTranslation } from 'react-i18next';
import { AdvancedImport } from './AddModelPanel/AdvancedImport';
import { ImportQueue } from './AddModelPanel/ImportQueue/ImportQueue';
import { ScanModelsForm } from './AddModelPanel/ScanModels/ScanModelsForm';
import { SimpleImport } from './AddModelPanel/SimpleImport';
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
export const ImportModels = () => {
export const InstallModels = () => {
const { t } = useTranslation();
return (
<Flex layerStyle="first" p={3} borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
@ -17,15 +16,11 @@ export const ImportModels = () => {
<Tabs variant="collapse" height="100%">
<TabList>
<Tab>{t('common.simple')}</Tab>
<Tab>{t('modelManager.advanced')}</Tab>
<Tab>{t('modelManager.scan')}</Tab>
</TabList>
<TabPanels p={3} height="100%">
<TabPanel>
<SimpleImport />
</TabPanel>
<TabPanel height="100%">
<AdvancedImport />
<InstallModelForm />
</TabPanel>
<TabPanel height="100%">
<ScanModelsForm />
@ -34,7 +29,7 @@ export const ImportModels = () => {
</Tabs>
</Box>
<Box layerStyle="second" borderRadius="base" w="full" h="50%">
<ImportQueue />
<ModelInstallQueue />
</Box>
</Flex>
);

View File

@ -5,7 +5,7 @@ import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { IoFilter } from 'react-icons/io5';
export const MODEL_TYPE_LABELS: { [key: string]: string } = {
const MODEL_TYPE_LABELS: { [key: string]: string } = {
main: 'Main',
lora: 'LoRA',
embedding: 'Textual Inversion',

View File

@ -1,14 +1,14 @@
import { Box } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { ImportModels } from './ImportModels';
import { InstallModels } from './InstallModels';
import { Model } from './ModelPanel/Model';
export const ModelPane = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
return (
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
{selectedModelKey ? <Model key={selectedModelKey} /> : <ImportModels />}
{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}
</Box>
);
};

View File

@ -1,11 +1,12 @@
import { Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import Loading from 'common/components/Loading/Loading';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { isNil } from 'lodash-es';
import { useMemo } from 'react';
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
@ -23,8 +24,9 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
export const DefaultSettings = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { t } = useTranslation();
const { data, isLoading } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
useAppSelector(initialStatesSelector);
@ -59,7 +61,7 @@ export const DefaultSettings = () => {
]);
if (isLoading) {
return <Loading />;
return <Text>{t('common.loading')}</Text>;
}
return <DefaultSettingsForm defaultSettingsDefaults={defaultSettingsDefaults} />;

View File

@ -8,7 +8,7 @@ import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { IoPencil } from 'react-icons/io5';
import { useUpdateModelMetadataMutation } from 'services/api/endpoints/models';
import { useUpdateModelMutation } from 'services/api/endpoints/models';
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
import { DefaultCfgScale } from './DefaultCfgScale';
@ -41,7 +41,7 @@ export const DefaultSettingsForm = ({
const { t } = useTranslation();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation();
const [updateModel, { isLoading }] = useUpdateModelMutation();
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
defaultValues: defaultSettingsDefaults,
@ -62,7 +62,7 @@ export const DefaultSettingsForm = ({
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
};
editModelMetadata({
updateModel({
key: selectedModelKey,
body: { default_settings: body },
})
@ -90,7 +90,7 @@ export const DefaultSettingsForm = ({
}
});
},
[selectedModelKey, dispatch, editModelMetadata, t]
[selectedModelKey, dispatch, updateModel, t]
);
return (

View File

@ -3,9 +3,9 @@ import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
import type { UpdateModelArg } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
@ -14,8 +14,12 @@ const options: ComboboxOption[] = [
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
];
const BaseModelSelect = (props: UseControllerProps<AnyModelConfig>) => {
const { field } = useController(props);
type Props = {
control: Control<UpdateModelArg['body']>;
};
const BaseModelSelect = ({ control }: Props) => {
const { field } = useController({ control, name: 'base' });
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {

View File

@ -1,27 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'none', label: '-' },
{ value: 'true', label: 'True' },
{ value: 'false', label: 'False' },
];
const BooleanSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value === 'true');
},
[field]
);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(BooleanSelect);

View File

@ -1,47 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController, useWatch } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const ModelFormatSelect = (props: UseControllerProps<AnyModelConfig>) => {
const { field, formState } = useController(props);
const type = useWatch({ control: props.control, name: 'type' });
const onChange = useCallback<ComboboxOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
const options: ComboboxOption[] = useMemo(() => {
const modelType = type || formState.defaultValues?.type;
if (modelType === 'lora') {
return [
{ value: 'lycoris', label: 'LyCORIS' },
{ value: 'diffusers', label: 'Diffusers' },
];
} else if (modelType === 'embedding') {
return [
{ value: 'embedding_file', label: 'Embedding File' },
{ value: 'embedding_folder', label: 'Embedding Folder' },
];
} else if (modelType === 'ip_adapter') {
return [{ value: 'invokeai', label: 'invokeai' }];
} else {
return [
{ value: 'diffusers', label: 'Diffusers' },
{ value: 'checkpoint', label: 'Checkpoint' },
];
}
}, [type, formState.defaultValues?.type]);
const value = useMemo(() => options.find((o) => o.value === field.value), [options, field.value]);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(ModelFormatSelect);

View File

@ -1,32 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { MODEL_TYPE_LABELS } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'main', label: MODEL_TYPE_LABELS['main'] as string },
{ value: 'lora', label: MODEL_TYPE_LABELS['lora'] as string },
{ value: 'embedding', label: MODEL_TYPE_LABELS['embedding'] as string },
{ value: 'vae', label: MODEL_TYPE_LABELS['vae'] as string },
{ value: 'controlnet', label: MODEL_TYPE_LABELS['controlnet'] as string },
{ value: 'ip_adapter', label: MODEL_TYPE_LABELS['ip_adapter'] as string },
{ value: 't2i_adapater', label: MODEL_TYPE_LABELS['t2i_adapter'] as string },
] as const;
const ModelTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(ModelTypeSelect);

View File

@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
import type { UpdateModelArg } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [
{ value: 'normal', label: 'Normal' },
@ -12,8 +12,12 @@ const options: ComboboxOption[] = [
{ value: 'depth', label: 'Depth' },
];
const ModelVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
type Props = {
control: Control<UpdateModelArg['body']>;
};
const ModelVariantSelect = ({ control }: Props) => {
const { field } = useController({ control, name: 'variant' });
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {

View File

@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
import type { UpdateModelArg } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [
{ value: 'none', label: '-' },
@ -13,8 +13,12 @@ const options: ComboboxOption[] = [
{ value: 'sample', label: 'sample' },
];
const PredictionTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
type Props = {
control: Control<UpdateModelArg['body']>;
};
const PredictionTypeSelect = ({ control }: Props) => {
const { field } = useController({ control, name: 'prediction_type' });
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {

View File

@ -1,27 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'none', label: '-' },
{ value: 'fp16', label: 'fp16' },
{ value: 'fp32', label: 'fp32' },
];
const RepoVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value);
},
[field]
);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(RepoVariantSelect);

View File

@ -1,18 +1,41 @@
import { Flex } from '@invoke-ai/ui-library';
import { Box, Flex } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
import { useMemo } from 'react';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import type { ModelType } from 'services/api/types';
import { TriggerPhrases } from './TriggerPhrases';
const MODEL_TYPE_TRIGGER_PHRASE: ModelType[] = ['main', 'lora'];
export const ModelMetadata = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
const { data } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const shouldShowTriggerPhraseSettings = useMemo(() => {
if (!data?.type) {
return false;
}
return MODEL_TYPE_TRIGGER_PHRASE.includes(data.type);
}, [data]);
const apiResponseFormatted = useMemo(() => {
if (!data?.source_api_response) {
return {};
}
return JSON.parse(data.source_api_response);
}, [data?.source_api_response]);
return (
<>
<Flex flexDir="column" height="full" gap="3">
<DataViewer label="metadata" data={metadata || {}} />
</Flex>
</>
<Flex flexDir="column" height="full" gap="3">
{shouldShowTriggerPhraseSettings && (
<Box layerStyle="second" borderRadius="base" p={3}>
<TriggerPhrases />
</Box>
)}
<DataViewer label="metadata" data={apiResponseFormatted} />
</Flex>
);
};

View File

@ -0,0 +1,106 @@
import {
Button,
Flex,
FormControl,
FormErrorMessage,
Input,
Tag,
TagCloseButton,
TagLabel,
} from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { ModelListHeader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelListHeader';
import type { ChangeEvent } from 'react';
import { useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
export const TriggerPhrases = () => {
const { t } = useTranslation();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data: modelConfig } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const [phrase, setPhrase] = useState('');
const [updateModel, { isLoading }] = useUpdateModelMutation();
const handlePhraseChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setPhrase(e.target.value);
}, []);
const triggerPhrases = useMemo(() => {
return modelConfig?.trigger_phrases || [];
}, [modelConfig?.trigger_phrases]);
const errors = useMemo(() => {
const errors = [];
if (phrase.length && triggerPhrases.includes(phrase)) {
errors.push('Phrase is already in list');
}
return errors;
}, [phrase, triggerPhrases]);
const addTriggerPhrase = useCallback(async () => {
if (!selectedModelKey) {
return;
}
if (!phrase.length || triggerPhrases.includes(phrase)) {
return;
}
await updateModel({
key: selectedModelKey,
body: { trigger_phrases: [...triggerPhrases, phrase] },
}).unwrap();
setPhrase('');
}, [updateModel, selectedModelKey, phrase, triggerPhrases]);
const removeTriggerPhrase = useCallback(
async (phraseToRemove: string) => {
if (!selectedModelKey) {
return;
}
const filteredPhrases = triggerPhrases.filter((p) => p !== phraseToRemove);
await updateModel({ key: selectedModelKey, body: { trigger_phrases: filteredPhrases } }).unwrap();
},
[updateModel, selectedModelKey, triggerPhrases]
);
return (
<Flex flexDir="column" w="full" gap="5">
<ModelListHeader title={t('modelManager.triggerPhrases')} />
<form>
<FormControl w="full" isInvalid={Boolean(errors.length)}>
<Flex flexDir="column" w="full">
<Flex gap="3" alignItems="center" w="full">
<Input value={phrase} onChange={handlePhraseChange} placeholder={t('modelManager.typePhraseHere')} />
<Button
type="submit"
onClick={addTriggerPhrase}
isDisabled={Boolean(errors.length)}
isLoading={isLoading}
>
{t('common.add')}
</Button>
</Flex>
{!!errors.length && errors.map((error) => <FormErrorMessage key={error}>{error}</FormErrorMessage>)}
</Flex>
</FormControl>
</form>
<Flex gap="4" flexWrap="wrap" mt="3" mb="3">
{triggerPhrases.map((phrase, index) => (
<Tag size="md" key={index}>
<TagLabel>{phrase}</TagLabel>
<TagCloseButton onClick={removeTriggerPhrase.bind(null, phrase)} isDisabled={isLoading} />
</Tag>
))}
</Flex>
</Flex>
);
};

View File

@ -13,7 +13,7 @@ import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
import { useConvertModelMutation } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/types';
interface ModelConvertProps {
@ -24,7 +24,7 @@ export const ModelConvert = (props: ModelConvertProps) => {
const { model } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
const [convertModel, { isLoading }] = useConvertModelMutation();
const { isOpen, onOpen, onClose } = useDisclosure();
const modelConvertHandler = useCallback(() => {

View File

@ -1,5 +1,6 @@
import {
Button,
Checkbox,
Flex,
FormControl,
FormErrorMessage,
@ -19,66 +20,27 @@ import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { UpdateModelArg } from 'services/api/endpoints/models';
import { useGetModelConfigQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
import BaseModelSelect from './Fields/BaseModelSelect';
import BooleanSelect from './Fields/BooleanSelect';
import ModelFormatSelect from './Fields/ModelFormatSelect';
import ModelTypeSelect from './Fields/ModelTypeSelect';
import ModelVariantSelect from './Fields/ModelVariantSelect';
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
import RepoVariantSelect from './Fields/RepoVariantSelect';
export const ModelEdit = () => {
const dispatch = useAppDispatch();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelsMutation();
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
const { t } = useTranslation();
// const modelData = useMemo(() => {
// if (!data) {
// return null;
// }
// const modelFormat = data.format;
// const modelType = data.type;
// if (modelType === 'main') {
// if (modelFormat === 'diffusers') {
// return data as DiffusersModelConfig;
// } else if (modelFormat === 'checkpoint') {
// return data as CheckpointModelConfig;
// }
// }
// switch (modelType) {
// case 'lora':
// return data as LoRAModelConfig;
// case 'embedding':
// return data as TextualInversionModelConfig;
// case 't2i_adapter':
// return data as T2IAdapterModelConfig;
// case 'ip_adapter':
// return data as IPAdapterModelConfig;
// case 'controlnet':
// return data as ControlNetModelConfig;
// case 'vae':
// return data as VAEModelConfig;
// default:
// return null;
// }
// }, [data]);
const {
register,
handleSubmit,
control,
formState: { errors },
reset,
watch,
} = useForm<UpdateModelArg['body']>({
defaultValues: {
...data,
@ -86,10 +48,7 @@ export const ModelEdit = () => {
mode: 'onChange',
});
const watchedModelType = watch('type');
const watchedModelFormat = watch('format');
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
(values) => {
if (!data?.key) {
return;
@ -143,33 +102,31 @@ export const ModelEdit = () => {
return (
<Flex flexDir="column" h="full">
<form onSubmit={handleSubmit(onSubmit)}>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
<FormLabel hidden={true}>{t('modelManager.modelName')}</FormLabel>
<Input
{...register('name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
})}
size="lg"
/>
<Flex gap={2}>
<Button size="sm" onClick={handleClickCancel}>
{t('common.cancel')}
</Button>
<Button
size="sm"
colorScheme="invokeYellow"
onClick={handleSubmit(onSubmit)}
isLoading={isSubmitting}
isDisabled={Boolean(Object.keys(errors).length)}
>
{t('common.save')}
</Button>
</Flex>
</Flex>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</FormControl>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</FormControl>
<Button size="sm" onClick={handleClickCancel}>
{t('common.cancel')}
</Button>
<Button
size="sm"
colorScheme="invokeYellow"
onClick={handleSubmit(onSubmit)}
isLoading={isSubmitting}
isDisabled={Boolean(Object.keys(errors).length)}
>
{t('common.save')}
</Button>
</Flex>
<Flex flexDir="column" gap={3} mt="4">
<Flex>
@ -184,76 +141,22 @@ export const ModelEdit = () => {
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect control={control} name="base" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.modelType')}</FormLabel>
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
<BaseModelSelect control={control} />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('common.format')}</FormLabel>
<ModelFormatSelect control={control} name="format" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.path')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</FormControl>
</Flex>
{watchedModelType === 'main' && (
<>
<Flex gap={4}>
{watchedModelFormat === 'diffusers' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
</FormControl>
)}
{watchedModelFormat === 'checkpoint' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
<Input {...register('config')} />
</FormControl>
)}
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</FormControl>
</Flex>
</>
)}
{watchedModelType === 'ip_adapter' && (
{data.type === 'main' && data.format === 'checkpoint' && (
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
<Input {...register('image_encoder_model_id')} />
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect control={control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect control={control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<Checkbox {...register('upcast_attention')} />
</FormControl>
</Flex>
)}

View File

@ -91,26 +91,19 @@ export const ModelView = () => {
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
</Flex>
{modelData.type === 'main' && (
<>
<Flex gap={2}>
{modelData.format === 'diffusers' && (
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
)}
{modelData.format === 'checkpoint' && (
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
)}
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
</Flex>
</>
<Flex gap={2}>
{modelData.format === 'diffusers' && modelData.repo_variant && (
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
)}
{modelData.format === 'checkpoint' && (
<>
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} />
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
</>
)}
</Flex>
)}
{modelData.type === 'ip_adapter' && (
<Flex gap={2}>

View File

@ -1,10 +1,10 @@
import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
import { usePrompt } from 'features/embedding/usePrompt';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { setNegativePrompt } from 'features/parameters/store/generationSlice';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
import { usePrompt } from 'features/prompt/usePrompt';
import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
@ -19,19 +19,14 @@ export const ParamNegativePrompt = memo(() => {
},
[dispatch]
);
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown } = usePrompt({
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown } = usePrompt({
prompt,
textareaRef,
onChange: _onChange,
});
return (
<EmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={onSelectEmbedding}
width={textareaRef.current?.clientWidth}
>
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
<Box pos="relative">
<Textarea
id="negativePrompt"
@ -45,10 +40,10 @@ export const ParamNegativePrompt = memo(() => {
variant="darkFilled"
/>
<PromptOverlayButtonWrapper>
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
</PromptOverlayButtonWrapper>
</Box>
</EmbeddingPopover>
</PromptPopover>
);
});

View File

@ -1,11 +1,11 @@
import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ShowDynamicPromptsPreviewButton } from 'features/dynamicPrompts/components/ShowDynamicPromptsPreviewButton';
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
import { usePrompt } from 'features/embedding/usePrompt';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { setPositivePrompt } from 'features/parameters/store/generationSlice';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
import { usePrompt } from 'features/prompt/usePrompt';
import { SDXLConcatButton } from 'features/sdxl/components/SDXLPrompts/SDXLConcatButton';
import { memo, useCallback, useRef } from 'react';
import type { HotkeyCallback } from 'react-hotkeys-hook';
@ -25,7 +25,7 @@ export const ParamPositivePrompt = memo(() => {
},
[dispatch]
);
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown, onFocus } = usePrompt({
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
prompt,
textareaRef: textareaRef,
onChange: handleChange,
@ -42,12 +42,7 @@ export const ParamPositivePrompt = memo(() => {
useHotkeys('alt+a', focus, []);
return (
<EmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={onSelectEmbedding}
width={textareaRef.current?.clientWidth}
>
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
<Box pos="relative">
<Textarea
id="prompt"
@ -61,12 +56,12 @@ export const ParamPositivePrompt = memo(() => {
variant="darkFilled"
/>
<PromptOverlayButtonWrapper>
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
{baseModel === 'sdxl' && <SDXLConcatButton />}
<ShowDynamicPromptsPreviewButton />
</PromptOverlayButtonWrapper>
</Box>
</EmbeddingPopover>
</PromptPopover>
);
});

View File

@ -16,7 +16,6 @@ import type {
ParameterScheduler,
ParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { configChanged } from 'features/system/store/configSlice';
import { clamp } from 'lodash-es';
@ -210,26 +209,6 @@ export const generationSlice = createSlice({
},
extraReducers: (builder) => {
builder.addCase(configChanged, (state, action) => {
const defaultModel = action.payload.sd?.defaultModel;
if (defaultModel && !state.model) {
const [base_model, model_type, model_name] = defaultModel.split('/');
const result = zParameterModel.safeParse({
model_name,
base_model,
model_type,
});
if (result.success) {
state.model = result.data;
const optimalDimension = getOptimalDimension(result.data);
state.width = optimalDimension;
state.height = optimalDimension;
}
}
if (action.payload.sd?.scheduler) {
state.scheduler = action.payload.sd.scheduler;
}

View File

@ -8,15 +8,15 @@ type Props = {
onOpen: () => void;
};
export const AddEmbeddingButton = memo((props: Props) => {
export const AddPromptTriggerButton = memo((props: Props) => {
const { onOpen, isOpen } = props;
const { t } = useTranslation();
return (
<Tooltip label={t('embedding.addEmbedding')}>
<Tooltip label={t('prompt.addPromptTrigger')}>
<IconButton
variant="promptOverlay"
isDisabled={isOpen}
aria-label={t('embedding.addEmbedding')}
aria-label={t('prompt.addPromptTrigger')}
icon={<PiCodeBold />}
onClick={onOpen}
/>
@ -24,4 +24,4 @@ export const AddEmbeddingButton = memo((props: Props) => {
);
});
AddEmbeddingButton.displayName = 'AddEmbeddingButton';
AddPromptTriggerButton.displayName = 'AddPromptTriggerButton';

View File

@ -1,9 +1,9 @@
import { Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
import { EmbeddingSelect } from 'features/embedding/EmbeddingSelect';
import type { EmbeddingPopoverProps } from 'features/embedding/types';
import { PromptTriggerSelect } from 'features/prompt/PromptTriggerSelect';
import type { PromptPopoverProps } from 'features/prompt/types';
import { memo } from 'react';
export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
export const PromptPopover = memo((props: PromptPopoverProps) => {
const { onSelect, isOpen, onClose, width, children } = props;
return (
@ -14,7 +14,7 @@ export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
openDelay={0}
closeDelay={0}
closeOnBlur={true}
returnFocusOnClose={true}
returnFocusOnClose={false}
isLazy
>
<PopoverAnchor>{children}</PopoverAnchor>
@ -27,11 +27,11 @@ export const EmbeddingPopover = memo((props: EmbeddingPopoverProps) => {
borderStyle="solid"
>
<PopoverBody p={0} width={`calc(${width}px - 0.25rem)`}>
<EmbeddingSelect onClose={onClose} onSelect={onSelect} />
<PromptTriggerSelect onClose={onClose} onSelect={onSelect} />
</PopoverBody>
</PopoverContent>
</Popover>
);
});
EmbeddingPopover.displayName = 'EmbeddingPopover';
PromptPopover.displayName = 'PromptPopover';

View File

@ -0,0 +1,21 @@
import type { Meta, StoryObj } from '@storybook/react';
import { PromptTriggerSelect } from './PromptTriggerSelect';
import type { PromptTriggerSelectProps } from './types';
const meta: Meta<typeof PromptTriggerSelect> = {
title: 'Feature/Prompt/PromptTriggerSelect',
tags: ['autodocs'],
component: PromptTriggerSelect,
};
export default meta;
type Story = StoryObj<typeof PromptTriggerSelect>;
const Component = (props: PromptTriggerSelectProps) => {
return <PromptTriggerSelect {...props}>Invoke</PromptTriggerSelect>;
};
export const Default: Story = {
render: Component,
};

View File

@ -0,0 +1,104 @@
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import { selectLoraSlice } from 'features/lora/store/loraSlice';
import type { PromptTriggerSelectProps } from 'features/prompt/types';
import { t } from 'i18next';
import { flatten, map } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import {
loraModelsAdapterSelectors,
textualInversionModelsAdapterSelectors,
useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery,
} from 'services/api/endpoints/models';
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
const selectLoRAs = createMemoizedSelector(selectLoraSlice, (loras) => loras.loras);
export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSelectProps) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const addedLoRAs = useAppSelector(selectLoRAs);
const { data: loraModels, isLoading: isLoadingLoRAs } = useGetLoRAModelsQuery();
const { data: tiModels, isLoading: isLoadingTIs } = useGetTextualInversionModelsQuery();
const _onChange = useCallback<ComboboxOnChange>(
(v) => {
if (!v) {
onSelect('');
return;
}
onSelect(v.value);
},
[onSelect]
);
const options = useMemo(() => {
const _options: GroupBase<ComboboxOption>[] = [];
if (tiModels) {
const embeddingOptions = textualInversionModelsAdapterSelectors
.selectAll(tiModels)
.filter((ti) => ti.base === currentBaseModel)
.map((model) => ({ label: model.name, value: `<${model.name}>` }));
if (embeddingOptions.length > 0) {
_options.push({
label: t('prompt.compatibleEmbeddings'),
options: embeddingOptions,
});
}
}
if (loraModels) {
const triggerPhraseOptions = loraModelsAdapterSelectors
.selectAll(loraModels)
.filter((lora) => map(addedLoRAs, (l) => l.model.key).includes(lora.key))
.map((lora) => {
if (lora.trigger_phrases) {
return lora.trigger_phrases.map((triggerPhrase) => ({ label: triggerPhrase, value: triggerPhrase }));
}
return [];
})
.flatMap((x) => x);
if (triggerPhraseOptions.length > 0) {
_options.push({
label: t('modelManager.triggerPhrases'),
options: flatten(triggerPhraseOptions),
});
}
}
return _options;
}, [tiModels, loraModels, t, currentBaseModel, addedLoRAs]);
return (
<FormControl>
<Combobox
placeholder={isLoadingLoRAs || isLoadingTIs ? t('common.loading') : t('prompt.addPromptTrigger')}
defaultMenuIsOpen
autoFocus
value={null}
options={options}
noOptionsMessage={noOptionsMessage}
onChange={_onChange}
onMenuClose={onClose}
data-testid="add-prompt-trigger"
sx={selectStyles}
/>
</FormControl>
);
});
PromptTriggerSelect.displayName = 'PromptTriggerSelect';
const selectStyles: ChakraProps['sx'] = {
w: 'full',
};

View File

@ -1,12 +1,12 @@
import type { PropsWithChildren } from 'react';
export type EmbeddingSelectProps = {
export type PromptTriggerSelectProps = {
onSelect: (v: string) => void;
onClose: () => void;
};
export type EmbeddingPopoverProps = PropsWithChildren &
EmbeddingSelectProps & {
export type PromptPopoverProps = PropsWithChildren &
PromptTriggerSelectProps & {
isOpen: boolean;
width?: number | string;
};

View File

@ -4,13 +4,13 @@ import type { ChangeEventHandler, KeyboardEventHandler, RefObject } from 'react'
import { useCallback } from 'react';
import { flushSync } from 'react-dom';
type UseInsertEmbeddingArg = {
type UseInsertTriggerArg = {
prompt: string;
textareaRef: RefObject<HTMLTextAreaElement>;
onChange: (v: string) => void;
};
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertEmbeddingArg) => {
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertTriggerArg) => {
const { isOpen, onClose, onOpen } = useDisclosure();
const onChange: ChangeEventHandler<HTMLTextAreaElement> = useCallback(
@ -20,13 +20,13 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
[_onChange]
);
const insertEmbedding = useCallback(
const insertTrigger = useCallback(
(v: string) => {
if (!textareaRef.current) {
return;
}
// this is where we insert the TI trigger
// this is where we insert the trigger
const caret = textareaRef.current.selectionStart;
if (isNil(caret)) {
@ -35,13 +35,9 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
let newPrompt = prompt.slice(0, caret);
if (newPrompt[newPrompt.length - 1] !== '<') {
newPrompt += '<';
}
newPrompt += `${v}`;
newPrompt += `${v}>`;
// we insert the cursor after the `>`
// we insert the cursor after the end of trigger
const finalCaretPos = newPrompt.length;
newPrompt += prompt.slice(caret);
@ -51,7 +47,7 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
_onChange(newPrompt);
});
// set the caret position to just after the TI trigger
// set the cursor position to just after the trigger
textareaRef.current.selectionStart = finalCaretPos;
textareaRef.current.selectionEnd = finalCaretPos;
},
@ -62,17 +58,17 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
textareaRef.current?.focus();
}, [textareaRef]);
const handleClose = useCallback(() => {
const handleClosePopover = useCallback(() => {
onClose();
onFocus();
}, [onFocus, onClose]);
const onSelectEmbedding = useCallback(
const onSelect = useCallback(
(v: string) => {
insertEmbedding(v);
handleClose();
insertTrigger(v);
handleClosePopover();
},
[handleClose, insertEmbedding]
[handleClosePopover, insertTrigger]
);
const onKeyDown: KeyboardEventHandler<HTMLTextAreaElement> = useCallback(
@ -90,7 +86,7 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
isOpen,
onClose,
onOpen,
onSelectEmbedding,
onSelect,
onKeyDown,
onFocus,
};

View File

@ -1,9 +1,9 @@
import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
import { usePrompt } from 'features/embedding/usePrompt';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
import { usePrompt } from 'features/prompt/usePrompt';
import { setNegativeStylePromptSDXL } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback, useRef } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
@ -20,7 +20,7 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
},
[dispatch]
);
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown, onFocus } = usePrompt({
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
prompt,
textareaRef: textareaRef,
onChange: handleChange,
@ -29,12 +29,7 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
useHotkeys('alt+a', onFocus, []);
return (
<EmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={onSelectEmbedding}
width={textareaRef.current?.clientWidth}
>
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
<Box pos="relative">
<Textarea
id="prompt"
@ -48,10 +43,10 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
variant="darkFilled"
/>
<PromptOverlayButtonWrapper>
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
</PromptOverlayButtonWrapper>
</Box>
</EmbeddingPopover>
</PromptPopover>
);
});

View File

@ -1,9 +1,9 @@
import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { AddEmbeddingButton } from 'features/embedding/AddEmbeddingButton';
import { EmbeddingPopover } from 'features/embedding/EmbeddingPopover';
import { usePrompt } from 'features/embedding/usePrompt';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
import { usePrompt } from 'features/prompt/usePrompt';
import { setPositiveStylePromptSDXL } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
@ -19,19 +19,14 @@ export const ParamSDXLPositiveStylePrompt = memo(() => {
},
[dispatch]
);
const { onChange, isOpen, onClose, onOpen, onSelectEmbedding, onKeyDown } = usePrompt({
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown } = usePrompt({
prompt,
textareaRef: textareaRef,
onChange: handleChange,
});
return (
<EmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={onSelectEmbedding}
width={textareaRef.current?.clientWidth}
>
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
<Box pos="relative">
<Textarea
id="prompt"
@ -45,10 +40,10 @@ export const ParamSDXLPositiveStylePrompt = memo(() => {
variant="darkFilled"
/>
<PromptOverlayButtonWrapper>
<AddEmbeddingButton isOpen={isOpen} onOpen={onOpen} />
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
</PromptOverlayButtonWrapper>
</Box>
</EmbeddingPopover>
</PromptPopover>
);
});

View File

@ -1,7 +1,6 @@
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import type { JSONObject } from 'common/types';
import queryString from 'query-string';
import type { operations, paths } from 'services/api/schema';
import type {
@ -24,49 +23,33 @@ export type UpdateModelArg = {
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
};
type UpdateModelMetadataArg = {
key: paths['/api/v2/models/i/{key}/metadata']['patch']['parameters']['path']['key'];
body: paths['/api/v2/models/i/{key}/metadata']['patch']['requestBody']['content']['application/json'];
};
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type UpdateModelMetadataResponse =
paths['/api/v2/models/i/{key}/metadata']['patch']['responses']['200']['content']['application/json'];
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
type GetModelMetadataResponse =
paths['/api/v2/models/i/{key}/metadata']['get']['responses']['200']['content']['application/json'];
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
type DeleteMainModelArg = {
type DeleteModelArg = {
key: string;
};
type DeleteMainModelResponse = void;
type DeleteModelResponse = void;
type ConvertMainModelResponse =
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
type InstallModelArg = {
source: paths['/api/v2/models/install']['post']['parameters']['query']['source'];
access_token?: paths['/api/v2/models/install']['post']['parameters']['query']['access_token'];
// TODO(MM2): This is typed as `Optional[Dict[str, Any]]` in backend...
config?: JSONObject;
// config: NonNullable<paths['/api/v2/models/install']['post']['requestBody']>['content']['application/json'];
};
type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json'];
type ListImportModelsResponse =
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
type ListModelInstallsResponse =
paths['/api/v2/models/install']['get']['responses']['200']['content']['application/json'];
type DeleteImportModelsResponse =
paths['/api/v2/models/import/{id}']['delete']['responses']['201']['content']['application/json'];
type CancelModelInstallResponse =
paths['/api/v2/models/install/{id}']['delete']['responses']['201']['content']['application/json'];
type PruneModelImportsResponse =
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
type PruneCompletedModelInstallsResponse =
paths['/api/v2/models/install']['delete']['responses']['200']['content']['application/json'];
export type ScanFolderResponse =
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
@ -83,6 +66,7 @@ const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
@ -102,6 +86,10 @@ const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelC
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors(
undefined,
getSelectorsOptions
);
const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
@ -146,31 +134,7 @@ const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
query: (base_models) => {
const params: ListModelsArg = {
model_type: 'main',
base_models,
};
const query = queryString.stringify(params, { arrayFormat: 'none' });
return buildModelsUrl(`?${query}`);
},
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getModelMetadata: build.query<GetModelMetadataResponse, string>({
query: (key) => {
return buildModelsUrl(`i/${key}/metadata`);
},
providesTags: ['Model'],
}),
updateModels: build.mutation<UpdateModelResponse, UpdateModelArg>({
updateModel: build.mutation<UpdateModelResponse, UpdateModelArg>({
query: ({ key, body }) => {
return {
url: buildModelsUrl(`i/${key}`),
@ -180,28 +144,17 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
updateModelMetadata: build.mutation<UpdateModelMetadataResponse, UpdateModelMetadataArg>({
query: ({ key, body }) => {
return {
url: buildModelsUrl(`i/${key}/metadata`),
method: 'PATCH',
body: body,
};
},
invalidatesTags: ['Model'],
}),
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
query: ({ source, config, access_token }) => {
query: ({ source }) => {
return {
url: buildModelsUrl('install'),
params: { source, access_token },
params: { source },
method: 'POST',
body: config,
};
},
invalidatesTags: ['Model', 'ModelImports'],
invalidatesTags: ['Model', 'ModelInstalls'],
}),
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
deleteModels: build.mutation<DeleteModelResponse, DeleteModelArg>({
query: ({ key }) => {
return {
url: buildModelsUrl(`i/${key}`),
@ -210,7 +163,7 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
convertMainModels: build.mutation<ConvertMainModelResponse, string>({
convertModel: build.mutation<ConvertMainModelResponse, string>({
query: (key) => {
return {
url: buildModelsUrl(`convert/${key}`),
@ -253,6 +206,57 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
scanFolder: build.query<ScanFolderResponse, ScanFolderArg>({
query: (arg) => {
const folderQueryStr = arg ? queryString.stringify(arg, {}) : '';
return {
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
};
},
}),
listModelInstalls: build.query<ListModelInstallsResponse, void>({
query: () => {
return {
url: buildModelsUrl('install'),
};
},
providesTags: ['ModelInstalls'],
}),
cancelModelInstall: build.mutation<CancelModelInstallResponse, number>({
query: (id) => {
return {
url: buildModelsUrl(`install/${id}`),
method: 'DELETE',
};
},
invalidatesTags: ['ModelInstalls'],
}),
pruneCompletedModelInstalls: build.mutation<PruneCompletedModelInstallsResponse, void>({
query: () => {
return {
url: buildModelsUrl('install'),
method: 'DELETE',
};
},
invalidatesTags: ['ModelInstalls'],
}),
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
query: (base_models) => {
const params: ListModelsArg = {
model_type: 'main',
base_models,
};
const query = queryString.stringify(params, { arrayFormat: 'none' });
return buildModelsUrl(`?${query}`);
},
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
@ -313,40 +317,6 @@ export const modelsApi = api.injectEndpoints({
});
},
}),
scanModels: build.query<ScanFolderResponse, ScanFolderArg>({
query: (arg) => {
const folderQueryStr = arg ? queryString.stringify(arg, {}) : '';
return {
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
};
},
}),
getModelImports: build.query<ListImportModelsResponse, void>({
query: () => {
return {
url: buildModelsUrl(`import`),
};
},
providesTags: ['ModelImports'],
}),
deleteModelImport: build.mutation<DeleteImportModelsResponse, number>({
query: (id) => {
return {
url: buildModelsUrl(`import/${id}`),
method: 'DELETE',
};
},
invalidatesTags: ['ModelImports'],
}),
pruneModelImports: build.mutation<PruneModelImportsResponse, void>({
query: () => {
return {
url: buildModelsUrl('import'),
method: 'PATCH',
};
},
invalidatesTags: ['ModelImports'],
}),
}),
});
@ -360,16 +330,14 @@ export const {
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
useDeleteModelsMutation,
useUpdateModelsMutation,
useUpdateModelMutation,
useInstallModelMutation,
useConvertMainModelsMutation,
useConvertModelMutation,
useSyncModelsMutation,
useLazyScanModelsQuery,
useGetModelImportsQuery,
useGetModelMetadataQuery,
useDeleteModelImportMutation,
usePruneModelImportsMutation,
useUpdateModelMetadataMutation,
useLazyScanFolderQuery,
useListModelInstallsQuery,
useCancelModelInstallMutation,
usePruneCompletedModelInstallsMutation,
} = modelsApi;
const upsertModelConfigs = (

View File

@ -28,7 +28,7 @@ export const tagTypes = [
'InvocationCacheStatus',
'Model',
'ModelConfig',
'ModelImports',
'ModelInstalls',
'T2IAdapterModel',
'MainModel',
'VaeModel',

File diff suppressed because one or more lines are too long

View File

@ -43,14 +43,13 @@ export type ControlField = S['ControlField'];
// Model Configs
// TODO(MM2): Can we make key required in the pydantic model?
export type LoRAModelConfig = S['LoRAConfig'];
export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
// TODO(MM2): Can we rename this from Vae -> VAE
export type VAEModelConfig = S['VaeCheckpointConfig'] | S['VaeDiffusersConfig'];
export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
export type IPAdapterModelConfig = S['IPAdapterConfig'];
// TODO(MM2): Can we rename this to T2IAdapterConfig
export type T2IAdapterModelConfig = S['T2IConfig'];
export type TextualInversionModelConfig = S['TextualInversionConfig'];
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
export type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
export type DiffusersModelConfig = S['MainDiffusersConfig'];
export type CheckpointModelConfig = S['MainCheckpointConfig'];
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];

View File

@ -187,7 +187,7 @@ version = { attr = "invokeai.version.__version__" }
#=== Begin: PyTest and Coverage
[tool.pytest.ini_options]
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers --timeout 60 -m \"not slow\""
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\""
markers = [
"slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\".",
"timeout: Marks the timeout override."

View File

@ -59,12 +59,11 @@ def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase,
def test_registration_meta_override_succeed(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
store = mm2_installer.record_store
key = mm2_installer.register_path(
embedding_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash", "key": "xyzzy"}
embedding_file, {"name": "banana_sushi", "source": "fake/repo_id", "key": "xyzzy"}
)
model_record = store.get_model(key)
assert model_record.name == "banana_sushi"
assert model_record.source == "fake/repo_id"
assert model_record.current_hash == "New Hash"
assert model_record.key == "xyzzy"

View File

@ -3,28 +3,28 @@ Test the refactored model config classes.
"""
from hashlib import sha256
from typing import Any
from typing import Any, Optional
import pytest
from pydantic import ValidationError
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import (
DuplicateModelException,
ModelRecordOrderBy,
ModelRecordServiceBase,
ModelRecordServiceSQL,
UnknownModelException,
)
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.config import (
BaseModelType,
MainCheckpointConfig,
MainDiffusersConfig,
ModelFormat,
ModelSourceType,
ModelType,
TextualInversionConfig,
VaeDiffusersConfig,
TextualInversionFileConfig,
VAEDiffusersConfig,
)
from invokeai.backend.model_manager.metadata import BaseMetadata
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.fixtures.sqlite_database import create_mock_sqlite_database
@ -37,90 +37,76 @@ def store(
config = InvokeAIAppConfig(root=datadir)
logger = InvokeAILogger.get_logger(config=config)
db = create_mock_sqlite_database(config, logger)
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
return ModelRecordServiceSQL(db)
def example_config() -> TextualInversionConfig:
return TextualInversionConfig(
def example_ti_config(key: Optional[str] = None) -> TextualInversionFileConfig:
config = TextualInversionFileConfig(
source="test/source/",
source_type=ModelSourceType.Path,
path="/tmp/pokemon.bin",
name="old name",
base=BaseModelType("sd-1"),
type=ModelType("embedding"),
format="embedding_file",
original_hash="ABC123",
base=BaseModelType.StableDiffusion1,
type=ModelType.TextualInversion,
format=ModelFormat.EmbeddingFile,
hash="ABC123",
)
if key is not None:
config.key = key
return config
def test_type(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", config)
config = example_ti_config("key1")
store.add_model(config)
config1 = store.get_model("key1")
assert type(config1) == TextualInversionConfig
assert isinstance(config1, TextualInversionFileConfig)
def test_add(store: ModelRecordServiceBase):
raw = {
"path": "/tmp/foo.ckpt",
"name": "model1",
"base": BaseModelType("sd-1"),
"type": "main",
"config": "/tmp/foo.yaml",
"variant": "normal",
"format": "checkpoint",
"original_hash": "111222333444",
}
store.add_model("key1", raw)
config1 = store.get_model("key1")
assert config1 is not None
assert type(config1) == MainCheckpointConfig
assert config1.base == BaseModelType("sd-1")
assert config1.name == "model1"
assert config1.original_hash == "111222333444"
assert config1.current_hash is None
def test_dup(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", example_config())
def test_raises_on_violating_uniqueness(store: ModelRecordServiceBase):
# Models have a uniqueness constraint by their name, base and type
config1 = example_ti_config("key1")
config2 = config1.model_copy(deep=True)
config2.key = "key2"
store.add_model(config1)
with pytest.raises(DuplicateModelException):
store.add_model("key1", config)
store.add_model(config1)
with pytest.raises(DuplicateModelException):
store.add_model("key2", config)
store.add_model(config2)
def test_update(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", config)
def test_model_records_updates_model(store: ModelRecordServiceBase):
config = example_ti_config("key1")
store.add_model(config)
config = store.get_model("key1")
assert config.name == "old name"
config.name = "new name"
store.update_model("key1", config)
new_name = "new name"
changes = ModelRecordChanges(name=new_name)
store.update_model(config.key, changes)
new_config = store.get_model("key1")
assert new_config.name == "new name"
assert new_config.name == new_name
def test_rename(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", config)
def test_model_records_rejects_invalid_changes(store: ModelRecordServiceBase):
config = example_ti_config("key1")
store.add_model(config)
config = store.get_model("key1")
assert config.name == "old name"
store.rename_model("key1", "new name")
new_config = store.get_model("key1")
assert new_config.name == "new name"
# upcast_attention is an invalid field for TIs
changes = ModelRecordChanges(upcast_attention=True)
with pytest.raises(ValidationError):
store.update_model(config.key, changes)
def test_unknown_key(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", config)
config = example_ti_config("key1")
store.add_model(config)
with pytest.raises(UnknownModelException):
store.update_model("unknown_key", config)
store.update_model("unknown_key", ModelRecordChanges())
def test_delete(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", config)
config = example_ti_config("key1")
store.add_model(config)
config = store.get_model("key1")
store.del_model("key1")
with pytest.raises(UnknownModelException):
@ -128,49 +114,58 @@ def test_delete(store: ModelRecordServiceBase):
def test_exists(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", config)
config = example_ti_config("key1")
store.add_model(config)
assert store.exists("key1")
assert not store.exists("key2")
def test_filter(store: ModelRecordServiceBase):
config1 = MainDiffusersConfig(
key="config1",
path="/tmp/config1",
name="config1",
base=BaseModelType("sd-1"),
type=ModelType("main"),
original_hash="CONFIG1HASH",
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
hash="CONFIG1HASH",
source="test/source",
source_type=ModelSourceType.Path,
)
config2 = MainDiffusersConfig(
key="config2",
path="/tmp/config2",
name="config2",
base=BaseModelType("sd-1"),
type=ModelType("main"),
original_hash="CONFIG2HASH",
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
hash="CONFIG2HASH",
source="test/source",
source_type=ModelSourceType.Path,
)
config3 = VaeDiffusersConfig(
config3 = VAEDiffusersConfig(
key="config3",
path="/tmp/config3",
name="config3",
base=BaseModelType("sd-2"),
type=ModelType("vae"),
original_hash="CONFIG3HASH",
type=ModelType.VAE,
hash="CONFIG3HASH",
source="test/source",
source_type=ModelSourceType.Path,
)
for c in config1, config2, config3:
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
matches = store.search_by_attr(model_type=ModelType("main"))
store.add_model(c)
matches = store.search_by_attr(model_type=ModelType.Main)
assert len(matches) == 2
assert matches[0].name in {"config1", "config2"}
matches = store.search_by_attr(model_type=ModelType("vae"))
matches = store.search_by_attr(model_type=ModelType.VAE)
assert len(matches) == 1
assert matches[0].name == "config3"
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
assert matches[0].key == "config3"
assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back
matches = store.search_by_hash("CONFIG1HASH")
assert len(matches) == 1
assert matches[0].original_hash == "CONFIG1HASH"
assert matches[0].hash == "CONFIG1HASH"
matches = store.all_models()
assert len(matches) == 3
@ -179,143 +174,116 @@ def test_filter(store: ModelRecordServiceBase):
def test_unique(store: ModelRecordServiceBase):
config1 = MainDiffusersConfig(
path="/tmp/config1",
base=BaseModelType("sd-1"),
type=ModelType("main"),
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
name="nonuniquename",
original_hash="CONFIG1HASH",
hash="CONFIG1HASH",
source="test/source/",
source_type=ModelSourceType.Path,
)
config2 = MainDiffusersConfig(
path="/tmp/config2",
base=BaseModelType("sd-2"),
type=ModelType("main"),
type=ModelType.Main,
name="nonuniquename",
original_hash="CONFIG1HASH",
hash="CONFIG1HASH",
source="test/source/",
source_type=ModelSourceType.Path,
)
config3 = VaeDiffusersConfig(
config3 = VAEDiffusersConfig(
path="/tmp/config3",
base=BaseModelType("sd-2"),
type=ModelType("vae"),
type=ModelType.VAE,
name="nonuniquename",
original_hash="CONFIG1HASH",
hash="CONFIG1HASH",
source="test/source/",
source_type=ModelSourceType.Path,
)
config4 = MainDiffusersConfig(
path="/tmp/config4",
base=BaseModelType("sd-1"),
type=ModelType("main"),
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
name="nonuniquename",
original_hash="CONFIG1HASH",
hash="CONFIG1HASH",
source="test/source/",
source_type=ModelSourceType.Path,
)
# config1, config2 and config3 are compatible because they have unique combos
# of name, type and base
for c in config1, config2, config3:
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
c.key = sha256(c.path.encode("utf-8")).hexdigest()
store.add_model(c)
# config4 clashes with config1 and should raise an integrity error
with pytest.raises(DuplicateModelException):
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), config4)
config4.key = sha256(config4.path.encode("utf-8")).hexdigest()
store.add_model(config4)
def test_filter_2(store: ModelRecordServiceBase):
config1 = MainDiffusersConfig(
path="/tmp/config1",
name="config1",
base=BaseModelType("sd-1"),
type=ModelType("main"),
original_hash="CONFIG1HASH",
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
hash="CONFIG1HASH",
source="test/source/",
source_type=ModelSourceType.Path,
)
config2 = MainDiffusersConfig(
path="/tmp/config2",
name="config2",
base=BaseModelType("sd-1"),
type=ModelType("main"),
original_hash="CONFIG2HASH",
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
hash="CONFIG2HASH",
source="test/source/",
source_type=ModelSourceType.Path,
)
config3 = MainDiffusersConfig(
path="/tmp/config3",
name="dup_name1",
base=BaseModelType("sd-2"),
type=ModelType("main"),
original_hash="CONFIG3HASH",
type=ModelType.Main,
hash="CONFIG3HASH",
source="test/source/",
source_type=ModelSourceType.Path,
)
config4 = MainDiffusersConfig(
path="/tmp/config4",
name="dup_name1",
base=BaseModelType("sdxl"),
type=ModelType("main"),
original_hash="CONFIG3HASH",
type=ModelType.Main,
hash="CONFIG3HASH",
source="test/source/",
source_type=ModelSourceType.Path,
)
config5 = VaeDiffusersConfig(
config5 = VAEDiffusersConfig(
path="/tmp/config5",
name="dup_name1",
base=BaseModelType("sd-1"),
type=ModelType("vae"),
original_hash="CONFIG3HASH",
base=BaseModelType.StableDiffusion1,
type=ModelType.VAE,
hash="CONFIG3HASH",
source="test/source/",
source_type=ModelSourceType.Path,
)
for c in config1, config2, config3, config4, config5:
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
store.add_model(c)
matches = store.search_by_attr(
model_type=ModelType("main"),
model_type=ModelType.Main,
model_name="dup_name1",
)
assert len(matches) == 2
matches = store.search_by_attr(
base_model=BaseModelType("sd-1"),
model_type=ModelType("main"),
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.Main,
)
assert len(matches) == 2
matches = store.search_by_attr(
base_model=BaseModelType("sd-1"),
model_type=ModelType("vae"),
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.VAE,
model_name="dup_name1",
)
assert len(matches) == 1
def test_summary(mm2_record_store: ModelRecordServiceSQL) -> None:
# The fixture provides us with five configs.
for x in range(1, 5):
key = f"test_config_{x}"
name = f"name_{x}"
author = f"author_{x}"
tags = {f"tag{y}" for y in range(1, x)}
mm2_record_store.metadata_store.add_metadata(
model_key=key, metadata=BaseMetadata(name=name, author=author, tags=tags)
)
# sanity check that the tags sent in all right
assert mm2_record_store.get_metadata("test_config_3").tags == {"tag1", "tag2"}
assert mm2_record_store.get_metadata("test_config_4").tags == {"tag1", "tag2", "tag3"}
# get summary
summary1 = mm2_record_store.list_models(page=0, per_page=100)
assert summary1.page == 0
assert summary1.pages == 1
assert summary1.per_page == 100
assert summary1.total == 5
assert len(summary1.items) == 5
assert summary1.items[0].name == "test5" # lora / sd-1 / diffusers / test5
# find test_config_3
config3 = [x for x in summary1.items if x.key == "test_config_3"][0]
assert config3.description == "This is test 3"
assert config3.tags == {"tag1", "tag2"}
# find test_config_5
config5 = [x for x in summary1.items if x.key == "test_config_5"][0]
assert config5.tags == set()
assert config5.description == ""
# test paging
summary2 = mm2_record_store.list_models(page=1, per_page=2)
assert summary2.page == 1
assert summary2.per_page == 2
assert summary2.pages == 3
assert summary1.items[2].name == summary2.items[0].name
# test sorting
summary = mm2_record_store.list_models(page=0, per_page=100, order_by=ModelRecordOrderBy.Name)
print(summary.items)
assert summary.items[0].name == "model1"
assert summary.items[-1].name == "test5"

View File

@ -18,12 +18,17 @@ from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.backend.model_manager.config import (
BaseModelType,
LoRADiffusersConfig,
MainCheckpointConfig,
MainDiffusersConfig,
ModelFormat,
ModelSourceType,
ModelType,
ModelVariantType,
VAEDiffusersConfig,
)
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
from invokeai.backend.util.logging import InvokeAILogger
@ -107,11 +112,6 @@ def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> Downloa
return download_queue
@pytest.fixture
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
return mm2_record_store.metadata_store
@pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
ram_cache = ModelCache(
@ -137,7 +137,7 @@ def mm2_installer(
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
events = DummyEventService()
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
store = ModelRecordServiceSQL(db)
installer = ModelInstallService(
app_config=mm2_app_config,
@ -160,61 +160,71 @@ def mm2_installer(
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=mm2_app_config)
db = create_mock_sqlite_database(mm2_app_config, logger)
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
store = ModelRecordServiceSQL(db)
# add five simple config records to the database
raw1 = {
"path": "/tmp/foo1",
"format": ModelFormat("diffusers"),
"name": "test2",
"base": BaseModelType("sd-2"),
"type": ModelType("vae"),
"original_hash": "111222333444",
"source": "stabilityai/sdxl-vae",
}
raw2 = {
"path": "/tmp/foo2.ckpt",
"name": "model1",
"format": ModelFormat("checkpoint"),
"base": BaseModelType("sd-1"),
"type": "main",
"config": "/tmp/foo.yaml",
"variant": "normal",
"original_hash": "111222333444",
"source": "https://civitai.com/models/206883/split",
}
raw3 = {
"path": "/tmp/foo3",
"format": ModelFormat("diffusers"),
"name": "test3",
"base": BaseModelType("sdxl"),
"type": ModelType("main"),
"original_hash": "111222333444",
"source": "author3/model3",
"description": "This is test 3",
}
raw4 = {
"path": "/tmp/foo4",
"format": ModelFormat("diffusers"),
"name": "test4",
"base": BaseModelType("sdxl"),
"type": ModelType("lora"),
"original_hash": "111222333444",
"source": "author4/model4",
}
raw5 = {
"path": "/tmp/foo5",
"format": ModelFormat("diffusers"),
"name": "test5",
"base": BaseModelType("sd-1"),
"type": ModelType("lora"),
"original_hash": "111222333444",
"source": "author4/model5",
}
store.add_model("test_config_1", raw1)
store.add_model("test_config_2", raw2)
store.add_model("test_config_3", raw3)
store.add_model("test_config_4", raw4)
store.add_model("test_config_5", raw5)
config1 = VAEDiffusersConfig(
key="test_config_1",
path="/tmp/foo1",
format=ModelFormat.Diffusers,
name="test2",
base=BaseModelType.StableDiffusion2,
type=ModelType.VAE,
hash="111222333444",
source="stabilityai/sdxl-vae",
source_type=ModelSourceType.HFRepoID,
)
config2 = MainCheckpointConfig(
key="test_config_2",
path="/tmp/foo2.ckpt",
name="model1",
format=ModelFormat.Checkpoint,
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
config_path="/tmp/foo.yaml",
variant=ModelVariantType.Normal,
hash="111222333444",
source="https://civitai.com/models/206883/split",
source_type=ModelSourceType.CivitAI,
)
config3 = MainDiffusersConfig(
key="test_config_3",
path="/tmp/foo3",
format=ModelFormat.Diffusers,
name="test3",
base=BaseModelType.StableDiffusionXL,
type=ModelType.Main,
hash="111222333444",
source="author3/model3",
description="This is test 3",
source_type=ModelSourceType.HFRepoID,
)
config4 = LoRADiffusersConfig(
key="test_config_4",
path="/tmp/foo4",
format=ModelFormat.Diffusers,
name="test4",
base=BaseModelType.StableDiffusionXL,
type=ModelType.LoRA,
hash="111222333444",
source="author4/model4",
source_type=ModelSourceType.HFRepoID,
)
config5 = LoRADiffusersConfig(
key="test_config_5",
path="/tmp/foo5",
format=ModelFormat.Diffusers,
name="test5",
base=BaseModelType.StableDiffusion1,
type=ModelType.LoRA,
hash="111222333444",
source="author4/model5",
source_type=ModelSourceType.HFRepoID,
)
store.add_model(config1)
store.add_model(config2)
store.add_model(config3)
store.add_model(config4)
store.add_model(config5)
return store

View File

@ -1,202 +0,0 @@
"""
Test model metadata fetching and storage.
"""
import datetime
from pathlib import Path
import pytest
from pydantic.networks import HttpUrl
from requests.sessions import Session
from invokeai.app.services.model_metadata import ModelMetadataStoreBase
from invokeai.backend.model_manager.config import ModelRepoVariant
from invokeai.backend.model_manager.metadata import (
CivitaiMetadata,
CivitaiMetadataFetch,
CommercialUsage,
HuggingFaceMetadata,
HuggingFaceMetadataFetch,
UnknownMetadataException,
)
from invokeai.backend.model_manager.util import select_hf_files
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None:
tags = {"text-to-image", "diffusers"}
input_metadata = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
tags=tags,
id="stabilityai/sdxl-vae",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
mm2_metadata_store.add_metadata("test_config_1", input_metadata)
output_metadata = mm2_metadata_store.get_metadata("test_config_1")
assert input_metadata == output_metadata
with pytest.raises(UnknownMetadataException):
mm2_metadata_store.add_metadata("unknown_key", input_metadata)
assert mm2_metadata_store.list_tags() == tags
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None:
input_metadata = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
tags={"text-to-image", "diffusers"},
id="stabilityai/sdxl-vae",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
mm2_metadata_store.add_metadata("test_config_1", input_metadata)
input_metadata.name = "new-name"
mm2_metadata_store.update_metadata("test_config_1", input_metadata)
output_metadata = mm2_metadata_store.get_metadata("test_config_1")
assert output_metadata.name == "new-name"
assert input_metadata == output_metadata
def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None:
metadata1 = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
tags={"text-to-image", "diffusers"},
id="stabilityai/sdxl-vae",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata2 = HuggingFaceMetadata(
name="model2",
author="stabilityai",
tags={"text-to-image", "diffusers", "community-contributed"},
id="author2/model2",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata3 = HuggingFaceMetadata(
name="model3",
author="author3",
tags={"text-to-image", "checkpoint", "community-contributed"},
id="author3/model3",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
mm2_metadata_store.add_metadata("test_config_1", metadata1)
mm2_metadata_store.add_metadata("test_config_2", metadata2)
mm2_metadata_store.add_metadata("test_config_3", metadata3)
matches = mm2_metadata_store.search_by_author("stabilityai")
assert len(matches) == 2
assert "test_config_1" in matches
assert "test_config_2" in matches
matches = mm2_metadata_store.search_by_author("Sherlock Holmes")
assert not matches
matches = mm2_metadata_store.search_by_name("model3")
assert len(matches) == 1
assert "test_config_3" in matches
matches = mm2_metadata_store.search_by_tag({"text-to-image"})
assert len(matches) == 3
matches = mm2_metadata_store.search_by_tag({"text-to-image", "diffusers"})
assert len(matches) == 2
assert "test_config_1" in matches
assert "test_config_2" in matches
matches = mm2_metadata_store.search_by_tag({"checkpoint", "community-contributed"})
assert len(matches) == 1
assert "test_config_3" in matches
# does the tag table update correctly?
matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
assert not matches
assert mm2_metadata_store.list_tags() == {"text-to-image", "diffusers", "community-contributed", "checkpoint"}
metadata3.tags.add("licensed-for-commercial-use")
mm2_metadata_store.update_metadata("test_config_3", metadata3)
assert mm2_metadata_store.list_tags() == {
"text-to-image",
"diffusers",
"community-contributed",
"checkpoint",
"licensed-for-commercial-use",
}
matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
assert len(matches) == 1
def test_metadata_civitai_fetch(mm2_session: Session) -> None:
fetcher = CivitaiMetadataFetch(mm2_session)
metadata = fetcher.from_url(HttpUrl("https://civitai.com/models/215485/SDXL-turbo"))
assert isinstance(metadata, CivitaiMetadata)
assert metadata.id == 215485
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
assert CommercialUsage("RentCivit") in metadata.restrictions.AllowCommercialUse
assert metadata.version_id == 242807
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
def test_metadata_hf_fetch(mm2_session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
assert isinstance(metadata, HuggingFaceMetadata)
assert metadata.author == "test_author" # this is not the same as the original
assert metadata.files
assert metadata.tags == {
"diffusers",
"onnx",
"safetensors",
"text-to-image",
"license:other",
"has_space",
"diffusers:StableDiffusionXLPipeline",
"region:us",
}
def test_metadata_hf_filter(mm2_session: Session) -> None:
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
assert isinstance(metadata, HuggingFaceMetadata)
files = [x.path for x in metadata.files]
fp16_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16"))
assert Path("sdxl-turbo/text_encoder/model.fp16.safetensors") in fp16_files
assert Path("sdxl-turbo/text_encoder/model.safetensors") not in fp16_files
fp32_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp32"))
assert Path("sdxl-turbo/text_encoder/model.safetensors") in fp32_files
assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in fp32_files
onnx_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("onnx"))
assert Path("sdxl-turbo/text_encoder/model.onnx") in onnx_files
assert Path("sdxl-turbo/text_encoder/model.safetensors") not in onnx_files
default_files = select_hf_files.filter_files(files)
assert Path("sdxl-turbo/text_encoder/model.safetensors") in default_files
assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in default_files
openvino_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("openvino"))
print(openvino_files)
assert len(openvino_files) == 0
flax_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("flax"))
print(flax_files)
assert not flax_files
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(
HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo-nofp16")
)
assert isinstance(metadata, HuggingFaceMetadata)
files = [x.path for x in metadata.files]
filtered_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16"))
assert (
Path("sdxl-turbo-nofp16/text_encoder/model.safetensors") in filtered_files
) # confirm that default is returned
assert Path("sdxl-turbo-nofp16/text_encoder/model.16.safetensors") not in filtered_files
def test_metadata_hf_urls(mm2_session: Session) -> None:
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
assert isinstance(metadata, HuggingFaceMetadata)

View File

@ -104,7 +104,7 @@ def sdxl_base_files() -> List[Path]:
],
),
(
ModelRepoVariant.DEFAULT,
ModelRepoVariant.Default,
[
"model_index.json",
"scheduler/scheduler_config.json",
@ -129,7 +129,7 @@ def sdxl_base_files() -> List[Path]:
],
),
(
ModelRepoVariant.OPENVINO,
ModelRepoVariant.OpenVINO,
[
"model_index.json",
"scheduler/scheduler_config.json",
@ -211,7 +211,7 @@ def sdxl_base_files() -> List[Path]:
],
),
(
ModelRepoVariant.FLAX,
ModelRepoVariant.Flax,
[
"model_index.json",
"scheduler/scheduler_config.json",
@ -235,7 +235,94 @@ def sdxl_base_files() -> List[Path]:
),
],
)
def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[Path]) -> None:
def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[str]) -> None:
print(f"testing variant {variant}")
filtered_files = filter_files(sdxl_base_files, variant)
assert set(filtered_files) == {Path(x) for x in expected_list}
@pytest.fixture
def sd15_test_files() -> list[Path]:
return [
Path(f)
for f in [
"feature_extractor/preprocessor_config.json",
"safety_checker/config.json",
"safety_checker/model.fp16.safetensors",
"safety_checker/model.safetensors",
"safety_checker/pytorch_model.bin",
"safety_checker/pytorch_model.fp16.bin",
"scheduler/scheduler_config.json",
"text_encoder/config.json",
"text_encoder/model.fp16.safetensors",
"text_encoder/model.safetensors",
"text_encoder/pytorch_model.bin",
"text_encoder/pytorch_model.fp16.bin",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"unet/config.json",
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
"unet/diffusion_pytorch_model.non_ema.bin",
"unet/diffusion_pytorch_model.non_ema.safetensors",
"unet/diffusion_pytorch_model.safetensors",
"vae/config.json",
"vae/diffusion_pytorch_model.bin",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"vae/diffusion_pytorch_model.safetensors",
]
]
@pytest.mark.parametrize(
"variant,expected_files",
[
(
ModelRepoVariant.FP16,
[
"feature_extractor/preprocessor_config.json",
"safety_checker/config.json",
"safety_checker/model.fp16.safetensors",
"scheduler/scheduler_config.json",
"text_encoder/config.json",
"text_encoder/model.fp16.safetensors",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"unet/config.json",
"unet/diffusion_pytorch_model.fp16.safetensors",
"vae/config.json",
"vae/diffusion_pytorch_model.fp16.safetensors",
],
),
(
ModelRepoVariant.FP32,
[
"feature_extractor/preprocessor_config.json",
"safety_checker/config.json",
"safety_checker/model.safetensors",
"scheduler/scheduler_config.json",
"text_encoder/config.json",
"text_encoder/model.safetensors",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"unet/config.json",
"unet/diffusion_pytorch_model.safetensors",
"vae/config.json",
"vae/diffusion_pytorch_model.safetensors",
],
),
],
)
def test_select_multiple_weights(
sd15_test_files: list[Path], variant: ModelRepoVariant, expected_files: list[str]
) -> None:
filtered_files = filter_files(sd15_test_files, variant)
assert set(filtered_files) == {Path(f) for f in expected_files}

View File

@ -21,7 +21,7 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat
base_type = probe.get_base_type()
assert base_type == expected_type
repo_variant = probe.get_repo_variant()
assert repo_variant == ModelRepoVariant.DEFAULT
assert repo_variant == ModelRepoVariant.Default
def test_repo_variant(datadir: Path):