Compare commits

...

94 Commits

Author SHA1 Message Date
cd3f5f30dc Run ruff format 2024-03-05 16:38:55 -05:00
71ee28ac12 Refractor session runner, move profiling back to processor, create abstract class for session runners, create path for passing in custom session runner to default session processor 2024-03-05 16:01:47 -05:00
46c904d08a Rename graph processor to session runner to better describe what it's doing, add before/after callbacks for sessions 2024-03-05 16:01:47 -05:00
7d5a88b69d Move graph processor into session_processor_default 2024-03-05 16:01:47 -05:00
afa4df1991 Separate the logic that actually runs a graph in the session_processor into its own class 2024-03-05 16:01:47 -05:00
e30cb4b52f updates for defaultModel (#5866)
* move defaultModel logic to modelsLoaded and update to work for key instead of name/base/type string

* lint fix

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2024-03-05 09:55:22 -05:00
ba1f6bf926 chore: lint 2024-03-05 23:50:19 +11:00
4a9cca6c2d fix(ui): format model API response data 2024-03-05 23:50:19 +11:00
b0275700b3 refactor(ui): compute prompt trigger options in the component
We can derive the valid trigger options in the component without needing to lift the options list into global state.
2024-03-05 23:50:19 +11:00
8319aca5f9 chore(ui): typegen 2024-03-05 23:50:19 +11:00
51a604f907 pkg(ui): do not fix knip in lint:fix script 2024-03-05 23:50:19 +11:00
7515d73628 make trigger phrases a list of options and add lora name as description to appear in dropdown 2024-03-05 23:50:19 +11:00
2c453aa531 fix type error 2024-03-05 23:50:19 +11:00
2cca6e4c76 check if lora is enabled before adding trigger phrases 2024-03-05 23:50:19 +11:00
ef171e890a use a listener to recalculate trigger phrases when model or lora list changes 2024-03-05 23:50:19 +11:00
caafbf2f0d only show trigger phrase settings on main and lora 2024-03-05 23:50:19 +11:00
2db5eaf907 lint fix 2024-03-05 23:50:19 +11:00
f234bf6256 cleanup 2024-03-05 23:50:19 +11:00
cfa78b4052 adapt embedding popover to work for trigger phrases also 2024-03-05 23:50:19 +11:00
ba1dd4b02b UI in MM to create trigger phrases 2024-03-05 23:50:19 +11:00
bcf58cac59 feat(mm): add config to skip model hash
This is useful for when you are using a memory DB and do not want to wait for all models to be hashed on startup.
2024-03-05 23:50:19 +11:00
e866d90ab2 tidy(mm): remove unused method on probe 2024-03-05 23:50:19 +11:00
e8797787cf fix(mm): fix incorrect calls to update_model 2024-03-05 23:50:19 +11:00
0082ecb22b feat(mm): add path to ModelRecordChanges 2024-03-05 23:50:19 +11:00
656839fcd1 fix(mm): fix typing on heuristic_import 2024-03-05 23:50:19 +11:00
99407c899f feat(ui): update UI to use new model config backend
- Update all queries
- Remove Advanced Add
- Removed un-editable, internal-only model attributes from model edit UI (e.g. format, repo variant, model type)
- Update model tags so the list refreshes when a model installs
- Rename some queries, components, variables, types to match backend
- Fix divide-by-zero in install queue
2024-03-05 23:50:19 +11:00
48119d9010 revert(mm): restore convert route 2024-03-05 23:50:19 +11:00
7c9128b253 tidy(mm): use canonical capitalization for all model-related enums, classes
For example, "Lora" -> "LoRA", "Vae" -> "VAE".
2024-03-05 23:50:19 +11:00
4f9bb00275 tidy(api): tidy mm routes
Rename MM routes to be consistent:
- "import" -> "install"
- "model_record" -> "model"

Comment several unused routes while I work (may end up removing them?):
- list model summary (we use the search route instead)
- add model record
- convert model
- merge models
2024-03-05 23:50:19 +11:00
78895b3e80 fix(mm): add missing inplace parameter to model install abc 2024-03-05 23:50:19 +11:00
3030a34b88 fix(mm): make type and format required in openapi schema for model config 2024-03-05 23:50:19 +11:00
58fa9c2fac fix(mm): do not allow extra fields on ModelRecordChanges 2024-03-05 23:50:19 +11:00
a8b6635050 fix(mm): make key required in openapi schema for model config 2024-03-05 23:50:19 +11:00
6829610a71 tests: rename "example_config" -> "example_it_config" 2024-03-05 23:50:19 +11:00
5551cf8ac4 feat(mm): revise update_model to use ModelRecordChanges 2024-03-05 23:50:19 +11:00
37b969d339 tidy(mm): add default_settings to model config 2024-03-05 23:50:19 +11:00
c953e61294 tidy(mm): "trigger_words" -> "trigger_phrases" 2024-03-05 23:50:19 +11:00
93dd3c848e tidy(mm): remove unused code in select_hf_files.py 2024-03-05 23:50:19 +11:00
02bde7bb75 tests: fix test_hf_model_select::test_select_multiple_weights on windows 2024-03-05 23:50:19 +11:00
3391c19926 chore: ruff 2024-03-05 23:50:19 +11:00
0f60b1ced4 fix(mm): use .value for model config discriminators
There is a breaking change in python 3.11 related to how enums with `str` as a mixin are formatted. This appears to have not caused any grief for us until now.

Re-jigger the discriminator setup to use `.value` so everything works on both python 3.10 and 3.11.
2024-03-05 23:50:19 +11:00
44c40d7d1a refactor(mm): remove unused metadata logic, fix tests
- Metadata is merged with the config. We can simplify the MM substantially and remove the handling for metadata.
- Per discussion, we don't have an ETA for frontend implementation of tags, and with the realization that the tags from CivitAI are largely useless, there's no reason to keep tags in the MM right now. When we are ready to implement tags on the frontend, we can refer back to the implementation here and use it if it supports the design.
- Fix all tests.
2024-03-05 23:50:19 +11:00
0b9a212363 tests: remove 60s timeout for tests
This makes it very difficult to troubleshoot tests. Our github actions now have timeouts, so there's no risk of a test stalling for ages.
2024-03-05 23:50:19 +11:00
c3aa985c93 refactor(mm): get metadata working 2024-03-05 23:50:19 +11:00
7cb0da1f66 refactor(mm): wip schema changes 2024-03-05 23:50:19 +11:00
3534366146 fix(mm): fix extraneous downloaded files in diffusers
Sometimes, diffusers model components (tokenizer, unet, etc.) have multiple weights files in the same directory.

In this situation, we assume the files are different versions of the same weights. For example, we may have multiple
formats (`.bin`, `.safetensors`) with different precisions. When downloading model files, we want to select only
the best of these files for the requested format and precision/variant.

The previous logic assumed that each model weights file would have the same base filename, but this assumption was
not always true. The logic is revised score each file and choose the best scoring file, resulting in only a single
file being downloaded for each submodel/subdirectory.
2024-03-05 23:50:19 +11:00
f2b5f8753f tidy(mm): remove json_schema_extra from config - not needed 2024-03-05 23:50:19 +11:00
f13f5984c0 fix(mm): update db schema & migration 2024-03-05 23:50:19 +11:00
94e1e64296 chore: ruff 2024-03-05 23:50:19 +11:00
2411bf53c0 tidy(mm): better descriptions for model configs 2024-03-05 23:50:19 +11:00
9378e47a06 feat(mm): add source_type to model configs 2024-03-05 23:50:19 +11:00
4471ea8ad1 refactor(mm): simplify model metadata schemas 2024-03-05 23:50:19 +11:00
2c835fd550 refactor(mm): WIP db schema 2024-03-05 23:50:19 +11:00
61b737bb9f tidy(mm): remove update method from ModelConfigBase
It's only used in the soon-to-be-removed model merge logic
2024-03-05 23:50:19 +11:00
a8cd3dfc99 refactor(mm): add models table (schema WIP), rename "original_hash" -> "hash" 2024-03-05 23:50:19 +11:00
0cce582f2f tidy(mm): remove current_hash 2024-03-05 23:50:19 +11:00
4347d1c7f7 tests(mm): fix some objects in tests 2024-03-05 23:50:19 +11:00
bd4fd9693d tidy(mm): rename ckpt "last_modified" -> "converted_at"
Clarify what this timestamp means
2024-03-05 23:50:19 +11:00
9b40c28144 tidy(mm): rename ckpy "config" -> "config_path" 2024-03-05 23:50:19 +11:00
16a5d718bf fix(mm): add config field to ckpt vaes 2024-03-05 23:50:19 +11:00
76cbc745e1 refactor(mm): add CheckpointConfigBase for all ckpt models 2024-03-05 23:50:19 +11:00
0a614943f6 fix(mm): fix broken get_model_discriminator_value 2024-03-05 23:50:19 +11:00
e426096d32 fix(mm): misc typing fixes for model loaders 2024-03-05 23:50:19 +11:00
c561cd751f fix(mm): use correct import path for ConfigMixin, ModelMixin 2024-03-05 23:50:19 +11:00
af9298f0ef tidy(mm): tidy class names in config.py 2024-03-05 23:50:19 +11:00
5b74117836 fix(mm): use generic for model loader registry
This preserves the typing for classes using the decorator
2024-03-05 23:50:19 +11:00
38474c9797 fix(mm): use correct import path for ModelMixin 2024-03-05 23:50:19 +11:00
b880a31039 refactor(mm): remove ztsnr_training field on _MainConfig
This is used to determine the CFG Rescale Multiplier setting. We'll handle this in the UI as a default setting.
2024-03-05 23:50:19 +11:00
dd31bc4586 refactor(mm): remove vae field on _MainConfig
We will handle default VAE selection in the UI.
2024-03-05 23:50:19 +11:00
316573df2d feat(mm): use callable discriminator for AnyModelConfig union 2024-03-05 23:50:19 +11:00
8b34f5298c Default model settings (#5850)
* UI in MM to create trigger phrases

* add scheduler and vaePrecision to config

* UI for configuring default settings for models'

* hook MM default model settings up to API

* add button to set default settings in parameters

* pull out trigger phrases

* back-end for default settings

* lint

* remove log;
gi

* ruff

* ruff format

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2024-03-04 09:39:03 -05:00
893bcd16fc Next: Allow in place local installs of models 2024-03-04 23:11:41 +11:00
f6028a4c61 Log a stack trace for invocation errors. 2024-03-04 23:01:56 +11:00
264aee3ffa translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2024-03-04 21:39:46 +11:00
4deb60f365 translationBot(ui): update translation (Italian)
Currently translated at 98.0% (1442 of 1470 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-03-04 21:39:46 +11:00
B N
f2d5fb176f translationBot(ui): update translation (German)
Currently translated at 80.4% (1183 of 1470 strings)

Co-authored-by: B N <berndnieschalk@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2024-03-04 21:39:46 +11:00
94005b5501 add button to navigate to model manager if tab is enabled 2024-03-03 19:50:50 -05:00
02dc1a8780 consolidate tabs for main model and concepts in generation panel 2024-03-03 19:50:50 -05:00
ef958568ac Update Transformers 4.37.2 -> 4.38.2 2024-03-03 19:41:33 -05:00
48e323d887 docs: added both create mask nodes to defaultNodes 2024-03-03 12:58:47 -05:00
735857479d fix(canvas): use corrected mask for pasteback 2024-03-03 12:58:47 -05:00
2f372d9b18 tests(mm): update tests to reflect using UUID for key 2024-03-03 14:32:14 +11:00
554d175792 feat(mm): improved model hash class
- Use memory view for hashlib algorithms (closer to python 3.11's filehash API in hashlib)
- Remove `sha1_fast` (realized it doesn't even hash the whole file, it just does the first block)
- Add support for custom file filters
- Update docstrings
- Update tests
2024-03-03 14:32:14 +11:00
ae99428883 fix(mm): use UUIDv4 for key
This changes the functionality of this PR to only use the updated hashing for model hashes with a UUID for the key.
2024-03-03 14:32:14 +11:00
863ce00712 tests(mm): add tests for ModelHash 2024-03-03 14:32:14 +11:00
86982f3059 feat(mm): make ModelHash instantiatable, taking an algorithm as arg 2024-03-03 14:32:14 +11:00
ec8ed530a7 feat(mm): modularize ModelHash to facilitate testing 2024-03-03 14:32:14 +11:00
982076d7d7 feat(mm): add hashing algos to ModelHash
- Some algos are slow, so it is now just called ModelHash
- Added all hashlib algos, plus BLAKE3 and the fast (but incorrect) SHA1 algo
2024-03-03 14:32:14 +11:00
2e4672f931 feat(mm): make hash.py a script for testing 2024-03-03 14:32:14 +11:00
908e915a71 feat(mm): use blake3 for hashing 2024-03-03 14:32:14 +11:00
a72056e0df make model key assignment deterministic
- When installing, model keys are now calculated from the model contents.
- .safetensors, .ckpt and other single file models are hashed with sha1
- The contents of diffusers directories are hashed using imohash (faster)

fixup yaml->sql db migration script to assign deterministic key

- this commit also detects and assigns the correct image encoder for
  ip adapter models.
2024-03-03 14:32:14 +11:00
d8d7ddf43a Remove attention map saving (#5845)
## What type of PR is this? (check all applicable)

- [x] Refactor
- [ ] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because

## Description

Attention map saving was a feature that existed a long time ago in
Invoke (>1 year ago). This PR strips out a bunch of dead code that still
remains from that feature and is polluting our diffusion implementation.

This change should not have any functional effect on the app.

## QA Instructions, Screenshots, Recordings

I did a quick smoke test of SD and SDXL image generation. All of the
deleted code was unused, so the risk should be relatively low.

## Merge Plan

- [x] Change target branch to `main` before merging.

## Added/updated tests?

- [ ] Yes
- [x] No: This PR just deletes a bunch of unused code.
2024-03-02 11:15:25 -05:00
cc45007dc4 Remove unused code for attention map saving. 2024-03-02 08:25:41 -05:00
73bec56c59 Delete unused functions from shared_invokeai_diffusion.py. 2024-03-02 08:25:41 -05:00
132 changed files with 4002 additions and 4958 deletions

View File

@ -32,7 +32,6 @@ model. These are the:
Responsible for loading a model from disk Responsible for loading a model from disk
into RAM and VRAM and getting it ready for inference. into RAM and VRAM and getting it ready for inference.
## Location of the Code ## Location of the Code
The four main services can be found in The four main services can be found in
@ -67,19 +66,17 @@ provides the following fields:
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator | | `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
| `base_model` | BaseModelType | The base model that the model is compatible with | | `base_model` | BaseModelType | The base model that the model is compatible with |
| `path` | str | Location of model on disk | | `path` | str | Location of model on disk |
| `original_hash` | str | Hash of the model when it was first installed | | `hash` | str | Hash of the model |
| `current_hash` | str | Most recent hash of the model's contents |
| `description` | str | Human-readable description of the model (optional) | | `description` | str | Human-readable description of the model (optional) |
| `source` | str | Model's source URL or repo id (optional) | | `source` | str | Model's source URL or repo id (optional) |
The `key` is a unique 32-character random ID which was generated at The `key` is a unique 32-character random ID which was generated at
install time. The `original_hash` field stores a hash of the model's install time. The `hash` field stores a hash of the model's
contents at install time obtained by sampling several parts of the contents at install time obtained by sampling several parts of the
model's files using the `imohash` library. Over the course of the model's files using the `imohash` library. Over the course of the
model's lifetime it may be transformed in various ways, such as model's lifetime it may be transformed in various ways, such as
changing its precision or converting it from a .safetensors to a changing its precision or converting it from a .safetensors to a
diffusers model. When this happens, `original_hash` is unchanged, but diffusers model.
`current_hash` is updated to indicate the current contents.
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that `ModelType`, `ModelFormat` and `BaseModelType` are string enums that
are defined in `invokeai.backend.model_manager.config`. They are also are defined in `invokeai.backend.model_manager.config`. They are also
@ -94,7 +91,6 @@ The `path` field can be absolute or relative. If relative, it is taken
to be relative to the `models_dir` setting in the user's to be relative to the `models_dir` setting in the user's
`invokeai.yaml` file. `invokeai.yaml` file.
### CheckpointConfig ### CheckpointConfig
This adds support for checkpoint configurations, and adds the This adds support for checkpoint configurations, and adds the
@ -217,20 +213,20 @@ for use in the InvokeAI web server. Its signature is:
``` ```
def open( def open(
cls, cls,
config: InvokeAIAppConfig, config: InvokeAIAppConfig,
conn: Optional[sqlite3.Connection] = None, conn: Optional[sqlite3.Connection] = None,
lock: Optional[threading.Lock] = None lock: Optional[threading.Lock] = None
) -> Union[ModelRecordServiceSQL, ModelRecordServiceFile]: ) -> Union[ModelRecordServiceSQL, ModelRecordServiceFile]:
``` ```
The way it works is as follows: The way it works is as follows:
1. Retrieve the value of the `model_config_db` option from the user's 1. Retrieve the value of the `model_config_db` option from the user's
`invokeai.yaml` config file. `invokeai.yaml` config file.
2. If `model_config_db` is `auto` (the default), then: 2. If `model_config_db` is `auto` (the default), then:
- Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object * Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
opened on the passed connection and lock. opened on the passed connection and lock.
- Open up a new connection to `databases/invokeai.db` if `conn` * Open up a new connection to `databases/invokeai.db` if `conn`
and/or `lock` are missing (see note below). and/or `lock` are missing (see note below).
3. If `model_config_db` is a Path, then use `from_db_file` 3. If `model_config_db` is a Path, then use `from_db_file`
to return the appropriate type of ModelRecordService. to return the appropriate type of ModelRecordService.
@ -255,7 +251,7 @@ store = ModelRecordServiceBase.open(config, db_conn, lock)
Configurations can be retrieved in several ways. Configurations can be retrieved in several ways.
#### get_model(key) -> AnyModelConfig: #### get_model(key) -> AnyModelConfig
The basic functionality is to call the record store object's The basic functionality is to call the record store object's
`get_model()` method with the desired model's unique key. It returns `get_model()` method with the desired model's unique key. It returns
@ -272,28 +268,28 @@ print(model_conf.path)
If the key is unrecognized, this call raises an If the key is unrecognized, this call raises an
`UnknownModelException`. `UnknownModelException`.
#### exists(key) -> AnyModelConfig: #### exists(key) -> AnyModelConfig
Returns True if a model with the given key exists in the databsae. Returns True if a model with the given key exists in the databsae.
#### search_by_path(path) -> AnyModelConfig: #### search_by_path(path) -> AnyModelConfig
Returns the configuration of the model whose path is `path`. The path Returns the configuration of the model whose path is `path`. The path
is matched using a simple string comparison and won't correctly match is matched using a simple string comparison and won't correctly match
models referred to by different paths (e.g. using symbolic links). models referred to by different paths (e.g. using symbolic links).
#### search_by_name(name, base, type) -> List[AnyModelConfig]: #### search_by_name(name, base, type) -> List[AnyModelConfig]
This method searches for models that match some combination of `name`, This method searches for models that match some combination of `name`,
`BaseType` and `ModelType`. Calling without any arguments will return `BaseType` and `ModelType`. Calling without any arguments will return
all the models in the database. all the models in the database.
#### all_models() -> List[AnyModelConfig]: #### all_models() -> List[AnyModelConfig]
Return all the model configs in the database. Exactly equivalent to Return all the model configs in the database. Exactly equivalent to
calling `search_by_name()` with no arguments. calling `search_by_name()` with no arguments.
#### search_by_tag(tags) -> List[AnyModelConfig]: #### search_by_tag(tags) -> List[AnyModelConfig]
`tags` is a list of strings. This method returns a list of model `tags` is a list of strings. This method returns a list of model
configs that contain all of the given tags. Examples: configs that contain all of the given tags. Examples:
@ -312,11 +308,11 @@ commercializable_models = [x for x in store.all_models() \
if x.license.contains('allowCommercialUse=Sell')] if x.license.contains('allowCommercialUse=Sell')]
``` ```
#### version() -> str: #### version() -> str
Returns the version of the database, currently at `3.2` Returns the version of the database, currently at `3.2`
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase: #### model_info_by_name(name, base_model, model_type) -> ModelConfigBase
This method exists to ease the transition from the previous version of This method exists to ease the transition from the previous version of
the model manager, in which `get_model()` took the three arguments the model manager, in which `get_model()` took the three arguments
@ -337,7 +333,7 @@ model and pass its key to `get_model()`.
Several methods allow you to create and update stored model config Several methods allow you to create and update stored model config
records. records.
#### add_model(key, config) -> AnyModelConfig: #### add_model(key, config) -> AnyModelConfig
Given a key and a configuration, this will add the model's Given a key and a configuration, this will add the model's
configuration record to the database. `config` can either be a subclass of configuration record to the database. `config` can either be a subclass of
@ -352,7 +348,7 @@ model with the same key is already in the database, or an
`InvalidModelConfigException` if a dict was passed and Pydantic `InvalidModelConfigException` if a dict was passed and Pydantic
experienced a parse or validation error. experienced a parse or validation error.
### update_model(key, config) -> AnyModelConfig: ### update_model(key, config) -> AnyModelConfig
Given a key and a configuration, this will update the model Given a key and a configuration, this will update the model
configuration record in the database. `config` can be either a configuration record in the database. `config` can be either a
@ -370,31 +366,31 @@ The `ModelInstallService` class implements the
shop for all your model install needs. It provides the following shop for all your model install needs. It provides the following
functionality: functionality:
- Registering a model config record for a model already located on the * Registering a model config record for a model already located on the
local filesystem, without moving it or changing its path. local filesystem, without moving it or changing its path.
- Installing a model alreadiy located on the local filesystem, by * Installing a model alreadiy located on the local filesystem, by
moving it into the InvokeAI root directory under the moving it into the InvokeAI root directory under the
`models` folder (or wherever config parameter `models_dir` `models` folder (or wherever config parameter `models_dir`
specifies). specifies).
- Probing of models to determine their type, base type and other key * Probing of models to determine their type, base type and other key
information. information.
- Interface with the InvokeAI event bus to provide status updates on * Interface with the InvokeAI event bus to provide status updates on
the download, installation and registration process. the download, installation and registration process.
- Downloading a model from an arbitrary URL and installing it in * Downloading a model from an arbitrary URL and installing it in
`models_dir`. `models_dir`.
- Special handling for Civitai model URLs which allow the user to * Special handling for Civitai model URLs which allow the user to
paste in a model page's URL or download link paste in a model page's URL or download link
- Special handling for HuggingFace repo_ids to recursively download * Special handling for HuggingFace repo_ids to recursively download
the contents of the repository, paying attention to alternative the contents of the repository, paying attention to alternative
variants such as fp16. variants such as fp16.
- Saving tags and other metadata about the model into the invokeai database * Saving tags and other metadata about the model into the invokeai database
when fetching from a repo that provides that type of information, when fetching from a repo that provides that type of information,
(currently only Civitai and HuggingFace). (currently only Civitai and HuggingFace).
@ -427,8 +423,8 @@ queue.start()
installer = ModelInstallService(app_config=config, installer = ModelInstallService(app_config=config,
record_store=record_store, record_store=record_store,
download_queue=queue download_queue=queue
) )
installer.start() installer.start()
``` ```
@ -443,7 +439,6 @@ required parameters:
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object | | `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) | |`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
Once initialized, the installer will provide the following methods: Once initialized, the installer will provide the following methods:
#### install_job = installer.heuristic_import(source, [config], [access_token]) #### install_job = installer.heuristic_import(source, [config], [access_token])
@ -457,12 +452,12 @@ The `source` is a string that can be any of these forms
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`) 1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`) 2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
3. A HuggingFace repo_id with any of the following formats: 3. A HuggingFace repo_id with any of the following formats:
- `model/name` -- entire model * `model/name` -- entire model
- `model/name:fp32` -- entire model, using the fp32 variant * `model/name:fp32` -- entire model, using the fp32 variant
- `model/name:fp16:vae` -- vae submodel, using the fp16 variant * `model/name:fp16:vae` -- vae submodel, using the fp16 variant
- `model/name::vae` -- vae submodel, using default precision * `model/name::vae` -- vae submodel, using default precision
- `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant * `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
- `model/name::path/to/model.safetensors` -- an individual model file, default variant * `model/name::path/to/model.safetensors` -- an individual model file, default variant
Note that by specifying a relative path to the top of the HuggingFace Note that by specifying a relative path to the top of the HuggingFace
repo, you can download and install arbitrary models files. repo, you can download and install arbitrary models files.
@ -491,9 +486,9 @@ following illustrates basic usage:
``` ```
from invokeai.app.services.model_install import ( from invokeai.app.services.model_install import (
LocalModelSource, LocalModelSource,
HFModelSource, HFModelSource,
URLModelSource, URLModelSource,
) )
source1 = LocalModelSource(path='/opt/models/sushi.safetensors') # a local safetensors file source1 = LocalModelSource(path='/opt/models/sushi.safetensors') # a local safetensors file
@ -513,12 +508,12 @@ for source in [source1, source2, source3, source4, source5, source6, source7]:
source2job = installer.wait_for_installs(timeout=120) source2job = installer.wait_for_installs(timeout=120)
for source in sources: for source in sources:
job = source2job[source] job = source2job[source]
if job.complete: if job.complete:
model_config = job.config_out model_config = job.config_out
model_key = model_config.key model_key = model_config.key
print(f"{source} installed as {model_key}") print(f"{source} installed as {model_key}")
elif job.errored: elif job.errored:
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}") print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
``` ```
@ -566,7 +561,6 @@ details.
This is used for a model that is located on a locally-accessible Posix This is used for a model that is located on a locally-accessible Posix
filesystem, such as a local disk or networked fileshare. filesystem, such as a local disk or networked fileshare.
| **Argument** | **Type** | **Default** | **Description** | | **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------| |------------------|------------------------------|-------------|-------------------------------------------|
| `path` | str | Path | None | Path to the model file or directory | | `path` | str | Path | None | Path to the model file or directory |
@ -625,7 +619,6 @@ HuggingFace has the most complicated `ModelSource` structure:
| `subfolder` | Path | None | Look for the model in a subfolder of the repo. | | `subfolder` | Path | None | Look for the model in a subfolder of the repo. |
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. | | `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`. The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`.
The `variant` is one of the various diffusers formats that HuggingFace The `variant` is one of the various diffusers formats that HuggingFace
@ -661,7 +654,6 @@ in. To download these files, you must provide an
`HfFolder.get_token()` will be called to fill it in with the cached `HfFolder.get_token()` will be called to fill it in with the cached
one. one.
#### Monitoring the install job process #### Monitoring the install job process
When you create an install job with `import_model()`, it launches the When you create an install job with `import_model()`, it launches the
@ -682,7 +674,6 @@ The `ModelInstallJob` class has the following structure:
| `error_type` | `str` | Name of the exception that led to an error status | | `error_type` | `str` | Name of the exception that led to an error status |
| `error` | `str` | Traceback of the error | | `error` | `str` | Traceback of the error |
If the `event_bus` argument was provided, events will also be If the `event_bus` argument was provided, events will also be
broadcast to the InvokeAI event bus. The events will appear on the bus broadcast to the InvokeAI event bus. The events will appear on the bus
as an event of type `EventServiceBase.model_event`, a timestamp and as an event of type `EventServiceBase.model_event`, a timestamp and
@ -702,14 +693,13 @@ following keys:
| `total_bytes` | int | Total size of all the files that make up the model | | `total_bytes` | int | Total size of all the files that make up the model |
| `parts` | List[Dict]| Information on the progress of the individual files that make up the model | | `parts` | List[Dict]| Information on the progress of the individual files that make up the model |
The parts is a list of dictionaries that give information on each of The parts is a list of dictionaries that give information on each of
the components pieces of the download. The dictionary's keys are the components pieces of the download. The dictionary's keys are
`source`, `local_path`, `bytes` and `total_bytes`, and correspond to `source`, `local_path`, `bytes` and `total_bytes`, and correspond to
the like-named keys in the main event. the like-named keys in the main event.
Note that downloading events will not be issued for local models, and Note that downloading events will not be issued for local models, and
that downloading events occur *before* the running event. that downloading events occur _before_ the running event.
##### `model_install_running` ##### `model_install_running`
@ -752,7 +742,6 @@ properties: `waiting`, `downloading`, `running`, `complete`, `errored`
and `cancelled`, as well as `in_terminal_state`. The last will return and `cancelled`, as well as `in_terminal_state`. The last will return
True if the job is in the complete, errored or cancelled states. True if the job is in the complete, errored or cancelled states.
#### Model configuration and probing #### Model configuration and probing
The install service uses the `invokeai.backend.model_manager.probe` The install service uses the `invokeai.backend.model_manager.probe`
@ -774,11 +763,11 @@ attributes. Here is an example of setting the
``` ```
install_job = installer.import_model( install_job = installer.import_model(
source=HFModelSource(repo_id='stabilityai/stable-diffusion-2-1',variant='fp32'), source=HFModelSource(repo_id='stabilityai/stable-diffusion-2-1',variant='fp32'),
config=dict( config=dict(
prediction_type=SchedulerPredictionType('v_prediction') prediction_type=SchedulerPredictionType('v_prediction')
name='stable diffusion 2 base model', name='stable diffusion 2 base model',
) )
) )
``` ```
### Other installer methods ### Other installer methods
@ -862,7 +851,6 @@ This method is similar to `unregister()`, but also unconditionally
deletes the corresponding model weights file(s), regardless of whether deletes the corresponding model weights file(s), regardless of whether
they are inside or outside the InvokeAI models hierarchy. they are inside or outside the InvokeAI models hierarchy.
#### path = installer.download_and_cache(remote_source, [access_token], [timeout]) #### path = installer.download_and_cache(remote_source, [access_token], [timeout])
This utility routine will download the model file located at source, This utility routine will download the model file located at source,
@ -974,7 +962,7 @@ is in its lifecycle. Values are defined in the string enum
`DownloadJobStatus`, a symbol available from `DownloadJobStatus`, a symbol available from
`invokeai.app.services.download_manager`. Possible values are: `invokeai.app.services.download_manager`. Possible values are:
| **Value** | **String Value** | ** Description ** | | **Value** | **String Value** | **Description** |
|--------------|---------------------|-------------------| |--------------|---------------------|-------------------|
| `IDLE` | idle | Job created, but not submitted to the queue | | `IDLE` | idle | Job created, but not submitted to the queue |
| `ENQUEUED` | enqueued | Job is patiently waiting on the queue | | `ENQUEUED` | enqueued | Job is patiently waiting on the queue |
@ -1040,11 +1028,11 @@ While a job is being downloaded, the queue will emit events at
periodic intervals. A typical series of events during a successful periodic intervals. A typical series of events during a successful
download session will look like this: download session will look like this:
- enqueued * enqueued
- running * running
- running * running
- running * running
- completed * completed
There will be a single enqueued event, followed by one or more running There will be a single enqueued event, followed by one or more running
events, and finally one `completed`, `error` or `cancelled` events, and finally one `completed`, `error` or `cancelled`
@ -1053,12 +1041,12 @@ events.
It is possible for a caller to pause download temporarily, in which It is possible for a caller to pause download temporarily, in which
case the events may look something like this: case the events may look something like this:
- enqueued * enqueued
- running * running
- running * running
- paused * paused
- running * running
- completed * completed
The download queue logs when downloads start and end (unless `quiet` The download queue logs when downloads start and end (unless `quiet`
is set to True at initialization time) but doesn't log any progress is set to True at initialization time) but doesn't log any progress
@ -1120,11 +1108,11 @@ A typical initialization sequence will look like:
from invokeai.app.services.download_manager import DownloadQueueService from invokeai.app.services.download_manager import DownloadQueueService
def log_download_event(job: DownloadJobBase): def log_download_event(job: DownloadJobBase):
logger.info(f'job={job.id}: status={job.status}') logger.info(f'job={job.id}: status={job.status}')
queue = DownloadQueueService( queue = DownloadQueueService(
event_handlers=[log_download_event] event_handlers=[log_download_event]
) )
``` ```
Event handlers can be provided to the queue at initialization time as Event handlers can be provided to the queue at initialization time as
@ -1155,9 +1143,9 @@ To use the former method, follow this example:
``` ```
job = DownloadJobRemoteSource( job = DownloadJobRemoteSource(
source='http://www.civitai.com/models/13456', source='http://www.civitai.com/models/13456',
destination='/tmp/models/', destination='/tmp/models/',
event_handlers=[my_handler1, my_handler2], # if desired event_handlers=[my_handler1, my_handler2], # if desired
) )
queue.submit_download_job(job, start=True) queue.submit_download_job(job, start=True)
``` ```
@ -1172,11 +1160,11 @@ To have the queue create the job for you, follow this example instead:
``` ```
job = queue.create_download_job( job = queue.create_download_job(
source='http://www.civitai.com/models/13456', source='http://www.civitai.com/models/13456',
destdir='/tmp/models/', destdir='/tmp/models/',
filename='my_model.safetensors', filename='my_model.safetensors',
event_handlers=[my_handler1, my_handler2], # if desired event_handlers=[my_handler1, my_handler2], # if desired
start=True, start=True,
) )
``` ```
The `filename` argument forces the downloader to use the specified The `filename` argument forces the downloader to use the specified
@ -1187,7 +1175,6 @@ and is equivalent to manually specifying a destination of
Here is the full list of arguments that can be provided to Here is the full list of arguments that can be provided to
`create_download_job()`: `create_download_job()`:
| **Argument** | **Type** | **Default** | **Description** | | **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------| |------------------|------------------------------|-------------|-------------------------------------------|
| `source` | Union[str, Path, AnyHttpUrl] | | Download remote or local source | | `source` | Union[str, Path, AnyHttpUrl] | | Download remote or local source |
@ -1275,7 +1262,7 @@ for getting the model to run. For example "author" is metadata, while
"type", "base" and "format" are not. The latter fields are part of the "type", "base" and "format" are not. The latter fields are part of the
model's config, as defined in `invokeai.backend.model_manager.config`. model's config, as defined in `invokeai.backend.model_manager.config`.
### Example Usage: ### Example Usage
``` ```
from invokeai.backend.model_manager.metadata import ( from invokeai.backend.model_manager.metadata import (
@ -1328,7 +1315,6 @@ This is the common base class for metadata:
| `author` | str | Model's author | | `author` | str | Model's author |
| `tags` | Set[str] | Model tags | | `tags` | Set[str] | Model tags |
Note that the model config record also has a `name` field. It is Note that the model config record also has a `name` field. It is
intended that the config record version be locally customizable, while intended that the config record version be locally customizable, while
the metadata version is read-only. However, enforcing this is expected the metadata version is read-only. However, enforcing this is expected
@ -1348,7 +1334,6 @@ This descends from `ModelMetadataBase` and adds the following fields:
| `last_modified`| datetime | Date of last commit of this model to the repo | | `last_modified`| datetime | Date of last commit of this model to the repo |
| `files` | List[Path] | List of the files in the model repo | | `files` | List[Path] | List of the files in the model repo |
#### `CivitaiMetadata` #### `CivitaiMetadata`
This descends from `ModelMetadataBase` and adds the following fields: This descends from `ModelMetadataBase` and adds the following fields:
@ -1415,7 +1400,6 @@ testing suite to avoid hitting the internet.
The HuggingFace and Civitai fetcher subclasses add additional The HuggingFace and Civitai fetcher subclasses add additional
repo-specific fetching methods: repo-specific fetching methods:
#### HuggingFaceMetadataFetch #### HuggingFaceMetadataFetch
This overrides its base class `from_json()` method to return a This overrides its base class `from_json()` method to return a
@ -1434,7 +1418,6 @@ retrieves its metadata. Functionally equivalent to `from_id()`, the
only difference is that it returna a `CivitaiMetadata` object rather only difference is that it returna a `CivitaiMetadata` object rather
than an `AnyModelRepoMetadata`. than an `AnyModelRepoMetadata`.
### Metadata Storage ### Metadata Storage
The `ModelMetadataStore` provides a simple facility to store model The `ModelMetadataStore` provides a simple facility to store model
@ -1535,16 +1518,16 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
ram_cache = ModelCache( ram_cache = ModelCache(
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
) )
convert_cache = ModelConvertCache( convert_cache = ModelConvertCache(
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
) )
loader = ModelLoadService( loader = ModelLoadService(
app_config=config, app_config=config,
ram_cache=ram_cache, ram_cache=ram_cache,
convert_cache=convert_cache, convert_cache=convert_cache,
registry=ModelLoaderRegistry registry=ModelLoaderRegistry
) )
``` ```
@ -1567,7 +1550,6 @@ The returned `LoadedModel` object contains a copy of the configuration
record returned by the model record `get_model()` method, as well as record returned by the model record `get_model()` method, as well as
the in-memory loaded model: the in-memory loaded model:
| **Attribute Name** | **Type** | **Description** | | **Attribute Name** | **Type** | **Description** |
|----------------|-----------------|------------------| |----------------|-----------------|------------------|
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. | | `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
@ -1581,7 +1563,6 @@ return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious. models. The others are obvious.
`LoadedModel` acts as a context manager. The context loads the model `LoadedModel` acts as a context manager. The context loads the model
into the execution device (e.g. VRAM on CUDA systems), locks the model into the execution device (e.g. VRAM on CUDA systems), locks the model
in the execution device for the duration of the context, and returns in the execution device for the duration of the context, and returns
@ -1590,14 +1571,14 @@ the model. Use it like this:
``` ```
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with model_info as vae: with model_info as vae:
image = vae.decode(latents)[0] image = vae.decode(latents)[0]
``` ```
`get_model_by_key()` may raise any of the following exceptions: `get_model_by_key()` may raise any of the following exceptions:
- `UnknownModelException` -- key not in database * `UnknownModelException` -- key not in database
- `ModelNotFoundException` -- key in database but model not found at path * `ModelNotFoundException` -- key in database but model not found at path
- `NotImplementedException` -- the loader doesn't know how to load this type of model * `NotImplementedException` -- the loader doesn't know how to load this type of model
### Emitting model loading events ### Emitting model loading events
@ -1609,15 +1590,15 @@ following payload:
``` ```
payload=dict( payload=dict(
queue_id=queue_id, queue_id=queue_id,
queue_item_id=queue_item_id, queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id, queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
model_key=model_key, model_key=model_key,
submodel_type=submodel, submodel_type=submodel,
hash=model_info.hash, hash=model_info.hash,
location=str(model_info.location), location=str(model_info.location),
precision=str(model_info.precision), precision=str(model_info.precision),
) )
``` ```
@ -1724,6 +1705,7 @@ object, or in `context.services.model_manager` from within an
invocation. invocation.
In the examples below, we have retrieved the manager using: In the examples below, we have retrieved the manager using:
``` ```
mm = ApiDependencies.invoker.services.model_manager mm = ApiDependencies.invoker.services.model_manager
``` ```

View File

@ -19,6 +19,8 @@ their descriptions.
| Conditioning Primitive | A conditioning tensor primitive value | | Conditioning Primitive | A conditioning tensor primitive value |
| Content Shuffle Processor | Applies content shuffle processing to image | | Content Shuffle Processor | Applies content shuffle processing to image |
| ControlNet | Collects ControlNet info to pass to other nodes | | ControlNet | Collects ControlNet info to pass to other nodes |
| Create Denoise Mask | Converts a greyscale or transparency image into a mask for denoising. |
| Create Gradient Mask | Creates a mask for Gradient ("soft", "differential") inpainting that gradually expands during denoising. Improves edge coherence. |
| Denoise Latents | Denoises noisy latents to decodable images | | Denoise Latents | Denoises noisy latents to decodable images |
| Divide Integers | Divides two numbers | | Divide Integers | Divides two numbers |
| Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator | | Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator |

View File

@ -26,7 +26,6 @@ from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker from ..services.invoker import Invoker
from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_metadata import ModelMetadataStoreSQL
from ..services.model_records import ModelRecordServiceSQL from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_processor.session_processor_default import DefaultSessionProcessor
@ -93,10 +92,9 @@ class ApiDependencies:
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
) )
download_queue_service = DownloadQueueService(event_bus=events) download_queue_service = DownloadQueueService(event_bus=events)
model_metadata_service = ModelMetadataStoreSQL(db=db)
model_manager = ModelManagerService.build_model_manager( model_manager = ModelManagerService.build_model_manager(
app_config=configuration, app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service), model_record_service=ModelRecordServiceSQL(db=db),
download_queue=download_queue_service, download_queue=download_queue_service,
events=events, events=events,
) )

View File

@ -3,9 +3,7 @@
import pathlib import pathlib
import shutil import shutil
from hashlib import sha1 from typing import Any, Dict, List, Optional
from random import randbytes
from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
@ -15,13 +13,10 @@ from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_records import ( from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException, InvalidModelException,
ModelRecordOrderBy,
ModelSummary,
UnknownModelException, UnknownModelException,
) )
from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
@ -30,8 +25,6 @@ from invokeai.backend.model_manager.config import (
ModelType, ModelType,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -47,15 +40,6 @@ class ModelsList(BaseModel):
model_config = ConfigDict(use_enum_values=True) model_config = ConfigDict(use_enum_values=True)
class ModelTagSet(BaseModel):
"""Return tags for a set of models."""
key: str
name: str
author: str
tags: Set[str]
############################################################################## ##############################################################################
# These are example inputs and outputs that are used in places where Swagger # These are example inputs and outputs that are used in places where Swagger
# is unable to generate a correct example. # is unable to generate a correct example.
@ -66,19 +50,16 @@ example_model_config = {
"base": "sd-1", "base": "sd-1",
"type": "main", "type": "main",
"format": "checkpoint", "format": "checkpoint",
"config": "string", "config_path": "string",
"key": "string", "key": "string",
"original_hash": "string", "hash": "string",
"current_hash": "string",
"description": "string", "description": "string",
"source": "string", "source": "string",
"last_modified": 0, "converted_at": 0,
"vae": "string",
"variant": "normal", "variant": "normal",
"prediction_type": "epsilon", "prediction_type": "epsilon",
"repo_variant": "fp16", "repo_variant": "fp16",
"upcast_attention": False, "upcast_attention": False,
"ztsnr_training": False,
} }
example_model_input = { example_model_input = {
@ -87,50 +68,12 @@ example_model_input = {
"base": "sd-1", "base": "sd-1",
"type": "main", "type": "main",
"format": "checkpoint", "format": "checkpoint",
"config": "configs/stable-diffusion/v1-inference.yaml", "config_path": "configs/stable-diffusion/v1-inference.yaml",
"description": "Model description", "description": "Model description",
"vae": None, "vae": None,
"variant": "normal", "variant": "normal",
} }
example_model_metadata = {
"name": "ip_adapter_sd_image_encoder",
"author": "InvokeAI",
"tags": [
"transformers",
"safetensors",
"clip_vision_model",
"endpoints_compatible",
"region:us",
"has_space",
"license:apache-2.0",
],
"files": [
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
"path": "ip_adapter_sd_image_encoder/README.md",
"size": 628,
"sha256": None,
},
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
"path": "ip_adapter_sd_image_encoder/config.json",
"size": 560,
"sha256": None,
},
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
"path": "ip_adapter_sd_image_encoder/model.safetensors",
"size": 2528373448,
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
},
],
"type": "huggingface",
"id": "InvokeAI/ip_adapter_sd_image_encoder",
"tag_dict": {"license": "apache-2.0"},
"last_modified": "2023-09-23T17:33:25Z",
}
############################################################################## ##############################################################################
# ROUTES # ROUTES
############################################################################## ##############################################################################
@ -210,48 +153,16 @@ async def get_model_record(
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@model_manager_router.get("/summary", operation_id="list_model_summary") # @model_manager_router.get("/summary", operation_id="list_model_summary")
async def list_model_summary( # async def list_model_summary(
page: int = Query(default=0, description="The page to get"), # page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of models per page"), # per_page: int = Query(default=10, description="The number of models per page"),
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"), # order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
) -> PaginatedResults[ModelSummary]: # ) -> PaginatedResults[ModelSummary]:
"""Gets a page of model summary data.""" # """Gets a page of model summary data."""
record_store = ApiDependencies.invoker.services.model_manager.store # record_store = ApiDependencies.invoker.services.model_manager.store
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by) # results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
return results # return results
@model_manager_router.get(
"/i/{key}/metadata",
operation_id="get_model_metadata",
responses={
200: {
"description": "The model metadata was retrieved successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
},
)
async def get_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
) -> Optional[AnyModelRepoMetadata]:
"""Get a model metadata object."""
record_store = ApiDependencies.invoker.services.model_manager.store
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
return result
@model_manager_router.get(
"/tags",
operation_id="list_tags",
)
async def list_tags() -> Set[str]:
"""Get a unique set of all the model tags."""
record_store = ApiDependencies.invoker.services.model_manager.store
result: Set[str] = record_store.list_tags()
return result
class FoundModel(BaseModel): class FoundModel(BaseModel):
@ -323,19 +234,6 @@ async def scan_for_models(
return scan_results return scan_results
@model_manager_router.get(
"/tags/search",
operation_id="search_by_metadata_tags",
)
async def search_by_metadata_tags(
tags: Set[str] = Query(default=None, description="Tags to search for"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_manager.store
results = record_store.search_by_metadata_tag(tags)
return ModelsList(models=results)
@model_manager_router.patch( @model_manager_router.patch(
"/i/{key}", "/i/{key}",
operation_id="update_model_record", operation_id="update_model_record",
@ -352,15 +250,13 @@ async def search_by_metadata_tags(
) )
async def update_model_record( async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")], key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[ changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
) -> AnyModelConfig: ) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" """Update a model's config."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store record_store = ApiDependencies.invoker.services.model_manager.store
try: try:
model_response: AnyModelConfig = record_store.update_model(key, config=info) model_response: AnyModelConfig = record_store.update_model(key, changes=changes)
logger.info(f"Updated model: {key}") logger.info(f"Updated model: {key}")
except UnknownModelException as e: except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@ -372,14 +268,14 @@ async def update_model_record(
@model_manager_router.delete( @model_manager_router.delete(
"/i/{key}", "/i/{key}",
operation_id="del_model_record", operation_id="delete_model",
responses={ responses={
204: {"description": "Model deleted successfully"}, 204: {"description": "Model deleted successfully"},
404: {"description": "Model not found"}, 404: {"description": "Model not found"},
}, },
status_code=204, status_code=204,
) )
async def del_model_record( async def delete_model(
key: str = Path(description="Unique key of model to remove from model registry."), key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response: ) -> Response:
""" """
@ -400,42 +296,39 @@ async def del_model_record(
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@model_manager_router.post( # @model_manager_router.post(
"/i/", # "/i/",
operation_id="add_model_record", # operation_id="add_model_record",
responses={ # responses={
201: { # 201: {
"description": "The model added successfully", # "description": "The model added successfully",
"content": {"application/json": {"example": example_model_config}}, # "content": {"application/json": {"example": example_model_config}},
}, # },
409: {"description": "There is already a model corresponding to this path or repo_id"}, # 409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"}, # 415: {"description": "Unrecognized file/folder format"},
}, # },
status_code=201, # status_code=201,
) # )
async def add_model_record( # async def add_model_record(
config: Annotated[ # config: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) # AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
], # ],
) -> AnyModelConfig: # ) -> AnyModelConfig:
"""Add a model using the configuration information appropriate for its type.""" # """Add a model using the configuration information appropriate for its type."""
logger = ApiDependencies.invoker.services.logger # logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store # record_store = ApiDependencies.invoker.services.model_manager.store
if config.key == "<NOKEY>": # try:
config.key = sha1(randbytes(100)).hexdigest() # record_store.add_model(config)
logger.info(f"Created model {config.key} for {config.name}") # except DuplicateModelException as e:
try: # logger.error(str(e))
record_store.add_model(config.key, config) # raise HTTPException(status_code=409, detail=str(e))
except DuplicateModelException as e: # except InvalidModelException as e:
logger.error(str(e)) # logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) # raise HTTPException(status_code=415)
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
# now fetch it out # # now fetch it out
result: AnyModelConfig = record_store.get_model(config.key) # result: AnyModelConfig = record_store.get_model(config.key)
return result # return result
@model_manager_router.post( @model_manager_router.post(
@ -451,6 +344,7 @@ async def add_model_record(
) )
async def install_model( async def install_model(
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"), source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
# TODO(MM2): Can we type this? # TODO(MM2): Can we type this?
config: Optional[Dict[str, Any]] = Body( config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
@ -493,6 +387,7 @@ async def install_model(
source=source, source=source,
config=config, config=config,
access_token=access_token, access_token=access_token,
inplace=bool(inplace),
) )
logger.info(f"Started installation of {source}") logger.info(f"Started installation of {source}")
except UnknownModelException as e: except UnknownModelException as e:
@ -508,10 +403,10 @@ async def install_model(
@model_manager_router.get( @model_manager_router.get(
"/import", "/install",
operation_id="list_model_install_jobs", operation_id="list_model_installs",
) )
async def list_model_install_jobs() -> List[ModelInstallJob]: async def list_model_installs() -> List[ModelInstallJob]:
"""Return the list of model install jobs. """Return the list of model install jobs.
Install jobs have a numeric `id`, a `status`, and other fields that provide information on Install jobs have a numeric `id`, a `status`, and other fields that provide information on
@ -525,9 +420,8 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
* "cancelled" -- Job was cancelled before completion. * "cancelled" -- Job was cancelled before completion.
Once completed, information about the model such as its size, base Once completed, information about the model such as its size, base
model, type, and metadata can be retrieved from the `config_out` model and type can be retrieved from the `config_out` field. For multi-file models such as diffusers,
field. For multi-file models such as diffusers, information on individual files information on individual files can be retrieved from `download_parts`.
can be retrieved from `download_parts`.
See the example and schema below for more information. See the example and schema below for more information.
""" """
@ -536,7 +430,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
@model_manager_router.get( @model_manager_router.get(
"/import/{id}", "/install/{id}",
operation_id="get_model_install_job", operation_id="get_model_install_job",
responses={ responses={
200: {"description": "Success"}, 200: {"description": "Success"},
@ -556,7 +450,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
@model_manager_router.delete( @model_manager_router.delete(
"/import/{id}", "/install/{id}",
operation_id="cancel_model_install_job", operation_id="cancel_model_install_job",
responses={ responses={
201: {"description": "The job was cancelled successfully"}, 201: {"description": "The job was cancelled successfully"},
@ -574,8 +468,8 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
installer.cancel_job(job) installer.cancel_job(job)
@model_manager_router.patch( @model_manager_router.delete(
"/import", "/install",
operation_id="prune_model_install_jobs", operation_id="prune_model_install_jobs",
responses={ responses={
204: {"description": "All completed and errored jobs have been pruned"}, 204: {"description": "All completed and errored jobs have been pruned"},
@ -654,7 +548,8 @@ async def convert_model(
# temporarily rename the original safetensors file so that there is no naming conflict # temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name original_name = model_config.name
model_config.name = f"{original_name}.DELETE" model_config.name = f"{original_name}.DELETE"
store.update_model(key, config=model_config) changes = ModelRecordChanges(name=model_config.name)
store.update_model(key, changes=changes)
# install the diffusers # install the diffusers
try: try:
@ -663,7 +558,7 @@ async def convert_model(
config={ config={
"name": original_name, "name": original_name,
"description": model_config.description, "description": model_config.description,
"original_hash": model_config.original_hash, "hash": model_config.hash,
"source": model_config.source, "source": model_config.source,
}, },
) )
@ -671,10 +566,6 @@ async def convert_model(
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
# get the original metadata
if orig_metadata := store.get_metadata(key):
store.metadata_store.add_metadata(new_key, orig_metadata)
# delete the original safetensors file # delete the original safetensors file
installer.delete(key) installer.delete(key)
@ -686,66 +577,66 @@ async def convert_model(
return new_config return new_config
@model_manager_router.put( # @model_manager_router.put(
"/merge", # "/merge",
operation_id="merge", # operation_id="merge",
responses={ # responses={
200: { # 200: {
"description": "Model converted successfully", # "description": "Model converted successfully",
"content": {"application/json": {"example": example_model_config}}, # "content": {"application/json": {"example": example_model_config}},
}, # },
400: {"description": "Bad request"}, # 400: {"description": "Bad request"},
404: {"description": "Model not found"}, # 404: {"description": "Model not found"},
409: {"description": "There is already a model registered at this location"}, # 409: {"description": "There is already a model registered at this location"},
}, # },
) # )
async def merge( # async def merge(
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3), # keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None), # merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), # alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
force: bool = Body( # force: bool = Body(
description="Force merging of models created with different versions of diffusers", # description="Force merging of models created with different versions of diffusers",
default=False, # default=False,
), # ),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None), # interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
merge_dest_directory: Optional[str] = Body( # merge_dest_directory: Optional[str] = Body(
description="Save the merged model to the designated directory (with 'merged_model_name' appended)", # description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
default=None, # default=None,
), # ),
) -> AnyModelConfig: # ) -> AnyModelConfig:
""" # """
Merge diffusers models. The process is controlled by a set parameters provided in the body of the request. # Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
``` # ```
Argument Description [default] # Argument Description [default]
-------- ---------------------- # -------- ----------------------
keys List of 2-3 model keys to merge together. All models must use the same base type. # keys List of 2-3 model keys to merge together. All models must use the same base type.
merged_model_name Name for the merged model [Concat model names] # merged_model_name Name for the merged model [Concat model names]
alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] # alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
force If true, force the merge even if the models were generated by different versions of the diffusers library [False] # force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] # interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
merge_dest_directory Specify a directory to store the merged model in [models directory] # merge_dest_directory Specify a directory to store the merged model in [models directory]
``` # ```
""" # """
logger = ApiDependencies.invoker.services.logger # logger = ApiDependencies.invoker.services.logger
try: # try:
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}") # logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None # dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
installer = ApiDependencies.invoker.services.model_manager.install # installer = ApiDependencies.invoker.services.model_manager.install
merger = ModelMerger(installer) # merger = ModelMerger(installer)
model_names = [installer.record_store.get_model(x).name for x in keys] # model_names = [installer.record_store.get_model(x).name for x in keys]
response = merger.merge_diffusion_models_and_save( # response = merger.merge_diffusion_models_and_save(
model_keys=keys, # model_keys=keys,
merged_model_name=merged_model_name or "+".join(model_names), # merged_model_name=merged_model_name or "+".join(model_names),
alpha=alpha, # alpha=alpha,
interp=interp, # interp=interp,
force=force, # force=force,
merge_dest_directory=dest, # merge_dest_directory=dest,
) # )
except UnknownModelException: # except UnknownModelException:
raise HTTPException( # raise HTTPException(
status_code=404, # status_code=404,
detail=f"One or more of the models '{keys}' not found", # detail=f"One or more of the models '{keys}' not found",
) # )
except ValueError as e: # except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) # raise HTTPException(status_code=400, detail=str(e))
return response # return response

View File

@ -173,6 +173,16 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
) )
@invocation_output("gradient_mask_output")
class GradientMaskOutput(BaseInvocationOutput):
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
expanded_mask_area: ImageField = OutputField(
description="Image representing the total gradient area of the mask. For paste-back purposes."
)
@invocation( @invocation(
"create_gradient_mask", "create_gradient_mask",
title="Create Gradient Mask", title="Create Gradient Mask",
@ -193,38 +203,42 @@ class CreateGradientMaskInvocation(BaseInvocation):
) )
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: def invoke(self, context: InvocationContext) -> GradientMaskOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L") mask_image = context.images.get_pil(self.mask.image_name, mode="L")
if self.coherence_mode == "Box Blur": if self.edge_radius > 0:
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius)) if self.coherence_mode == "Box Blur":
else: # Gaussian Blur OR Staged blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation else: # Gaussian Blur OR Staged
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2)) # Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
# redistribute blur so that the edges are 0 and blur out to 1 # redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2 blur_tensor = (blur_tensor - 0.5) * 2
threshold = 1 - self.minimum_denoise threshold = 1 - self.minimum_denoise
if self.coherence_mode == "Staged":
# wherever the blur_tensor is less than fully masked, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
if self.coherence_mode == "Staged":
# wherever the blur_tensor is masked to any degree, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
else: else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
# multiply original mask to force actually masked regions to 0
blur_tensor = mask_tensor * blur_tensor
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1)) mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
return DenoiseMaskOutput.build( # compute a [0, 1] mask from the blur_tensor
mask_name=mask_name, expanded_mask = torch.where((blur_tensor < 1), 0, 1)
masked_latents_name=None, expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
gradient=True, expanded_image_dto = context.images.save(expanded_mask_image)
return GradientMaskOutput(
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
) )
@ -775,10 +789,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_end=self.denoising_end, denoising_end=self.denoising_end,
) )
( result_latents = pipeline.latents_from_embeddings(
result_latents,
result_attention_map_saver,
) = pipeline.latents_from_embeddings(
latents=latents, latents=latents,
timesteps=timesteps, timesteps=timesteps,
init_timestep=init_timestep, init_timestep=init_timestep,

View File

@ -133,7 +133,7 @@ class MainModelLoaderInvocation(BaseInvocation):
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
key=key, key=key,
submodel_type=SubModelType.Vae, submodel_type=SubModelType.VAE,
), ),
), ),
) )

View File

@ -85,7 +85,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
key=model_key, key=model_key,
submodel_type=SubModelType.Vae, submodel_type=SubModelType.VAE,
), ),
), ),
) )
@ -142,7 +142,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
key=model_key, key=model_key,
submodel_type=SubModelType.Vae, submodel_type=SubModelType.VAE,
), ),
), ),
) )

View File

@ -256,6 +256,7 @@ class InvokeAIAppConfig(InvokeAISettings):
profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development) profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development)
profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development) profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development) profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development)
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce startup time.", json_schema_extra=Categories.Development)
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other) version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)

View File

@ -18,10 +18,9 @@ from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..model_metadata import ModelMetadataStoreBase
class InstallStatus(str, Enum): class InstallStatus(str, Enum):
"""State of an install job running in the background.""" """State of an install job running in the background."""
@ -151,6 +150,13 @@ ModelSource = Annotated[
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type") Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
] ]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
CivitaiModelSource: ModelSourceType.CivitAI,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel): class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request.""" """Object that tracks the current status of an install request."""
@ -260,7 +266,6 @@ class ModelInstallServiceBase(ABC):
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase, record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase, download_queue: DownloadQueueServiceBase,
metadata_store: ModelMetadataStoreBase,
event_bus: Optional["EventServiceBase"] = None, event_bus: Optional["EventServiceBase"] = None,
): ):
""" """
@ -347,6 +352,7 @@ class ModelInstallServiceBase(ABC):
source: str, source: str,
config: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
inplace: Optional[bool] = False,
) -> ModelInstallJob: ) -> ModelInstallJob:
r"""Install the indicated model using heuristics to interpret user intentions. r"""Install the indicated model using heuristics to interpret user intentions.
@ -392,7 +398,7 @@ class ModelInstallServiceBase(ABC):
will override corresponding autoassigned probe fields in the will override corresponding autoassigned probe fields in the
model's config record. Use it to override model's config record. Use it to override
`name`, `description`, `base_type`, `model_type`, `format`, `name`, `description`, `base_type`, `model_type`, `format`,
`prediction_type`, `image_size`, and/or `ztsnr_training`. `prediction_type`, and/or `image_size`.
This will download the model located at `source`, This will download the model located at `source`,
probe it, and install it into the models directory. probe it, and install it into the models directory.

View File

@ -7,7 +7,6 @@ import time
from hashlib import sha256 from hashlib import sha256
from pathlib import Path from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
from random import randbytes
from shutil import copyfile, copytree, move, rmtree from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Union
@ -21,11 +20,15 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
CheckpointConfigBase,
InvalidModelConfigException, InvalidModelConfigException,
ModelRepoVariant, ModelRepoVariant,
ModelSourceType,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.metadata import ( from invokeai.backend.model_manager.metadata import (
@ -35,12 +38,14 @@ from invokeai.backend.model_manager.metadata import (
ModelMetadataWithFiles, ModelMetadataWithFiles,
RemoteModelFile, RemoteModelFile,
) )
from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata
from invokeai.backend.model_manager.probe import ModelProbe from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import Chdir, InvokeAILogger from invokeai.backend.util import Chdir, InvokeAILogger
from invokeai.backend.util.devices import choose_precision, choose_torch_device from invokeai.backend.util.devices import choose_precision, choose_torch_device
from .model_install_base import ( from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP,
CivitaiModelSource, CivitaiModelSource,
HFModelSource, HFModelSource,
InstallStatus, InstallStatus,
@ -90,7 +95,6 @@ class ModelInstallService(ModelInstallServiceBase):
self._running = False self._running = False
self._session = session self._session = session
self._next_job_id = 0 self._next_job_id = 0
self._metadata_store = record_store.metadata_store # for convenience
@property @property
def app_config(self) -> InvokeAIAppConfig: # noqa D102 def app_config(self) -> InvokeAIAppConfig: # noqa D102
@ -139,6 +143,7 @@ class ModelInstallService(ModelInstallServiceBase):
config = config or {} config = config or {}
if not config.get("source"): if not config.get("source"):
config["source"] = model_path.resolve().as_posix() config["source"] = model_path.resolve().as_posix()
config["source_type"] = ModelSourceType.Path
return self._register(model_path, config) return self._register(model_path, config)
def install_path( def install_path(
@ -148,11 +153,11 @@ class ModelInstallService(ModelInstallServiceBase):
) -> str: # noqa D102 ) -> str: # noqa D102
model_path = Path(model_path) model_path = Path(model_path)
config = config or {} config = config or {}
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
config["key"] = config.get("key", self._create_key())
info: AnyModelConfig = self._probe_model(Path(model_path), config) if self._app_config.skip_model_hash:
config["hash"] = uuid_string()
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
if preferred_name := config.get("name"): if preferred_name := config.get("name"):
preferred_name = Path(preferred_name).with_suffix(model_path.suffix) preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
@ -178,13 +183,14 @@ class ModelInstallService(ModelInstallServiceBase):
source: str, source: str,
config: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
inplace: Optional[bool] = False,
) -> ModelInstallJob: ) -> ModelInstallJob:
variants = "|".join(ModelRepoVariant.__members__.values()) variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None source_obj: Optional[StringLikeSource] = None
if Path(source).exists(): # A local file or directory if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source)) source_obj = LocalModelSource(path=Path(source), inplace=inplace)
elif match := re.match(hf_repoid_re, source): elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource( source_obj = HFModelSource(
repo_id=match.group(1), repo_id=match.group(1),
@ -373,15 +379,18 @@ class ModelInstallService(ModelInstallServiceBase):
job.bytes = job.total_bytes job.bytes = job.total_bytes
self._signal_job_running(job) self._signal_job_running(job)
job.config_in["source"] = str(job.source) job.config_in["source"] = str(job.source)
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any
if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)):
job.config_in["source_api_response"] = job.source_metadata.api_response
if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_phrases:
job.config_in["trigger_phrases"] = job.source_metadata.trigger_phrases
if job.inplace: if job.inplace:
key = self.register_path(job.local_path, job.config_in) key = self.register_path(job.local_path, job.config_in)
else: else:
key = self.install_path(job.local_path, job.config_in) key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key) job.config_out = self.record_store.get_model(key)
# enter the metadata, if there is any
if job.source_metadata:
self._metadata_store.add_metadata(key, job.source_metadata)
self._signal_job_completed(job) self._signal_job_completed(job)
except InvalidModelConfigException as excp: except InvalidModelConfigException as excp:
@ -467,7 +476,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"Moving {model.name} to {new_path}.") self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path) new_path = self._move_model(old_path, new_path)
model.path = new_path.relative_to(models_dir).as_posix() model.path = new_path.relative_to(models_dir).as_posix()
self.record_store.update_model(key, model) self.record_store.update_model(key, ModelRecordChanges(path=model.path))
return model return model
def _scan_register(self, model: Path) -> bool: def _scan_register(self, model: Path) -> bool:
@ -519,22 +528,14 @@ class ModelInstallService(ModelInstallServiceBase):
move(old_path, new_path) move(old_path, new_path)
return new_path return new_path
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
if config: # used to override probe fields
for key, value in config.items():
setattr(info, key, value)
return info
def _create_key(self) -> str:
return sha256(randbytes(100)).hexdigest()[0:32]
def _register( def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str: ) -> str:
# Note that we may be passed a pre-populated AnyModelConfig object, config = config or {}
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
config["key"] = config.get("key", self._create_key()) if self._app_config.skip_model_hash:
config["hash"] = uuid_string()
info = info or ModelProbe.probe(model_path, config) info = info or ModelProbe.probe(model_path, config)
model_path = model_path.absolute() model_path = model_path.absolute()
@ -544,11 +545,11 @@ class ModelInstallService(ModelInstallServiceBase):
info.path = model_path.as_posix() info.path = model_path.as_posix()
# add 'main' specific fields # add 'main' specific fields
if hasattr(info, "config"): if isinstance(info, CheckpointConfigBase):
# make config relative to our root # make config relative to our root
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve() legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve()
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix() info.config_path = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
self.record_store.add_model(info.key, info) self.record_store.add_model(info)
return info.key return info.key
def _next_id(self) -> int: def _next_id(self) -> int:
@ -569,13 +570,15 @@ class ModelInstallService(ModelInstallServiceBase):
source=source, source=source,
config_in=config or {}, config_in=config or {},
local_path=Path(source.path), local_path=Path(source.path),
inplace=source.inplace, inplace=source.inplace or False,
) )
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
if not source.access_token: if not source.access_token:
self._logger.info("No Civitai access token provided; some models may not be downloadable.") self._logger.info("No Civitai access token provided; some models may not be downloadable.")
metadata = CivitaiMetadataFetch(self._session).from_id(str(source.version_id)) metadata = CivitaiMetadataFetch(self._session, self.app_config.get_config().civitai_api_key).from_id(
str(source.version_id)
)
assert isinstance(metadata, ModelMetadataWithFiles) assert isinstance(metadata, ModelMetadataWithFiles)
remote_files = metadata.download_urls(session=self._session) remote_files = metadata.download_urls(session=self._session)
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files) return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
@ -603,15 +606,17 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
# URLs from Civitai or HuggingFace will be handled specially # URLs from Civitai or HuggingFace will be handled specially
url_patterns = {
r"^https?://civitai.com/": CivitaiMetadataFetch,
r"^https?://huggingface.co/[^/]+/[^/]+$": HuggingFaceMetadataFetch,
}
metadata = None metadata = None
for pattern, fetcher in url_patterns.items(): fetcher = None
if re.match(pattern, str(source.url), re.IGNORECASE): try:
metadata = fetcher(self._session).from_url(source.url) fetcher = self.get_fetcher_from_url(str(source.url))
break except ValueError:
pass
kwargs: dict[str, Any] = {"session": self._session}
if fetcher is CivitaiMetadataFetch:
kwargs["api_key"] = self._app_config.get_config().civitai_api_key
if fetcher is not None:
metadata = fetcher(**kwargs).from_url(source.url)
self._logger.debug(f"metadata={metadata}") self._logger.debug(f"metadata={metadata}")
if metadata and isinstance(metadata, ModelMetadataWithFiles): if metadata and isinstance(metadata, ModelMetadataWithFiles):
remote_files = metadata.download_urls(session=self._session) remote_files = metadata.download_urls(session=self._session)
@ -626,7 +631,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_remote_model( def _import_remote_model(
self, self,
source: ModelSource, source: HFModelSource | CivitaiModelSource | URLModelSource,
remote_files: List[RemoteModelFile], remote_files: List[RemoteModelFile],
metadata: Optional[AnyModelRepoMetadata], metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]], config: Optional[Dict[str, Any]],
@ -654,7 +659,7 @@ class ModelInstallService(ModelInstallServiceBase):
# In the event that there is a subfolder specified in the source, # In the event that there is a subfolder specified in the source,
# we need to remove it from the destination path in order to avoid # we need to remove it from the destination path in order to avoid
# creating unwanted subfolders # creating unwanted subfolders
if hasattr(source, "subfolder") and source.subfolder: if isinstance(source, HFModelSource) and source.subfolder:
root = Path(remote_files[0].path.parts[0]) root = Path(remote_files[0].path.parts[0])
subfolder = root / source.subfolder subfolder = root / source.subfolder
else: else:
@ -841,3 +846,11 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"{job.source}: model installation was cancelled") self._logger.info(f"{job.source}: model installation was cancelled")
if self._event_bus: if self._event_bus:
self._event_bus.emit_model_install_cancelled(str(job.source)) self._event_bus.emit_model_install_cancelled(str(job.source))
@staticmethod
def get_fetcher_from_url(url: str):
if re.match(r"^https?://civitai.com/", url.lower()):
return CivitaiMetadataFetch
elif re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'")

View File

@ -1,9 +0,0 @@
"""Init file for ModelMetadataStoreService module."""
from .metadata_store_base import ModelMetadataStoreBase
from .metadata_store_sql import ModelMetadataStoreSQL
__all__ = [
"ModelMetadataStoreBase",
"ModelMetadataStoreSQL",
]

View File

@ -1,65 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Storage for Model Metadata
"""
from abc import ABC, abstractmethod
from typing import List, Set, Tuple
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class ModelMetadataStoreBase(ABC):
"""Store, search and fetch model metadata retrieved from remote repositories."""
@abstractmethod
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
@abstractmethod
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
@abstractmethod
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
@abstractmethod
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
@abstractmethod
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
@abstractmethod
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
@abstractmethod
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
@abstractmethod
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""

View File

@ -1,222 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Storage for Model Metadata
"""
import sqlite3
from typing import List, Optional, Set, Tuple
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
from .metadata_store_base import ModelMetadataStoreBase
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
"""Store, search and fetch model metadata retrieved from remote repositories."""
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_metadata(
id,
metadata
)
VALUES (?,?);
""",
(
model_key,
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
raise UnknownMetadataException from excp
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata FROM model_metadata
WHERE id=?;
""",
(model_key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownMetadataException("model metadata not found")
return ModelMetadataFetchBase.from_json(rows[0])
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT id,metadata FROM model_metadata;
""",
(),
)
rows = self._cursor.fetchall()
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
json_serialized = metadata.model_dump_json() # turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_metadata
SET
metadata=?
WHERE id=?;
""",
(json_serialized, model_key),
)
if self._cursor.rowcount == 0:
raise UnknownMetadataException("model metadata not found")
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_metadata(model_key)
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
self._cursor.execute(
"""--sql
select tag_text from tags;
"""
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Optional[Set[str]] = None
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.model_id FROM model_tags AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
if matches is None:
matches = model_keys
matches = matches.intersection(model_keys)
except sqlite3.Error as e:
raise e
return matches if matches else set()
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE author=?;
""",
(author,),
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE name=?;
""",
(name,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@ -6,20 +6,19 @@ Abstract base class for storing and retrieving model configuration records.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union from typing import List, Optional, Set, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType
from ..model_metadata import ModelMetadataStoreBase
class DuplicateModelException(Exception): class DuplicateModelException(Exception):
@ -60,11 +59,33 @@ class ModelSummary(BaseModel):
tags: Set[str] = Field(description="tags associated with model") tags: Set[str] = Field(description="tags associated with model")
class ModelRecordChanges(BaseModelExcludeNull):
"""A set of changes to apply to a model."""
# Changes applicable to all models
name: Optional[str] = Field(description="Name of the model.", default=None)
path: Optional[str] = Field(description="Path to the model.", default=None)
description: Optional[str] = Field(description="Model description", default=None)
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
# Checkpoint-specific changes
# TODO(MM2): Should we expose these? Feels footgun-y...
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
prediction_type: Optional[SchedulerPredictionType] = Field(
description="The prediction type of the model.", default=None
)
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
class ModelRecordServiceBase(ABC): class ModelRecordServiceBase(ABC):
"""Abstract base class for storage and retrieval of model configs.""" """Abstract base class for storage and retrieval of model configs."""
@abstractmethod @abstractmethod
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
""" """
Add a model to the database. Add a model to the database.
@ -88,13 +109,12 @@ class ModelRecordServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
""" """
Update the model, returning the updated version. Update the model, returning the updated version.
:param key: Unique key for the model to be updated :param key: Unique key for the model to be updated.
:param config: Model configuration record. Either a dict with the :param changes: A set of changes to apply to this model. Changes are validated before being written.
required fields, or a ModelConfigBase instance.
""" """
pass pass
@ -109,40 +129,6 @@ class ModelRecordServiceBase(ABC):
""" """
pass pass
@property
@abstractmethod
def metadata_store(self) -> ModelMetadataStoreBase:
"""Return a ModelMetadataStore initialized on the same database."""
pass
@abstractmethod
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""
Retrieve metadata (if any) from when model was downloaded from a repo.
:param key: Model key
"""
pass
@abstractmethod
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
"""List metadata for all models that have it."""
pass
@abstractmethod
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
"""
Search model metadata for ones with all listed tags and return their corresponding configs.
:param tags: Set of tags to search for. All tags must be present.
"""
pass
@abstractmethod
def list_tags(self) -> Set[str]:
"""Return a unique set of all the model tags in the metadata database."""
pass
@abstractmethod @abstractmethod
def list_models( def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
@ -217,21 +203,3 @@ class ModelRecordServiceBase(ABC):
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'." f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
) )
return model_configs[0] return model_configs[0]
def rename_model(
self,
key: str,
new_name: str,
) -> AnyModelConfig:
"""
Rename the indicated model. Just a special case of update_model().
In some implementations, renaming the model may involve changing where
it is stored on the filesystem. So this is broken out.
:param key: Model key
:param new_name: New name for model
"""
config = self.get_model(key)
config.name = new_name
return self.update_model(key, config)

View File

@ -43,7 +43,7 @@ import json
import sqlite3 import sqlite3
from math import ceil from math import ceil
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union from typing import List, Optional, Union
from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
@ -53,12 +53,11 @@ from invokeai.backend.model_manager.config import (
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
from ..shared.sqlite.sqlite_database import SqliteDatabase from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import ( from .model_records_base import (
DuplicateModelException, DuplicateModelException,
ModelRecordChanges,
ModelRecordOrderBy, ModelRecordOrderBy,
ModelRecordServiceBase, ModelRecordServiceBase,
ModelSummary, ModelSummary,
@ -69,7 +68,7 @@ from .model_records_base import (
class ModelRecordServiceSQL(ModelRecordServiceBase): class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database.""" """Implementation of the ModelConfigStore ABC using a SQL database."""
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase): def __init__(self, db: SqliteDatabase):
""" """
Initialize a new object from preexisting sqlite3 connection and threading lock objects. Initialize a new object from preexisting sqlite3 connection and threading lock objects.
@ -78,14 +77,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
super().__init__() super().__init__()
self._db = db self._db = db
self._cursor = db.conn.cursor() self._cursor = db.conn.cursor()
self._metadata_store = metadata_store
@property @property
def db(self) -> SqliteDatabase: def db(self) -> SqliteDatabase:
"""Return the underlying database.""" """Return the underlying database."""
return self._db return self._db
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
""" """
Add a model to the database. Add a model to the database.
@ -95,23 +93,19 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions. Can raise DuplicateModelException and InvalidModelConfigException exceptions.
""" """
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock: with self._db.lock:
try: try:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
INSERT INTO model_config ( INSERT INTO models (
id, id,
original_hash,
config config
) )
VALUES (?,?,?); VALUES (?,?);
""", """,
( (
key, config.key,
record.original_hash, config.model_dump_json(),
json_serialized,
), ),
) )
self._db.conn.commit() self._db.conn.commit()
@ -119,12 +113,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
except sqlite3.IntegrityError as e: except sqlite3.IntegrityError as e:
self._db.conn.rollback() self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e): if "UNIQUE constraint failed" in str(e):
if "model_config.path" in str(e): if "models.path" in str(e):
msg = f"A model with path '{record.path}' is already installed" msg = f"A model with path '{config.path}' is already installed"
elif "model_config.name" in str(e): elif "models.name" in str(e):
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed" msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
else: else:
msg = f"A model with key '{key}' is already installed" msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e raise DuplicateModelException(msg) from e
else: else:
raise e raise e
@ -132,7 +126,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db.conn.rollback() self._db.conn.rollback()
raise e raise e
return self.get_model(key) return self.get_model(config.key)
def del_model(self, key: str) -> None: def del_model(self, key: str) -> None:
""" """
@ -146,7 +140,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
try: try:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
DELETE FROM model_config DELETE FROM models
WHERE id=?; WHERE id=?;
""", """,
(key,), (key,),
@ -158,21 +152,20 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db.conn.rollback() self._db.conn.rollback()
raise e raise e
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
""" record = self.get_model(key)
Update the model, returning the updated version.
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
json_serialized = record.model_dump_json()
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock: with self._db.lock:
try: try:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
UPDATE model_config UPDATE models
SET SET
config=? config=?
WHERE id=?; WHERE id=?;
@ -199,7 +192,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock: with self._db.lock:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
SELECT config, strftime('%s',updated_at) FROM model_config SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?; WHERE id=?;
""", """,
(key,), (key,),
@ -220,7 +213,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock: with self._db.lock:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
select count(*) FROM model_config select count(*) FROM models
WHERE id=?; WHERE id=?;
""", """,
(key,), (key,),
@ -246,9 +239,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
If none of the optional filters are passed, will return all If none of the optional filters are passed, will return all
models in the database. models in the database.
""" """
results = [] where_clause: list[str] = []
where_clause = [] bindings: list[str] = []
bindings = []
if model_name: if model_name:
where_clause.append("name=?") where_clause.append("name=?")
bindings.append(model_name) bindings.append(model_name)
@ -265,14 +257,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock: with self._db.lock:
self._cursor.execute( self._cursor.execute(
f"""--sql f"""--sql
select config, strftime('%s',updated_at) FROM model_config SELECT config, strftime('%s',updated_at) FROM models
{where}; {where};
""", """,
tuple(bindings), tuple(bindings),
) )
results = [ result = self._cursor.fetchall()
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall() results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]
]
return results return results
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
@ -281,7 +272,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock: with self._db.lock:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
SELECT config, strftime('%s',updated_at) FROM model_config SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?; WHERE path=?;
""", """,
(str(path),), (str(path),),
@ -292,13 +283,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
return results return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]: def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated original_hash.""" """Return models with the indicated hash."""
results = [] results = []
with self._db.lock: with self._db.lock:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
SELECT config, strftime('%s',updated_at) FROM model_config SELECT config, strftime('%s',updated_at) FROM models
WHERE original_hash=?; WHERE hash=?;
""", """,
(hash,), (hash,),
) )
@ -307,83 +298,35 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
] ]
return results return results
@property
def metadata_store(self) -> ModelMetadataStoreBase:
"""Return a ModelMetadataStore initialized on the same database."""
return self._metadata_store
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""
Retrieve metadata (if any) from when model was downloaded from a repo.
:param key: Model key
"""
store = self.metadata_store
try:
metadata = store.get_metadata(key)
return metadata
except UnknownMetadataException:
return None
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
"""
Search model metadata for ones with all listed tags and return their corresponding configs.
:param tags: Set of tags to search for. All tags must be present.
"""
store = ModelMetadataStoreSQL(self._db)
keys = store.search_by_tag(tags)
return [self.get_model(x) for x in keys]
def list_tags(self) -> Set[str]:
"""Return a unique set of all the model tags in the metadata database."""
store = ModelMetadataStoreSQL(self._db)
return store.list_tags()
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
"""List metadata for all models that have it."""
store = ModelMetadataStoreSQL(self._db)
return store.list_all_metadata()
def list_models( def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
) -> PaginatedResults[ModelSummary]: ) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database.""" """Return a paginated summary listing of each model in the database."""
assert isinstance(order_by, ModelRecordOrderBy)
ordering = { ordering = {
ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name", ModelRecordOrderBy.Default: "type, base, format, name",
ModelRecordOrderBy.Type: "a.type", ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "a.base", ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "a.name", ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "a.format", ModelRecordOrderBy.Format: "format",
} }
def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]:
"""Fix up results so that there are no null values."""
result: Dict[str, Union[str, int, Set[str]]] = {}
for key, item in summary.items():
result[key] = item or ""
result["tags"] = set(json.loads(summary["tags"] or "[]"))
return result
# Lock so that the database isn't updated while we're doing the two queries. # Lock so that the database isn't updated while we're doing the two queries.
with self._db.lock: with self._db.lock:
# query1: get the total number of model configs # query1: get the total number of model configs
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
select count(*) from model_config; select count(*) from models;
""", """,
(), (),
) )
total = int(self._cursor.fetchone()[0]) total = int(self._cursor.fetchone()[0])
# query2: fetch key fields from the join of model_config and model_metadata # query2: fetch key fields
self._cursor.execute( self._cursor.execute(
f"""--sql f"""--sql
SELECT a.id as key, a.type, a.base, a.format, a.name, SELECT config
json_extract(a.config, '$.description') as description, FROM models
json_extract(b.metadata, '$.tags') as tags
FROM model_config AS a
LEFT JOIN model_metadata AS b on a.id=b.id
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ? LIMIT ?
OFFSET ?; OFFSET ?;
@ -394,7 +337,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
), ),
) )
rows = self._cursor.fetchall() rows = self._cursor.fetchall()
items = [ModelSummary.model_validate(_fixup(dict(x))) for x in rows] items = [ModelSummary.model_validate(dict(x)) for x in rows]
return PaginatedResults( return PaginatedResults(
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
) )

View File

@ -1,6 +1,35 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from threading import Event
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
class SessionRunnerBase(ABC):
"""
Base class for session runner.
"""
@abstractmethod
def start(self, services: InvocationServices, cancel_event: Event) -> None:
"""Starts the session runner"""
pass
@abstractmethod
def run(self, queue_item: SessionQueueItem) -> None:
"""Runs the session"""
pass
@abstractmethod
def complete(self, queue_item: SessionQueueItem) -> None:
"""Completes the session"""
pass
@abstractmethod
def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None:
"""Runs an already prepared node on the session"""
pass
class SessionProcessorBase(ABC): class SessionProcessorBase(ABC):

View File

@ -2,13 +2,14 @@ import traceback
from contextlib import suppress from contextlib import suppress
from threading import BoundedSemaphore, Thread from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent from threading import Event as ThreadEvent
from typing import Optional from typing import Callable, Optional, Union
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
@ -16,15 +17,164 @@ from invokeai.app.services.shared.invocation_context import InvocationContextDat
from invokeai.app.util.profiler import Profiler from invokeai.app.util.profiler import Profiler
from ..invoker import Invoker from ..invoker import Invoker
from .session_processor_base import SessionProcessorBase from .session_processor_base import SessionProcessorBase, SessionRunnerBase
from .session_processor_common import SessionProcessorStatus from .session_processor_common import SessionProcessorStatus
class DefaultSessionRunner(SessionRunnerBase):
"""Processes a single session's invocations"""
def __init__(
self,
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
):
self.on_before_run_node = on_before_run_node
self.on_after_run_node = on_after_run_node
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
"""Start the session runner"""
self.services = services
self.cancel_event = cancel_event
def run(self, queue_item: SessionQueueItem):
"""Run the graph"""
if not queue_item.session:
raise ValueError("Queue item has no session")
# Loop over invocations until the session is complete or canceled
while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
# Prepare the next node
invocation = queue_item.session.next()
if invocation is None:
# If there are no more invocations, complete the graph
break
# Build invocation context (the node-facing API
self.run_node(invocation.id, queue_item)
self.complete(queue_item)
def complete(self, queue_item: SessionQueueItem):
"""Complete the graph"""
self.services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
)
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run before a node is executed"""
# Send starting event
self.services.events.emit_invocation_started(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session_id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
)
if self.on_before_run_node is not None:
self.on_before_run_node(invocation, queue_item)
def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run after a node is executed"""
if self.on_after_run_node is not None:
self.on_after_run_node(invocation, queue_item)
def run_node(self, node_id: str, queue_item: SessionQueueItem):
"""Run a single node in the graph"""
# If this error raises a NodeNotFoundError that's handled by the processor
invocation = queue_item.session.execution_graph.get_node(node_id)
try:
self._on_before_run_node(invocation, queue_item)
data = InvocationContextData(
invocation=invocation,
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
queue_item=queue_item,
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
context = build_invocation_context(
data=data,
services=self.services,
cancel_event=self.cancel_event,
)
# Invoke the node
outputs = invocation.invoke_internal(context=context, services=self.services)
# Save outputs and history
queue_item.session.complete(invocation.id, outputs)
self._on_after_run_node(invocation, queue_item)
# Send complete event on successful runs
self.services.events.emit_invocation_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=data.source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
queue_item.session.set_node_error(invocation.id, error)
self.services.logger.error(
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
)
self.services.logger.error(error)
# Send error event
self.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
error_type=e.__class__.__name__,
error=error,
)
class DefaultSessionProcessor(SessionProcessorBase): class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None: """Processes sessions from the session queue"""
def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None:
super().__init__()
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
def start(
self,
invoker: Invoker,
thread_limit: int = 1,
polling_interval: int = 1,
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
) -> None:
self._invoker: Invoker = invoker self._invoker: Invoker = invoker
self._queue_item: Optional[SessionQueueItem] = None self._queue_item: Optional[SessionQueueItem] = None
self._invocation: Optional[BaseInvocation] = None self._invocation: Optional[BaseInvocation] = None
self.on_before_run_session = on_before_run_session
self.on_after_run_session = on_after_run_session
self._resume_event = ThreadEvent() self._resume_event = ThreadEvent()
self._stop_event = ThreadEvent() self._stop_event = ThreadEvent()
@ -59,6 +209,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
"cancel_event": self._cancel_event, "cancel_event": self._cancel_event,
}, },
) )
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
self._thread.start() self._thread.start()
def stop(self, *args, **kwargs) -> None: def stop(self, *args, **kwargs) -> None:
@ -117,130 +268,34 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear() cancel_event.clear()
# If we have a on_before_run_session callback, call it
if self.on_before_run_session is not None:
self.on_before_run_session(self._queue_item)
# If profiling is enabled, start the profiler # If profiling is enabled, start the profiler
if self._profiler is not None: if self._profiler is not None:
self._profiler.start(profile_id=self._queue_item.session_id) self._profiler.start(profile_id=self._queue_item.session_id)
# Prepare invocations and take the first # Run the graph
self._invocation = self._queue_item.session.next() self.session_runner.run(queue_item=self._queue_item)
# Loop over invocations until the session is complete or canceled # If we are profiling, stop the profiler and dump the profile & stats
while self._invocation is not None and not cancel_event.is_set(): if self._profiler:
# get the source node id to provide to clients (the prepared node id is not as useful) profile_path = self._profiler.stop()
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id] stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
# Send starting event graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
self._invoker.services.events.emit_invocation_started(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session_id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
) )
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
try: # we don't care about that - suppress the error.
with self._invoker.services.performance_statistics.collect_stats( with suppress(GESStatsNotFoundError):
self._invocation, self._queue_item.session.id self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
): self._invoker.services.performance_statistics.reset_stats()
# Build invocation context (the node-facing API)
data = InvocationContextData(
invocation=self._invocation,
source_invocation_id=source_invocation_id,
queue_item=self._queue_item,
)
context = build_invocation_context(
data=data,
services=self._invoker.services,
cancel_event=self._cancel_event,
)
# Invoke the node # If we have a on_after_run_session callback, call it
outputs = self._invocation.invoke_internal( if self.on_after_run_session is not None:
context=context, services=self._invoker.services self.on_after_run_session(self._queue_item)
)
# Save outputs and history
self._queue_item.session.complete(self._invocation.id, outputs)
# Send complete event
self._invoker.services.events.emit_invocation_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
self._queue_item.session.set_node_error(self._invocation.id, error)
self._invoker.services.logger.error(
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
)
# Send error event
self._invoker.services.events.emit_invocation_error(
queue_batch_id=self._queue_item.session_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
error_type=e.__class__.__name__,
error=error,
)
pass
# The session is complete if the all invocations are complete or there was an error
if self._queue_item.session.is_complete() or cancel_event.is_set():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
# Set the invocation to None to prepare for the next session
self._invocation = None
else:
# Prepare the next invocation
self._invocation = self._queue_item.session.next()
# The session is complete, immediately poll for next session # The session is complete, immediately poll for next session
self._queue_item = None self._queue_item = None
@ -274,3 +329,4 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.clear() poll_now_event.clear()
self._queue_item = None self._queue_item = None
self._thread_semaphore.release() self._thread_semaphore.release()
self._invoker.services.logger.debug("Session processor stopped")

View File

@ -9,6 +9,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -35,6 +36,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_4()) migrator.register_migration(build_migration_4())
migrator.register_migration(build_migration_5()) migrator.register_migration(build_migration_5())
migrator.register_migration(build_migration_6()) migrator.register_migration(build_migration_6())
migrator.register_migration(build_migration_7())
migrator.run_migrations() migrator.run_migrations()
return db return db

View File

@ -0,0 +1,88 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration7Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._create_models_table(cursor)
self._drop_old_models_tables(cursor)
def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None:
"""Drops the old model_records, model_metadata, model_tags and tags tables."""
tables = ["model_records", "model_metadata", "model_tags", "tags"]
for table in tables:
cursor.execute(f"DROP TABLE IF EXISTS {table};")
def _create_models_table(self, cursor: sqlite3.Cursor) -> None:
"""Creates the v4.0.0 models table."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS models (
id TEXT NOT NULL PRIMARY KEY,
hash TEXT GENERATED ALWAYS as (json_extract(config, '$.hash')) VIRTUAL NOT NULL,
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
description TEXT GENERATED ALWAYS as (json_extract(config, '$.description')) VIRTUAL,
source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL,
source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL,
source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL,
trigger_phrases TEXT GENERATED ALWAYS as (json_extract(config, '$.trigger_phrases')) VIRTUAL,
-- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
]
# Add trigger for `updated_at`.
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS models_updated_at
AFTER UPDATE
ON models FOR EACH ROW
BEGIN
UPDATE models SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
]
# Add indexes for searchable fields
indices = [
"CREATE INDEX IF NOT EXISTS base_index ON models(base);",
"CREATE INDEX IF NOT EXISTS type_index ON models(type);",
"CREATE INDEX IF NOT EXISTS name_index ON models(name);",
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON models(path);",
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def build_migration_7() -> Migration:
"""
Build the migration from database version 6 to 7.
This migration does the following:
- Adds the new models table
- Drops the old model_records, model_metadata, model_tags and tags tables.
- TODO(MM2): Migrates model names and descriptions from `models.yaml` to the new table (?).
"""
migration_7 = Migration(
from_version=6,
to_version=7,
callback=Migration7Callback(),
)
return migration_7

View File

@ -3,7 +3,6 @@
import json import json
import sqlite3 import sqlite3
from hashlib import sha1
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -22,7 +21,7 @@ from invokeai.backend.model_manager.config import (
ModelConfigFactory, ModelConfigFactory,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.hash import FastModelHash from invokeai.backend.model_manager.hash import ModelHash
ModelsValidator = TypeAdapter(AnyModelConfig) ModelsValidator = TypeAdapter(AnyModelConfig)
@ -73,19 +72,27 @@ class MigrateModelYamlToDb1:
base_type, model_type, model_name = str(model_key).split("/") base_type, model_type, model_name = str(model_key).split("/")
try: try:
hash = FastModelHash.hash(self.config.models_path / stanza.path) hash = ModelHash().hash(self.config.models_path / stanza.path)
except OSError: except OSError:
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.") self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
continue continue
assert isinstance(model_key, str)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base"] = BaseModelType(base_type) stanza["base"] = BaseModelType(base_type)
stanza["type"] = ModelType(model_type) stanza["type"] = ModelType(model_type)
stanza["name"] = model_name stanza["name"] = model_name
stanza["original_hash"] = hash stanza["original_hash"] = hash
stanza["current_hash"] = hash stanza["current_hash"] = hash
new_key = hash # deterministic key assignment
# special case for ip adapters, which need the new `image_encoder_model_id` field
if stanza["type"] == ModelType.IPAdapter:
try:
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
self.config.models_path / stanza.path
)
except OSError:
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
continue
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094 new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
@ -95,7 +102,7 @@ class MigrateModelYamlToDb1:
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}") self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
self._update_model(key, new_config) self._update_model(key, new_config)
else: else:
self.logger.info(f"Adding model {model_name} with key {model_key}") self.logger.info(f"Adding model {model_name} with key {new_key}")
self._add_model(new_key, new_config) self._add_model(new_key, new_config)
except DuplicateModelException: except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database") self.logger.warning(f"Model {model_name} is already in the database")
@ -143,9 +150,14 @@ class MigrateModelYamlToDb1:
""", """,
( (
key, key,
record.original_hash, record.hash,
json_serialized, json_serialized,
), ),
) )
except sqlite3.IntegrityError as exc: except sqlite3.IntegrityError as exc:
raise DuplicateModelException(f"{record.name}: model is already in database") from exc raise DuplicateModelException(f"{record.name}: model is already in database") from exc
def _get_image_encoder_model_id(self, model_path: Path) -> str:
with open(model_path / "image_encoder.txt") as f:
encoder = f.read()
return encoder.strip()

View File

@ -17,7 +17,8 @@ class MigrateCallback(Protocol):
See :class:`Migration` for an example. See :class:`Migration` for an example.
""" """
def __call__(self, cursor: sqlite3.Cursor) -> None: ... def __call__(self, cursor: sqlite3.Cursor) -> None:
...
class MigrationError(RuntimeError): class MigrationError(RuntimeError):

View File

@ -1,55 +0,0 @@
import json
from typing import Optional
from pydantic import ValidationError
from invokeai.app.services.shared.graph import Edge
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
"""
Parses raw session string, returning a dict of the graph.
Only the general graph shape is validated; none of the fields are validated.
Any `metadata_accumulator` nodes and edges are removed.
Any validation failure will return None.
"""
graph = json.loads(session_raw).get("graph", None)
# sanity check make sure the graph is at least reasonably shaped
if (
not isinstance(graph, dict)
or "nodes" not in graph
or not isinstance(graph["nodes"], dict)
or "edges" not in graph
or not isinstance(graph["edges"], list)
):
# something has gone terribly awry, return an empty dict
return None
try:
# delete the `metadata_accumulator` node
del graph["nodes"]["metadata_accumulator"]
except KeyError:
# no accumulator node, all good
pass
# delete any edges to or from it
for i, edge in enumerate(graph["edges"]):
try:
# try to parse the edge
Edge(**edge)
except ValidationError:
# something has gone terribly awry, return an empty dict
return None
if (
edge["source"]["node_id"] == "metadata_accumulator"
or edge["destination"]["node_id"] == "metadata_accumulator"
):
del graph["edges"][i]
return graph

View File

@ -25,10 +25,13 @@ from enum import Enum
from typing import Literal, Optional, Type, Union from typing import Literal, Optional, Type, Union
import torch import torch
from diffusers import ModelMixin from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string
from ..raw_model import RawModel from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models # ModelMixin is the base class for all diffusers and transformers models
@ -56,8 +59,8 @@ class ModelType(str, Enum):
ONNX = "onnx" ONNX = "onnx"
Main = "main" Main = "main"
Vae = "vae" VAE = "vae"
Lora = "lora" LoRA = "lora"
ControlNet = "controlnet" # used by model_probe ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding" TextualInversion = "embedding"
IPAdapter = "ip_adapter" IPAdapter = "ip_adapter"
@ -73,9 +76,9 @@ class SubModelType(str, Enum):
TextEncoder2 = "text_encoder_2" TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer" Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2" Tokenizer2 = "tokenizer_2"
Vae = "vae" VAE = "vae"
VaeDecoder = "vae_decoder" VAEDecoder = "vae_decoder"
VaeEncoder = "vae_encoder" VAEEncoder = "vae_encoder"
Scheduler = "scheduler" Scheduler = "scheduler"
SafetyChecker = "safety_checker" SafetyChecker = "safety_checker"
@ -93,8 +96,8 @@ class ModelFormat(str, Enum):
Diffusers = "diffusers" Diffusers = "diffusers"
Checkpoint = "checkpoint" Checkpoint = "checkpoint"
Lycoris = "lycoris" LyCORIS = "lycoris"
Onnx = "onnx" ONNX = "onnx"
Olive = "olive" Olive = "olive"
EmbeddingFile = "embedding_file" EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder" EmbeddingFolder = "embedding_folder"
@ -112,128 +115,187 @@ class SchedulerPredictionType(str, Enum):
class ModelRepoVariant(str, Enum): class ModelRepoVariant(str, Enum):
"""Various hugging face variants on the diffusers format.""" """Various hugging face variants on the diffusers format."""
DEFAULT = "" # model files without "fp16" or other qualifier - empty str Default = "" # model files without "fp16" or other qualifier - empty str
FP16 = "fp16" FP16 = "fp16"
FP32 = "fp32" FP32 = "fp32"
ONNX = "onnx" ONNX = "onnx"
OPENVINO = "openvino" OpenVINO = "openvino"
FLAX = "flax" Flax = "flax"
class ModelSourceType(str, Enum):
"""Model source type."""
Path = "path"
Url = "url"
HFRepoID = "hf_repo_id"
CivitAI = "civitai"
class ModelDefaultSettings(BaseModel):
vae: str | None
vae_precision: str | None
scheduler: SCHEDULER_NAME_VALUES | None
steps: int | None
cfg_scale: float | None
cfg_rescale_multiplier: float | None
class ModelConfigBase(BaseModel): class ModelConfigBase(BaseModel):
"""Base class for model configuration information.""" """Base class for model configuration information."""
path: str = Field(description="filesystem path to the model file or directory") key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
name: str = Field(description="model name") hash: str = Field(description="The hash of the model file(s).")
base: BaseModelType = Field(description="base model") path: str = Field(
type: ModelType = Field(description="type of the model") description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
format: ModelFormat = Field(description="model format") )
key: str = Field(description="unique key for model", default="<NOKEY>") name: str = Field(description="Name of the model.")
original_hash: Optional[str] = Field( base: BaseModelType = Field(description="The base model.")
description="original fasthash of model contents", default=None description: Optional[str] = Field(description="Model description", default=None)
) # this is assigned at install time and will not change source: str = Field(description="The original source of the model (path, URL or repo_id).")
current_hash: Optional[str] = Field( source_type: ModelSourceType = Field(description="The type of source")
description="current fasthash of model contents", default=None source_api_response: Optional[str] = Field(
) # if model is converted or otherwise modified, this will hold updated hash description="The original API response from the source, as stringified JSON.", default=None
description: Optional[str] = Field(description="human readable description of the model", default=None) )
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None) trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time) default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
@staticmethod @staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
schema["required"].extend( schema["required"].extend(["key", "type", "format"])
["key", "base", "type", "format", "original_hash", "current_hash", "source", "last_modified"]
)
model_config = ConfigDict( model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
use_enum_values=False,
validate_assignment=True,
json_schema_extra=json_schema_extra,
)
def update(self, attributes: Dict[str, Any]) -> None:
"""Update the object with fields in dict."""
for key, value in attributes.items():
setattr(self, key, value) # may raise a validation error
class _CheckpointConfig(ModelConfigBase): class CheckpointConfigBase(ModelConfigBase):
"""Model config for checkpoint-style models.""" """Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config: str = Field(description="path to the checkpoint model config file") config_path: str = Field(description="path to the checkpoint model config file")
converted_at: Optional[float] = Field(
description="When this model was last converted to diffusers", default_factory=time.time
)
class _DiffusersConfig(ModelConfigBase): class DiffusersConfigBase(ModelConfigBase):
"""Model config for diffusers-style models.""" """Model config for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
class LoRAConfig(ModelConfigBase): class LoRALyCORISConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models.""" """Model config for LoRA/Lycoris models."""
type: Literal[ModelType.Lora] = ModelType.Lora type: Literal[ModelType.LoRA] = ModelType.LoRA
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers] format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
class VaeCheckpointConfig(ModelConfigBase): class LoRADiffusersConfig(ModelConfigBase):
"""Model config for LoRA/Diffusers models."""
type: Literal[ModelType.LoRA] = ModelType.LoRA
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}")
class VAECheckpointConfig(CheckpointConfigBase):
"""Model config for standalone VAE models.""" """Model config for standalone VAE models."""
type: Literal[ModelType.Vae] = ModelType.Vae type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}")
class VaeDiffusersConfig(ModelConfigBase):
class VAEDiffusersConfig(ModelConfigBase):
"""Model config for standalone VAE models (diffusers version).""" """Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.Vae] = ModelType.Vae type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
class ControlNetDiffusersConfig(_DiffusersConfig):
class ControlNetDiffusersConfig(DiffusersConfigBase):
"""Model config for ControlNet models (diffusers version).""" """Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
class ControlNetCheckpointConfig(_CheckpointConfig):
class ControlNetCheckpointConfig(CheckpointConfigBase):
"""Model config for ControlNet models (diffusers version).""" """Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Checkpoint.value}")
class TextualInversionConfig(ModelConfigBase):
class TextualInversionFileConfig(ModelConfigBase):
"""Model config for textual inversion embeddings.""" """Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder] format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
class _MainConfig(ModelConfigBase): class TextualInversionFolderConfig(ModelConfigBase):
"""Model config for main models.""" """Model config for textual inversion embeddings."""
vae: Optional[str] = Field(default=None) type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
variant: ModelVariantType = ModelVariantType.Normal format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False @staticmethod
ztsnr_training: bool = False def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
class MainCheckpointConfig(_CheckpointConfig, _MainConfig): class MainCheckpointConfig(CheckpointConfigBase):
"""Model config for main checkpoint models.""" """Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main type: Literal[ModelType.Main] = ModelType.Main
variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
class MainDiffusersConfig(_DiffusersConfig, _MainConfig): class MainDiffusersConfig(DiffusersConfigBase):
"""Model config for main diffusers models.""" """Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main type: Literal[ModelType.Main] = ModelType.Main
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
class IPAdapterConfig(ModelConfigBase): class IPAdapterConfig(ModelConfigBase):
"""Model config for IP Adaptor format models.""" """Model config for IP Adaptor format models."""
@ -242,6 +304,10 @@ class IPAdapterConfig(ModelConfigBase):
image_encoder_model_id: str image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI] format: Literal[ModelFormat.InvokeAI]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
class CLIPVisionDiffusersConfig(ModelConfigBase): class CLIPVisionDiffusersConfig(ModelConfigBase):
"""Model config for ClipVision.""" """Model config for ClipVision."""
@ -249,58 +315,65 @@ class CLIPVisionDiffusersConfig(ModelConfigBase):
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers] format: Literal[ModelFormat.Diffusers]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
class T2IConfig(ModelConfigBase):
class T2IAdapterConfig(ModelConfigBase):
"""Model config for T2I.""" """Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers] format: Literal[ModelFormat.Diffusers]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
_ControlNetConfig = Annotated[
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
Field(discriminator="format"),
]
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
AnyModelConfig = Union[ def get_model_discriminator_value(v: Any) -> str:
_MainModelConfig, """
_VaeConfig, Computes the discriminator value for a model config.
_ControlNetConfig, https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
# ModelConfigBase, """
LoRAConfig, format_ = None
TextualInversionConfig, type_ = None
IPAdapterConfig, if isinstance(v, dict):
CLIPVisionDiffusersConfig, format_ = v.get("format")
T2IConfig, if isinstance(format_, Enum):
format_ = format_.value
type_ = v.get("type")
if isinstance(type_, Enum):
type_ = type_.value
else:
format_ = v.format.value
type_ = v.type.value
v = f"{type_}.{format_}"
return v
AnyModelConfig = Annotated[
Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
] ]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig) AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
# IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
# This is a known issue. Please see:
# https://github.com/tiangolo/fastapi/discussions/9761 and
# https://github.com/tiangolo/fastapi/discussions/9287
# AnyModelConfig = Annotated[
# Union[
# _MainModelConfig,
# _ONNXConfig,
# _VaeConfig,
# _ControlNetConfig,
# LoRAConfig,
# TextualInversionConfig,
# IPAdapterConfig,
# CLIPVisionDiffusersConfig,
# T2IConfig,
# ],
# Field(discriminator="type"),
# ]
class ModelConfigFactory(object): class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects.""" """Class for parsing config dicts into StableDiffusion Config obects."""
@ -332,6 +405,6 @@ class ModelConfigFactory(object):
assert model is not None assert model is not None
if key: if key:
model.key = key model.key = key
if timestamp: if isinstance(model, CheckpointConfigBase) and timestamp is not None:
model.last_modified = timestamp model.converted_at = timestamp
return model # type: ignore return model # type: ignore

View File

@ -11,56 +11,175 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
import hashlib import hashlib
import os import os
from pathlib import Path from pathlib import Path
from typing import Dict, Union from typing import Callable, Literal, Optional, Union
from imohash import hashfile from blake3 import blake3
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
ALGORITHM = Literal[
"md5",
"sha1",
"sha224",
"sha256",
"sha384",
"sha512",
"blake2b",
"blake2s",
"sha3_224",
"sha3_256",
"sha3_384",
"sha3_512",
"shake_128",
"shake_256",
"blake3",
]
class FastModelHash(object): class ModelHash:
"""FastModelHash obect provides one public class method, hash().""" """
Creates a hash of a model using a specified algorithm.
@classmethod Args:
def hash(cls, model_location: Union[str, Path]) -> str: algorithm: Hashing algorithm to use. Defaults to BLAKE3.
""" file_filter: A function that takes a file name and returns True if the file should be included in the hash.
Return hexdigest string for model located at model_location.
:param model_location: Path to the model If the model is a single file, it is hashed directly using the provided algorithm.
"""
model_location = Path(model_location) If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
if model_location.is_file():
return cls._hash_file(model_location) Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
elif model_location.is_dir():
return cls._hash_dir(model_location) The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
that directory hashes are never weaker than the file hashes.
Usage:
```py
# BLAKE3 hash
ModelHash().hash("path/to/some/model.safetensors")
# MD5
ModelHash("md5").hash("path/to/model/dir/")
```
"""
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
if algorithm == "blake3":
self._hash_file = self._blake3
elif algorithm in hashlib.algorithms_available:
self._hash_file = self._get_hashlib(algorithm)
else: else:
raise OSError(f"Not a valid file or directory: {model_location}") raise ValueError(f"Algorithm {algorithm} not available")
@classmethod self._file_filter = file_filter or self._default_file_filter
def _hash_file(cls, model_location: Union[str, Path]) -> str:
def hash(self, model_path: Union[str, Path]) -> str:
""" """
Fasthash a single file and return its hexdigest. Return hexdigest of hash of model located at model_path using the algorithm provided at class instantiation.
:param model_location: Path to the model file If model_path is a directory, the hash is computed by hashing the hashes of all model files in the
directory. The final composite hash is always computed using BLAKE3.
Args:
model_path: Path to the model
Returns:
str: Hexdigest of the hash of the model
""" """
# we return md5 hash of the filehash to make it shorter
# cryptographic security not needed here
return hashlib.md5(hashfile(model_location)).hexdigest()
@classmethod model_path = Path(model_path)
def _hash_dir(cls, model_location: Union[str, Path]) -> str: if model_path.is_file():
components: Dict[str, str] = {} return self._hash_file(model_path)
elif model_path.is_dir():
return self._hash_dir(model_path)
else:
raise OSError(f"Not a valid file or directory: {model_path}")
for root, _dirs, files in os.walk(model_location): def _hash_dir(self, dir: Path) -> str:
for file in files: """Compute the hash for all files in a directory and return a hexdigest.
# only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted.
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
continue
path = (Path(root) / file).as_posix()
fast_hash = cls._hash_file(path)
components.update({path: fast_hash})
# hash all the model hashes together, using alphabetic file order Args:
md5 = hashlib.md5() dir: Path to the directory
for _path, fast_hash in sorted(components.items()):
md5.update(fast_hash.encode("utf-8")) Returns:
return md5.hexdigest() str: Hexdigest of the hash of the directory
"""
model_component_paths = self._get_file_paths(dir, self._file_filter)
component_hashes: list[str] = []
for component in sorted(model_component_paths):
component_hashes.append(self._hash_file(component))
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
# for the composite hash
composite_hasher = blake3()
for h in component_hashes:
composite_hasher.update(h.encode("utf-8"))
return composite_hasher.hexdigest()
@staticmethod
def _get_file_paths(model_path: Path, file_filter: Callable[[str], bool]) -> list[Path]:
"""Return a list of all model files in the directory.
Args:
model_path: Path to the model
file_filter: Function that takes a file name and returns True if the file should be included in the list.
Returns:
List of all model files in the directory
"""
files: list[Path] = []
for root, _dirs, _files in os.walk(model_path):
for file in _files:
if file_filter(file):
files.append(Path(root, file))
return files
@staticmethod
def _blake3(file_path: Path) -> str:
"""Hashes a file using BLAKE3
Args:
file_path: Path to the file to hash
Returns:
Hexdigest of the hash of the file
"""
file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(file_path)
return file_hasher.hexdigest()
@staticmethod
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
"""Factory function that returns a function to hash a file with the given algorithm.
Args:
algorithm: Hashing algorithm to use
Returns:
A function that hashes a file using the given algorithm
"""
def hashlib_hasher(file_path: Path) -> str:
"""Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory."""
hasher = hashlib.new(algorithm)
buffer = bytearray(128 * 1024)
mv = memoryview(buffer)
with open(file_path, "rb", buffering=0) as f:
while n := f.readinto(mv):
hasher.update(mv[:n])
return hasher.hexdigest()
return hashlib_hasher
@staticmethod
def _default_file_filter(file_path: str) -> bool:
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
Args:
file_path: Path to the file
Returns:
True if the file matches the given extensions, otherwise False
"""
return file_path.endswith(MODEL_FILE_EXTENSIONS)

View File

@ -13,6 +13,7 @@ from invokeai.backend.model_manager import (
ModelRepoVariant, ModelRepoVariant,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
@ -50,7 +51,7 @@ class ModelLoader(ModelLoaderBase):
:param submodel_type: an ModelType enum indicating the portion of :param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae) the model to retrieve (e.g. ModelType.Vae)
""" """
if model_config.type == "main" and not submodel_type: if model_config.type is ModelType.Main and not submodel_type:
raise InvalidModelConfigException("submodel_type is required when loading a main model") raise InvalidModelConfigException("submodel_type is required when loading a main model")
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type) model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
@ -80,7 +81,7 @@ class ModelLoader(ModelLoaderBase):
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type)) self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
return self._convert_model(config, model_path, cache_path) return self._convert_model(config, model_path, cache_path)
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool: def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
return False return False
def _load_if_needed( def _load_if_needed(
@ -119,7 +120,7 @@ class ModelLoader(ModelLoaderBase):
return calc_model_size_by_fs( return calc_model_size_by_fs(
model_path=model_path, model_path=model_path,
subfolder=submodel_type.value if submodel_type else None, subfolder=submodel_type.value if submodel_type else None,
variant=config.repo_variant if hasattr(config, "repo_variant") else None, variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
) )
# This needs to be implemented in subclasses that handle checkpoints # This needs to be implemented in subclasses that handle checkpoints

View File

@ -15,10 +15,8 @@ Use like this:
""" """
import hashlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
from typing import Callable, Dict, Optional, Tuple, Type
from ..config import ( from ..config import (
AnyModelConfig, AnyModelConfig,
@ -27,8 +25,6 @@ from ..config import (
ModelFormat, ModelFormat,
ModelType, ModelType,
SubModelType, SubModelType,
VaeCheckpointConfig,
VaeDiffusersConfig,
) )
from . import ModelLoaderBase from . import ModelLoaderBase
@ -61,6 +57,9 @@ class ModelLoaderRegistryBase(ABC):
""" """
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
class ModelLoaderRegistry: class ModelLoaderRegistry:
""" """
This class allows model loaders to register their type, base and format. This class allows model loaders to register their type, base and format.
@ -71,10 +70,10 @@ class ModelLoaderRegistry:
@classmethod @classmethod
def register( def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]: ) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
"""Define a decorator which registers the subclass of loader.""" """Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
key = cls._to_registry_key(base, type, format) key = cls._to_registry_key(base, type, format)
if key in cls._registry: if key in cls._registry:
raise Exception( raise Exception(
@ -90,33 +89,15 @@ class ModelLoaderRegistry:
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type.""" """Get subclass of ModelLoaderBase registered to handle base and type."""
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
implementation = cls._registry.get(key1) or cls._registry.get(key2) implementation = cls._registry.get(key1) or cls._registry.get(key2)
if not implementation: if not implementation:
raise NotImplementedError( raise NotImplementedError(
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}" f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
) )
return implementation, conf2, submodel_type return implementation, config, submodel_type
@classmethod
def _handle_subtype_overrides(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
model_path = Path(config.vae)
config_class = (
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
)
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
submodel_type = None
else:
new_conf = config
return new_conf, submodel_type
@staticmethod @staticmethod
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str: def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:

View File

@ -3,8 +3,8 @@
from pathlib import Path from pathlib import Path
import safetensors
import torch import torch
from safetensors.torch import load_file as safetensors_load_file
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModelConfig, AnyModelConfig,
@ -12,6 +12,7 @@ from invokeai.backend.model_manager import (
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.config import CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
from .. import ModelLoaderRegistry from .. import ModelLoaderRegistry
@ -20,15 +21,15 @@ from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
class ControlnetLoader(GenericDiffusersLoader): class ControlNetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models.""" """Class to load ControlNet models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if config.format != ModelFormat.Checkpoint: if not isinstance(config, CheckpointConfigBase):
return False return False
elif ( elif (
dest_path.exists() dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0) and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
): ):
return False return False
@ -37,13 +38,13 @@ class ControlnetLoader(GenericDiffusersLoader):
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}") raise Exception(f"ControlNet conversion not supported for model type: {config.base}")
else: else:
assert hasattr(config, "config") assert isinstance(config, CheckpointConfigBase)
config_file = config.config config_file = config.config_path
if model_path.suffix == ".safetensors": if model_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(model_path, device="cpu") checkpoint = safetensors_load_file(model_path, device="cpu")
else: else:
checkpoint = torch.load(model_path, map_location="cpu") checkpoint = torch.load(model_path, map_location="cpu")

View File

@ -3,9 +3,10 @@
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Optional
from diffusers import ConfigMixin, ModelMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModel, AnyModel,
@ -41,6 +42,7 @@ class GenericDiffusersLoader(ModelLoader):
# TO DO: Add exception handling # TO DO: Add exception handling
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin: def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load.""" """Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
result = None
if submodel_type: if submodel_type:
try: try:
config = self._load_diffusers_config(model_path, config_name="model_index.json") config = self._load_diffusers_config(model_path, config_name="model_index.json")
@ -64,6 +66,7 @@ class GenericDiffusersLoader(ModelLoader):
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json") raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
except KeyError as e: except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
assert result is not None
return result return result
# TO DO: Add exception handling # TO DO: Add exception handling
@ -75,7 +78,7 @@ class GenericDiffusersLoader(ModelLoader):
result: ModelMixin = getattr(res_type, class_name) result: ModelMixin = getattr(res_type, class_name)
return result return result
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]: def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> dict[str, Any]:
return ConfigLoader.load_config(model_path, config_name=config_name) return ConfigLoader.load_config(model_path, config_name=config_name)
@ -83,8 +86,8 @@ class ConfigLoader(ConfigMixin):
"""Subclass of ConfigMixin for loading diffusers configuration files.""" """Subclass of ConfigMixin for loading diffusers configuration files."""
@classmethod @classmethod
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: def load_config(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # pyright: ignore [reportIncompatibleMethodOverride]
"""Load a diffusrs ConfigMixin configuration.""" """Load a diffusrs ConfigMixin configuration."""
cls.config_name = kwargs.pop("config_name") cls.config_name = kwargs.pop("config_name")
# Diffusers doesn't provide typing info # TODO(psyche): the types on this diffusers method are not correct
return super().load_config(*args, **kwargs) # type: ignore return super().load_config(*args, **kwargs) # type: ignore

View File

@ -31,7 +31,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
if submodel_type is not None: if submodel_type is not None:
raise ValueError("There are no submodels in an IP-Adapter model.") raise ValueError("There are no submodels in an IP-Adapter model.")
model = build_ip_adapter( model = build_ip_adapter(
ip_adapter_ckpt_path=model_path / "ip_adapter.bin", ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
device=torch.device("cpu"), device=torch.device("cpu"),
dtype=self._torch_dtype, dtype=self._torch_dtype,
) )

View File

@ -22,8 +22,8 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
from .. import ModelLoader, ModelLoaderRegistry from .. import ModelLoader, ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
class LoraLoader(ModelLoader): class LoraLoader(ModelLoader):
"""Class to load LoRA models.""" """Class to load LoRA models."""

View File

@ -18,7 +18,7 @@ from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.ONNX)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
class OnnyxDiffusersModel(GenericDiffusersLoader): class OnnyxDiffusersModel(GenericDiffusersLoader):
"""Class to load onnx models.""" """Class to load onnx models."""

View File

@ -4,7 +4,8 @@
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModel, AnyModel,
@ -16,7 +17,7 @@ from invokeai.backend.model_manager import (
ModelVariantType, ModelVariantType,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.config import MainCheckpointConfig from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from .. import ModelLoaderRegistry from .. import ModelLoaderRegistry
@ -54,11 +55,11 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
return result return result
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if config.format != ModelFormat.Checkpoint: if not isinstance(config, CheckpointConfigBase):
return False return False
elif ( elif (
dest_path.exists() dest_path.exists()
and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0) and (dest_path / "model_index.json").stat().st_mtime >= (config.converted_at or 0.0)
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
): ):
return False return False
@ -73,7 +74,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
) )
config_file = config.config config_file = config.config_path
self._logger.info(f"Converting {model_path} to diffusers format") self._logger.info(f"Converting {model_path} to diffusers format")
convert_ckpt_to_diffusers( convert_ckpt_to_diffusers(

View File

@ -3,9 +3,9 @@
from pathlib import Path from pathlib import Path
import safetensors
import torch import torch
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from safetensors.torch import load_file as safetensors_load_file
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModelConfig, AnyModelConfig,
@ -13,24 +13,25 @@ from invokeai.backend.model_manager import (
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.config import CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
from .. import ModelLoaderRegistry from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class VaeLoader(GenericDiffusersLoader): class VaeLoader(GenericDiffusersLoader):
"""Class to load VAE models.""" """Class to load VAE models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if config.format != ModelFormat.Checkpoint: if not isinstance(config, CheckpointConfigBase):
return False return False
elif ( elif (
dest_path.exists() dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0) and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
): ):
return False return False
@ -38,16 +39,15 @@ class VaeLoader(GenericDiffusersLoader):
return True return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
# TO DO: check whether sdxl VAE models convert. # TODO(MM2): check whether sdxl VAE models convert.
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}") raise Exception(f"VAE conversion not supported for model type: {config.base}")
else: else:
config_file = ( assert isinstance(config, CheckpointConfigBase)
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml" config_file = config.config_path
)
if model_path.suffix == ".safetensors": if model_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(model_path, device="cpu") checkpoint = safetensors_load_file(model_path, device="cpu")
else: else:
checkpoint = torch.load(model_path, map_location="cpu") checkpoint = torch.load(model_path, map_location="cpu")
@ -55,7 +55,7 @@ class VaeLoader(GenericDiffusersLoader):
if "state_dict" in checkpoint: if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"] checkpoint = checkpoint["state_dict"]
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file) ckpt_config = OmegaConf.load(self._app_config.root_path / config_file)
assert isinstance(ckpt_config, DictConfig) assert isinstance(ckpt_config, DictConfig)
vae_model = convert_ldm_vae_to_diffusers( vae_model = convert_ldm_vae_to_diffusers(

View File

@ -16,6 +16,7 @@ from diffusers import AutoPipelineForText2Image
from diffusers.utils import logging as dlogging from diffusers.utils import logging as dlogging
from invokeai.app.services.model_install import ModelInstallServiceBase from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.util.devices import choose_torch_device, torch_dtype from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from . import ( from . import (
@ -117,7 +118,6 @@ class ModelMerger(object):
config = self._installer.app_config config = self._installer.app_config
store = self._installer.record_store store = self._installer.record_store
base_models: Set[BaseModelType] = set() base_models: Set[BaseModelType] = set()
vae = None
variant = None if self._installer.app_config.full_precision else "fp16" variant = None if self._installer.app_config.full_precision else "fp16"
assert ( assert (
@ -134,10 +134,6 @@ class ModelMerger(object):
"normal" "normal"
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged" ), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
# pick up the first model's vae
if key == model_keys[0]:
vae = info.vae
# tally base models used # tally base models used
base_models.add(info.base) base_models.add(info.base)
model_paths.extend([config.models_path / info.path]) model_paths.extend([config.models_path / info.path])
@ -163,12 +159,10 @@ class ModelMerger(object):
# update model's config # update model's config
model_config = self._installer.record_store.get_model(key) model_config = self._installer.record_store.get_model(key)
model_config.update( model_config.name = merged_model_name
{ model_config.description = f"Merge of models {', '.join(model_names)}"
"name": merged_model_name,
"description": f"Merge of models {', '.join(model_names)}", self._installer.record_store.update_model(
"vae": vae, key, ModelRecordChanges(name=model_config.name, description=model_config.description)
}
) )
self._installer.record_store.update_model(key, model_config)
return model_config return model_config

View File

@ -25,9 +25,7 @@ from .metadata_base import (
AnyModelRepoMetadataValidator, AnyModelRepoMetadataValidator,
BaseMetadata, BaseMetadata,
CivitaiMetadata, CivitaiMetadata,
CommercialUsage,
HuggingFaceMetadata, HuggingFaceMetadata,
LicenseRestrictions,
ModelMetadataWithFiles, ModelMetadataWithFiles,
RemoteModelFile, RemoteModelFile,
UnknownMetadataException, UnknownMetadataException,
@ -38,10 +36,8 @@ __all__ = [
"AnyModelRepoMetadataValidator", "AnyModelRepoMetadataValidator",
"CivitaiMetadata", "CivitaiMetadata",
"CivitaiMetadataFetch", "CivitaiMetadataFetch",
"CommercialUsage",
"HuggingFaceMetadata", "HuggingFaceMetadata",
"HuggingFaceMetadataFetch", "HuggingFaceMetadataFetch",
"LicenseRestrictions",
"ModelMetadataFetchBase", "ModelMetadataFetchBase",
"BaseMetadata", "BaseMetadata",
"ModelMetadataWithFiles", "ModelMetadataWithFiles",

View File

@ -23,22 +23,21 @@ metadata = fetcher.from_url("https://civitai.com/models/206883/split")
print(metadata.trained_words) print(metadata.trained_words)
""" """
import json
import re import re
from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Optional
import requests import requests
from pydantic import TypeAdapter, ValidationError
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from requests.sessions import Session from requests.sessions import Session
from invokeai.backend.model_manager import ModelRepoVariant from invokeai.backend.model_manager.config import ModelRepoVariant
from ..metadata_base import ( from ..metadata_base import (
AnyModelRepoMetadata, AnyModelRepoMetadata,
CivitaiMetadata, CivitaiMetadata,
CommercialUsage,
LicenseRestrictions,
RemoteModelFile, RemoteModelFile,
UnknownMetadataException, UnknownMetadataException,
) )
@ -52,10 +51,13 @@ CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/" CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
StringSetAdapter = TypeAdapter(set[str])
class CivitaiMetadataFetch(ModelMetadataFetchBase): class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""Fetch model metadata from Civitai.""" """Fetch model metadata from Civitai."""
def __init__(self, session: Optional[Session] = None): def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None):
""" """
Initialize the fetcher with an optional requests.sessions.Session object. Initialize the fetcher with an optional requests.sessions.Session object.
@ -63,6 +65,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
this module without an internet connection. this module without an internet connection.
""" """
self._requests = session or requests.Session() self._requests = session or requests.Session()
self._api_key = api_key
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
""" """
@ -102,22 +105,21 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
May raise an `UnknownMetadataException`. May raise an `UnknownMetadataException`.
""" """
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id) model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(model_url).json() model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
return self._from_model_json(model_json) return self._from_api_response(model_json)
def _from_model_json(self, model_json: Dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata: def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
try: try:
version_id = version_id or model_json["modelVersions"][0]["id"] version_id = version_id or api_response["modelVersions"][0]["id"]
except TypeError as excp: except TypeError as excp:
raise UnknownMetadataException from excp raise UnknownMetadataException from excp
# loop till we find the section containing the version requested # loop till we find the section containing the version requested
version_sections = [x for x in model_json["modelVersions"] if x["id"] == version_id] version_sections = [x for x in api_response["modelVersions"] if x["id"] == version_id]
if not version_sections: if not version_sections:
raise UnknownMetadataException(f"Version {version_id} not found in model metadata") raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
version_json = version_sections[0] version_json = version_sections[0]
safe_thumbnails = [x["url"] for x in version_json["images"] if x["nsfw"] == "None"]
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary. # Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
primary = [x for x in version_json["files"] if x.get("primary")] primary = [x for x in version_json["files"] if x.get("primary")]
@ -134,36 +136,23 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
url = url + f"?type={primary_file['type']}{metadata_string}" url = url + f"?type={primary_file['type']}{metadata_string}"
model_files = [ model_files = [
RemoteModelFile( RemoteModelFile(
url=url, url=self._get_url_with_api_key(url),
path=Path(primary_file["name"]), path=Path(primary_file["name"]),
size=int(primary_file["sizeKB"] * 1024), size=int(primary_file["sizeKB"] * 1024),
sha256=primary_file["hashes"]["SHA256"], sha256=primary_file["hashes"]["SHA256"],
) )
] ]
try:
trigger_phrases = StringSetAdapter.validate_python(version_json.get("trainedWords"))
except ValidationError:
trigger_phrases: set[str] = set()
return CivitaiMetadata( return CivitaiMetadata(
id=model_json["id"],
name=version_json["name"], name=version_json["name"],
version_id=version_json["id"],
version_name=version_json["name"],
created=datetime.fromisoformat(_fix_timezone(version_json["createdAt"])),
updated=datetime.fromisoformat(_fix_timezone(version_json["updatedAt"])),
published=datetime.fromisoformat(_fix_timezone(version_json["publishedAt"])),
base_model_trained_on=version_json["baseModel"], # note - need a dictionary to turn into a BaseModelType
files=model_files, files=model_files,
download_url=version_json["downloadUrl"], trigger_phrases=trigger_phrases,
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None, api_response=json.dumps(version_json),
author=model_json["creator"]["username"],
description=model_json["description"],
version_description=version_json["description"] or "",
tags=model_json["tags"],
trained_words=version_json["trainedWords"],
nsfw=model_json["nsfw"],
restrictions=LicenseRestrictions(
AllowNoCredit=model_json["allowNoCredit"],
AllowCommercialUse={CommercialUsage(x) for x in model_json["allowCommercialUse"]},
AllowDerivatives=model_json["allowDerivatives"],
AllowDifferentLicense=model_json["allowDifferentLicense"],
),
) )
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata: def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
@ -174,14 +163,14 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
""" """
if model_id is None: if model_id is None:
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id) version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
version = self._requests.get(version_url).json() version = self._requests.get(self._get_url_with_api_key(version_url)).json()
if error := version.get("error"): if error := version.get("error"):
raise UnknownMetadataException(error) raise UnknownMetadataException(error)
model_id = version["modelId"] model_id = version["modelId"]
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id) model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(model_url).json() model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
return self._from_model_json(model_json, version_id) return self._from_api_response(model_json, version_id)
@classmethod @classmethod
def from_json(cls, json: str) -> CivitaiMetadata: def from_json(cls, json: str) -> CivitaiMetadata:
@ -189,6 +178,11 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
metadata = CivitaiMetadata.model_validate_json(json) metadata = CivitaiMetadata.model_validate_json(json)
return metadata return metadata
def _get_url_with_api_key(self, url: str) -> str:
if not self._api_key:
return url
def _fix_timezone(date: str) -> str: if "?" in url:
return re.sub(r"Z$", "+00:00", date) return f"{url}&token={self._api_key}"
return f"{url}?token={self._api_key}"

View File

@ -13,6 +13,7 @@ metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
print(metadata.tags) print(metadata.tags)
""" """
import json
import re import re
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -23,7 +24,7 @@ from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFo
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from requests.sessions import Session from requests.sessions import Session
from invokeai.backend.model_manager import ModelRepoVariant from invokeai.backend.model_manager.config import ModelRepoVariant
from ..metadata_base import ( from ..metadata_base import (
AnyModelRepoMetadata, AnyModelRepoMetadata,
@ -60,6 +61,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
# Little loop which tries fetching a revision corresponding to the selected variant. # Little loop which tries fetching a revision corresponding to the selected variant.
# If not available, then set variant to None and get the default. # If not available, then set variant to None and get the default.
# If this too fails, raise exception. # If this too fails, raise exception.
model_info = None model_info = None
while not model_info: while not model_info:
try: try:
@ -72,23 +74,24 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
else: else:
variant = None variant = None
files: list[RemoteModelFile] = []
_, name = id.split("/") _, name = id.split("/")
return HuggingFaceMetadata(
id=model_info.id, for s in model_info.siblings or []:
author=model_info.author, assert s.rfilename is not None
name=name, assert s.size is not None
last_modified=model_info.last_modified, files.append(
tag_dict=model_info.card_data.to_dict() if model_info.card_data else {},
tags=model_info.tags,
files=[
RemoteModelFile( RemoteModelFile(
url=hf_hub_url(id, x.rfilename, revision=variant), url=hf_hub_url(id, s.rfilename, revision=variant),
path=Path(name, x.rfilename), path=Path(name, s.rfilename),
size=x.size, size=s.size,
sha256=x.lfs.get("sha256") if x.lfs else None, sha256=s.lfs.get("sha256") if s.lfs else None,
) )
for x in model_info.siblings )
],
return HuggingFaceMetadata(
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str)
) )
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:

View File

@ -14,10 +14,8 @@ versions of these fields are intended to be kept in sync with the
remote repo. remote repo.
""" """
from datetime import datetime
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union from typing import List, Literal, Optional, Union
from huggingface_hub import configure_http_backend, hf_hub_url from huggingface_hub import configure_http_backend, hf_hub_url
from pydantic import BaseModel, Field, TypeAdapter from pydantic import BaseModel, Field, TypeAdapter
@ -34,31 +32,6 @@ class UnknownMetadataException(Exception):
"""Raised when no metadata is available for a model.""" """Raised when no metadata is available for a model."""
class CommercialUsage(str, Enum):
"""Type of commercial usage allowed."""
No = "None"
Image = "Image"
Rent = "Rent"
RentCivit = "RentCivit"
Sell = "Sell"
class LicenseRestrictions(BaseModel):
"""Broad categories of licensing restrictions."""
AllowNoCredit: bool = Field(
description="if true, model can be redistributed without crediting author", default=False
)
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
AllowDifferentLicense: bool = Field(
description="if true, derivatives of this model be redistributed under a different license", default=False
)
AllowCommercialUse: Optional[Set[CommercialUsage] | CommercialUsage] = Field(
description="Type of commercial use allowed if no commercial use is allowed.", default=None
)
class RemoteModelFile(BaseModel): class RemoteModelFile(BaseModel):
"""Information about a downloadable file that forms part of a model.""" """Information about a downloadable file that forms part of a model."""
@ -72,8 +45,6 @@ class ModelMetadataBase(BaseModel):
"""Base class for model metadata information.""" """Base class for model metadata information."""
name: str = Field(description="model's name") name: str = Field(description="model's name")
author: str = Field(description="model's author")
tags: Set[str] = Field(description="tags provided by model source")
class BaseMetadata(ModelMetadataBase): class BaseMetadata(ModelMetadataBase):
@ -111,60 +82,16 @@ class CivitaiMetadata(ModelMetadataWithFiles):
"""Extended metadata fields provided by Civitai.""" """Extended metadata fields provided by Civitai."""
type: Literal["civitai"] = "civitai" type: Literal["civitai"] = "civitai"
id: int = Field(description="Civitai version identifier") trigger_phrases: set[str] = Field(description="Trigger phrases extracted from the API response")
version_name: str = Field(description="Version identifier, such as 'V2-alpha'") api_response: Optional[str] = Field(description="Response from the Civitai API as stringified JSON", default=None)
version_id: int = Field(description="Civitai model version identifier")
created: datetime = Field(description="date the model was created")
updated: datetime = Field(description="date the model was last modified")
published: datetime = Field(description="date the model was published to Civitai")
description: str = Field(description="text description of model; may contain HTML")
version_description: str = Field(
description="text description of the model's reversion; usually change history; may contain HTML"
)
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
download_url: AnyHttpUrl = Field(description="download URL for this model")
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
weight_minmax: Tuple[float, float] = Field(
description="minimum and maximum slider values for a LoRA or other secondary model", default=(-1.0, +2.0)
) # note: For future use
@property
def credit_required(self) -> bool:
"""Return True if you must give credit for derivatives of this model and images generated from it."""
return not self.restrictions.AllowNoCredit
@property
def allow_commercial_use(self) -> bool:
"""Return True if commercial use is allowed."""
if self.restrictions.AllowCommercialUse is None:
return False
else:
# accommodate schema change
acu = self.restrictions.AllowCommercialUse
commercial_usage = acu if isinstance(acu, set) else {acu}
return CommercialUsage.No not in commercial_usage
@property
def allow_derivatives(self) -> bool:
"""Return True if derivatives of this model can be redistributed."""
return self.restrictions.AllowDerivatives
@property
def allow_different_license(self) -> bool:
"""Return true if derivatives of this model can use a different license."""
return self.restrictions.AllowDifferentLicense
class HuggingFaceMetadata(ModelMetadataWithFiles): class HuggingFaceMetadata(ModelMetadataWithFiles):
"""Extended metadata fields provided by HuggingFace.""" """Extended metadata fields provided by HuggingFace."""
type: Literal["huggingface"] = "huggingface" type: Literal["huggingface"] = "huggingface"
id: str = Field(description="huggingface model id") id: str = Field(description="The HF model id")
tag_dict: Dict[str, Any] api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
last_modified: datetime = Field(description="date of last commit to repo")
def download_urls( def download_urls(
self, self,
@ -193,7 +120,7 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
# the next step reads model_index.json to determine which subdirectories belong # the next step reads model_index.json to determine which subdirectories belong
# to the model # to the model
if Path(f"{prefix}model_index.json") in paths: if Path(f"{prefix}model_index.json") in paths:
url = hf_hub_url(self.id, filename="model_index.json", subfolder=subfolder) url = hf_hub_url(self.id, filename="model_index.json", subfolder=str(subfolder) if subfolder else None)
resp = session.get(url) resp = session.get(url)
resp.raise_for_status() resp.raise_for_status()
submodels = resp.json() submodels = resp.json()

View File

@ -1,221 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Storage for Model Metadata
"""
import sqlite3
from typing import List, Optional, Set, Tuple
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .fetch import ModelMetadataFetchBase
from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException
class ModelMetadataStore:
"""Store, search and fetch model metadata retrieved from remote repositories."""
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_metadata(
id,
metadata
)
VALUES (?,?);
""",
(
model_key,
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
raise UnknownMetadataException from excp
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata FROM model_metadata
WHERE id=?;
""",
(model_key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownMetadataException("model metadata not found")
return ModelMetadataFetchBase.from_json(rows[0])
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT id,metadata FROM model_metadata;
""",
(),
)
rows = self._cursor.fetchall()
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
json_serialized = metadata.model_dump_json() # turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_metadata
SET
metadata=?
WHERE id=?;
""",
(json_serialized, model_key),
)
if self._cursor.rowcount == 0:
raise UnknownMetadataException("model metadata not found")
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_metadata(model_key)
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
self._cursor.execute(
"""--sql
select tag_text from tags;
"""
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Optional[Set[str]] = None
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.model_id FROM model_tags AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
if matches is None:
matches = model_keys
matches = matches.intersection(model_keys)
except sqlite3.Error as e:
raise e
return matches if matches else set()
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE author=?;
""",
(author,),
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE name=?;
""",
(name,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@ -8,6 +8,7 @@ import torch
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.util.util import SilenceWarnings from invokeai.backend.util.util import SilenceWarnings
from .config import ( from .config import (
@ -17,11 +18,12 @@ from .config import (
ModelConfigFactory, ModelConfigFactory,
ModelFormat, ModelFormat,
ModelRepoVariant, ModelRepoVariant,
ModelSourceType,
ModelType, ModelType,
ModelVariantType, ModelVariantType,
SchedulerPredictionType, SchedulerPredictionType,
) )
from .hash import FastModelHash from .hash import ModelHash
from .util.model_util import lora_token_vector_length, read_checkpoint_meta from .util.model_util import lora_token_vector_length, read_checkpoint_meta
CkptType = Dict[str, Any] CkptType = Dict[str, Any]
@ -95,8 +97,8 @@ class ModelProbe(object):
"StableDiffusionXLImg2ImgPipeline": ModelType.Main, "StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main, "StableDiffusionXLInpaintPipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main, "LatentConsistencyModelPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae, "AutoencoderKL": ModelType.VAE,
"AutoencoderTiny": ModelType.Vae, "AutoencoderTiny": ModelType.VAE,
"ControlNetModel": ModelType.ControlNet, "ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision, "CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter, "T2IAdapter": ModelType.T2IAdapter,
@ -108,14 +110,6 @@ class ModelProbe(object):
) -> None: ) -> None:
cls.PROBES[format][model_type] = probe_class cls.PROBES[format][model_type] = probe_class
@classmethod
def heuristic_probe(
cls,
model_path: Path,
fields: Optional[Dict[str, Any]] = None,
) -> AnyModelConfig:
return cls.probe(model_path, fields)
@classmethod @classmethod
def probe( def probe(
cls, cls,
@ -137,19 +131,21 @@ class ModelProbe(object):
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None model_info = None
model_type = None model_type = None
if format_type == "diffusers": if format_type is ModelFormat.Diffusers:
model_type = cls.get_model_type_from_folder(model_path) model_type = cls.get_model_type_from_folder(model_path)
else: else:
model_type = cls.get_model_type_from_checkpoint(model_path) model_type = cls.get_model_type_from_checkpoint(model_path)
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type) probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class: if not probe_class:
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}") raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
hash = FastModelHash.hash(model_path)
probe = probe_class(model_path) probe = probe_class(model_path)
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
fields["source"] = fields.get("source") or model_path.as_posix()
fields["key"] = fields.get("key", uuid_string())
fields["path"] = model_path.as_posix() fields["path"] = model_path.as_posix()
fields["type"] = fields.get("type") or model_type fields["type"] = fields.get("type") or model_type
fields["base"] = fields.get("base") or probe.get_base_type() fields["base"] = fields.get("base") or probe.get_base_type()
@ -161,15 +157,17 @@ class ModelProbe(object):
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}" fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
) )
fields["format"] = fields.get("format") or probe.get_format() fields["format"] = fields.get("format") or probe.get_format()
fields["original_hash"] = fields.get("original_hash") or hash fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
fields["current_hash"] = fields.get("current_hash") or hash
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"): if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
# additional fields needed for main and controlnet models # additional fields needed for main and controlnet models
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint: if (
fields["config"] = cls._get_checkpoint_config_path( fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
and fields["format"] is ModelFormat.Checkpoint
):
fields["config_path"] = cls._get_checkpoint_config_path(
model_path, model_path,
model_type=fields["type"], model_type=fields["type"],
base_type=fields["base"], base_type=fields["base"],
@ -179,7 +177,7 @@ class ModelProbe(object):
# additional fields needed for main non-checkpoint models # additional fields needed for main non-checkpoint models
elif fields["type"] == ModelType.Main and fields["format"] in [ elif fields["type"] == ModelType.Main and fields["format"] in [
ModelFormat.Onnx, ModelFormat.ONNX,
ModelFormat.Olive, ModelFormat.Olive,
ModelFormat.Diffusers, ModelFormat.Diffusers,
]: ]:
@ -213,11 +211,11 @@ class ModelProbe(object):
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}): if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}): elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
return ModelType.Vae return ModelType.VAE
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}): elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.Lora return ModelType.LoRA
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}): elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
return ModelType.Lora return ModelType.LoRA
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}): elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.ControlNet return ModelType.ControlNet
elif key in {"emb_params", "string_to_param"}: elif key in {"emb_params", "string_to_param"}:
@ -239,7 +237,7 @@ class ModelProbe(object):
if (folder_path / f"learned_embeds.{suffix}").exists(): if (folder_path / f"learned_embeds.{suffix}").exists():
return ModelType.TextualInversion return ModelType.TextualInversion
if (folder_path / f"pytorch_lora_weights.{suffix}").exists(): if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
return ModelType.Lora return ModelType.LoRA
if (folder_path / "unet/model.onnx").exists(): if (folder_path / "unet/model.onnx").exists():
return ModelType.ONNX return ModelType.ONNX
if (folder_path / "image_encoder.txt").exists(): if (folder_path / "image_encoder.txt").exists():
@ -285,13 +283,21 @@ class ModelProbe(object):
if possible_conf.exists(): if possible_conf.exists():
return possible_conf.absolute() return possible_conf.absolute()
if model_type == ModelType.Main: if model_type is ModelType.Main:
config_file = LEGACY_CONFIGS[base_type][variant_type] config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type] config_file = config_file[prediction_type]
elif model_type == ModelType.ControlNet: elif model_type is ModelType.ControlNet:
config_file = ( config_file = (
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml" "../controlnet/cldm_v15.yaml"
if base_type is BaseModelType.StableDiffusion1
else "../controlnet/cldm_v21.yaml"
)
elif model_type is ModelType.VAE:
config_file = (
"../stable-diffusion/v1-inference.yaml"
if base_type is BaseModelType.StableDiffusion1
else "../stable-diffusion/v2-inference.yaml"
) )
else: else:
raise InvalidModelConfigException( raise InvalidModelConfigException(
@ -497,12 +503,12 @@ class FolderProbeBase(ProbeBase):
if ".fp16" in x.suffixes: if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16 return ModelRepoVariant.FP16
if "openvino_model" in x.name: if "openvino_model" in x.name:
return ModelRepoVariant.OPENVINO return ModelRepoVariant.OpenVINO
if "flax_model" in x.name: if "flax_model" in x.name:
return ModelRepoVariant.FLAX return ModelRepoVariant.Flax
if x.suffix == ".onnx": if x.suffix == ".onnx":
return ModelRepoVariant.ONNX return ModelRepoVariant.ONNX
return ModelRepoVariant.DEFAULT return ModelRepoVariant.Default
class PipelineFolderProbe(FolderProbeBase): class PipelineFolderProbe(FolderProbeBase):
@ -708,8 +714,8 @@ class T2IAdapterFolderProbe(FolderProbeBase):
############## register probe classes ###### ############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe) ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe) ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
@ -717,8 +723,8 @@ ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderPro
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)

View File

@ -13,6 +13,7 @@ files_to_download = select_hf_model_files(metadata.files, variant='onnx')
""" """
import re import re
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Set from typing import Dict, List, Optional, Set
@ -34,7 +35,7 @@ def filter_files(
The file list can be obtained from the `files` field of HuggingFaceMetadata, The file list can be obtained from the `files` field of HuggingFaceMetadata,
as defined in `invokeai.backend.model_manager.metadata.metadata_base`. as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
""" """
variant = variant or ModelRepoVariant.DEFAULT variant = variant or ModelRepoVariant.Default
paths: List[Path] = [] paths: List[Path] = []
root = files[0].parts[0] root = files[0].parts[0]
@ -73,64 +74,81 @@ def filter_files(
return sorted(_filter_by_variant(paths, variant)) return sorted(_filter_by_variant(paths, variant))
@dataclass
class SubfolderCandidate:
path: Path
score: int
def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]: def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]:
"""Select the proper variant files from a list of HuggingFace repo_id paths.""" """Select the proper variant files from a list of HuggingFace repo_id paths."""
result = set() result: set[Path] = set()
basenames: Dict[Path, Path] = {} subfolder_weights: dict[Path, list[SubfolderCandidate]] = {}
for path in files: for path in files:
if path.suffix in [".onnx", ".pb", ".onnx_data"]: if path.suffix in [".onnx", ".pb", ".onnx_data"]:
if variant == ModelRepoVariant.ONNX: if variant == ModelRepoVariant.ONNX:
result.add(path) result.add(path)
elif "openvino_model" in path.name: elif "openvino_model" in path.name:
if variant == ModelRepoVariant.OPENVINO: if variant == ModelRepoVariant.OpenVINO:
result.add(path) result.add(path)
elif "flax_model" in path.name: elif "flax_model" in path.name:
if variant == ModelRepoVariant.FLAX: if variant == ModelRepoVariant.Flax:
result.add(path) result.add(path)
elif path.suffix in [".json", ".txt"]: elif path.suffix in [".json", ".txt"]:
result.add(path) result.add(path)
elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [ elif variant in [
ModelRepoVariant.FP16, ModelRepoVariant.FP16,
ModelRepoVariant.FP32, ModelRepoVariant.FP32,
ModelRepoVariant.DEFAULT, ModelRepoVariant.Default,
]: ] and path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"]:
parent = path.parent # For weights files, we want to select the best one for each subfolder. For example, we may have multiple
suffixes = path.suffixes # text encoders:
if len(suffixes) == 2: #
variant_label, suffix = suffixes # - text_encoder/model.fp16.safetensors
basename = parent / Path(path.stem).stem # - text_encoder/model.safetensors
else: # - text_encoder/pytorch_model.bin
variant_label = "" # - text_encoder/pytorch_model.fp16.bin
suffix = suffixes[0] #
basename = parent / path.stem # We prefer safetensors over other file formats and an exact variant match. We'll score each file based on
# variant and format and select the best one.
if previous := basenames.get(basename): parent = path.parent
if ( score = 0
previous.suffix != ".safetensors" and suffix == ".safetensors"
): # replace non-safetensors with safetensors when available if path.suffix == ".safetensors":
basenames[basename] = path score += 1
if variant_label == f".{variant}":
basenames[basename] = path candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]:
basenames[basename] = path # Some special handling is needed here if there is not an exact match and if we cannot infer the variant
else: # from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
basenames[basename] = path if candidate_variant_label == f".{variant}" or (
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]
):
score += 1
if parent not in subfolder_weights:
subfolder_weights[parent] = []
subfolder_weights[parent].append(SubfolderCandidate(path=path, score=score))
else: else:
continue continue
for v in basenames.values(): for candidate_list in subfolder_weights.values():
result.add(v) highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
if highest_score_candidate:
result.add(highest_score_candidate.path)
# If one of the architecture-related variants was specified and no files matched other than # If one of the architecture-related variants was specified and no files matched other than
# config and text files then we return an empty list # config and text files then we return an empty list
if ( if (
variant variant
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OPENVINO, ModelRepoVariant.FLAX] and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OpenVINO, ModelRepoVariant.Flax]
and not any(variant.value in x.name for x in result) and not any(variant.value in x.name for x in result)
): ):
return set() return set()

View File

@ -4,13 +4,11 @@ Initialization file for the invokeai.backend.stable_diffusion package
from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401 from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401
from .diffusion import InvokeAIDiffuserComponent # noqa: F401 from .diffusion import InvokeAIDiffuserComponent # noqa: F401
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .seamless import set_seamless # noqa: F401 from .seamless import set_seamless # noqa: F401
__all__ = [ __all__ = [
"PipelineIntermediateState", "PipelineIntermediateState",
"StableDiffusionGeneratorPipeline", "StableDiffusionGeneratorPipeline",
"InvokeAIDiffuserComponent", "InvokeAIDiffuserComponent",
"AttentionMapSaver",
"set_seamless", "set_seamless",
] ]

View File

@ -12,7 +12,6 @@ import torch
import torchvision.transforms as T import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
@ -26,9 +25,9 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..util import auto_detect_slice_size, normalize_device from ..util import auto_detect_slice_size, normalize_device
from .diffusion import AttentionMapSaver, InvokeAIDiffuserComponent
@dataclass @dataclass
@ -39,7 +38,6 @@ class PipelineIntermediateState:
timestep: int timestep: int
latents: torch.Tensor latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None predicted_original: Optional[torch.Tensor] = None
attention_map_saver: Optional[AttentionMapSaver] = None
@dataclass @dataclass
@ -190,19 +188,6 @@ class T2IAdapterData:
end_step_percent: float = Field(default=1.0) end_step_percent: float = Field(default=1.0)
@dataclass
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
r"""
Output class for InvokeAI's Stable Diffusion pipeline.
Args:
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
after generation completes. Optional.
"""
attention_map_saver: Optional[AttentionMapSaver]
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion. Pipeline for text-to-image generation using Stable Diffusion.
@ -343,9 +328,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
masked_latents: Optional[torch.Tensor] = None, masked_latents: Optional[torch.Tensor] = None,
gradient_mask: Optional[bool] = False, gradient_mask: Optional[bool] = False,
seed: Optional[int] = None, seed: Optional[int] = None,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: ) -> torch.Tensor:
if init_timestep.shape[0] == 0: if init_timestep.shape[0] == 0:
return latents, None return latents
if additional_guidance is None: if additional_guidance is None:
additional_guidance = [] additional_guidance = []
@ -385,7 +370,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask)) additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
try: try:
latents, attention_map_saver = self.generate_latents_from_embeddings( latents = self.generate_latents_from_embeddings(
latents, latents,
timesteps, timesteps,
conditioning_data, conditioning_data,
@ -402,7 +387,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if mask is not None and not gradient_mask: if mask is not None and not gradient_mask:
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)) latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
return latents, attention_map_saver return latents
def generate_latents_from_embeddings( def generate_latents_from_embeddings(
self, self,
@ -415,16 +400,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
ip_adapter_data: Optional[list[IPAdapterData]] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
): ) -> torch.Tensor:
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(latents)
if additional_guidance is None: if additional_guidance is None:
additional_guidance = [] additional_guidance = []
batch_size = latents.shape[0] batch_size = latents.shape[0]
attention_map_saver: Optional[AttentionMapSaver] = None
if timesteps.shape[0] == 0: if timesteps.shape[0] == 0:
return latents, attention_map_saver return latents
ip_adapter_unet_patcher = None ip_adapter_unet_patcher = None
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning
@ -432,7 +416,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
attn_ctx = self.invokeai_diffuser.custom_attention_context( attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model, self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info, extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
) )
self.use_ip_adapter = False self.use_ip_adapter = False
elif ip_adapter_data is not None: elif ip_adapter_data is not None:
@ -483,13 +466,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
predicted_original = getattr(step_output, "pred_original_sample", None) predicted_original = getattr(step_output, "pred_original_sample", None)
# TODO resuscitate attention map saving
# if i == len(timesteps)-1 and extra_conditioning_info is not None:
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
# attention_map_token_ids = range(1, eos_token_index)
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
if callback is not None: if callback is not None:
callback( callback(
PipelineIntermediateState( PipelineIntermediateState(
@ -499,11 +475,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
timestep=int(t), timestep=int(t),
latents=latents, latents=latents,
predicted_original=predicted_original, predicted_original=predicted_original,
attention_map_saver=attention_map_saver,
) )
) )
return latents, attention_map_saver return latents
@torch.inference_mode() @torch.inference_mode()
def step( def step(

View File

@ -2,6 +2,4 @@
Initialization file for invokeai.models.diffusion Initialization file for invokeai.models.diffusion
""" """
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401 from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401

View File

@ -3,19 +3,13 @@
import enum import enum
import math
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, Optional from typing import Optional
import diffusers
import psutil
import torch import torch
from compel.cross_attention_control import Arguments from compel.cross_attention_control import Arguments
from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, SlicedAttnProcessor from diffusers.models.attention_processor import Attention, SlicedAttnProcessor
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from torch import nn
import invokeai.backend.util.logging as logger
from ...util import torch_dtype from ...util import torch_dtype
@ -25,72 +19,14 @@ class CrossAttentionType(enum.Enum):
TOKENS = 2 TOKENS = 2
class Context: class CrossAttnControlContext:
cross_attention_mask: Optional[torch.Tensor] def __init__(self, arguments: Arguments):
cross_attention_index_map: Optional[torch.Tensor]
class Action(enum.Enum):
NONE = 0
SAVE = (1,)
APPLY = 2
def __init__(self, arguments: Arguments, step_count: int):
""" """
:param arguments: Arguments for the cross-attention control process :param arguments: Arguments for the cross-attention control process
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
""" """
self.cross_attention_mask = None self.cross_attention_mask: Optional[torch.Tensor] = None
self.cross_attention_index_map = None self.cross_attention_index_map: Optional[torch.Tensor] = None
self.self_cross_attention_action = Context.Action.NONE
self.tokens_cross_attention_action = Context.Action.NONE
self.arguments = arguments self.arguments = arguments
self.step_count = step_count
self.self_cross_attention_module_identifiers = []
self.tokens_cross_attention_module_identifiers = []
self.saved_cross_attention_maps = {}
self.clear_requests(cleanup=True)
def register_cross_attention_modules(self, model):
for name, _module in get_cross_attention_modules(model, CrossAttentionType.SELF):
if name in self.self_cross_attention_module_identifiers:
raise AssertionError(f"name {name} cannot appear more than once")
self.self_cross_attention_module_identifiers.append(name)
for name, _module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
if name in self.tokens_cross_attention_module_identifiers:
raise AssertionError(f"name {name} cannot appear more than once")
self.tokens_cross_attention_module_identifiers.append(name)
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
if cross_attention_type == CrossAttentionType.SELF:
self.self_cross_attention_action = Context.Action.SAVE
else:
self.tokens_cross_attention_action = Context.Action.SAVE
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
if cross_attention_type == CrossAttentionType.SELF:
self.self_cross_attention_action = Context.Action.APPLY
else:
self.tokens_cross_attention_action = Context.Action.APPLY
def is_tokens_cross_attention(self, module_identifier) -> bool:
return module_identifier in self.tokens_cross_attention_module_identifiers
def get_should_save_maps(self, module_identifier: str) -> bool:
if module_identifier in self.self_cross_attention_module_identifiers:
return self.self_cross_attention_action == Context.Action.SAVE
elif module_identifier in self.tokens_cross_attention_module_identifiers:
return self.tokens_cross_attention_action == Context.Action.SAVE
return False
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
if module_identifier in self.self_cross_attention_module_identifiers:
return self.self_cross_attention_action == Context.Action.APPLY
elif module_identifier in self.tokens_cross_attention_module_identifiers:
return self.tokens_cross_attention_action == Context.Action.APPLY
return False
def get_active_cross_attention_control_types_for_step( def get_active_cross_attention_control_types_for_step(
self, percent_through: float = None self, percent_through: float = None
@ -111,219 +47,8 @@ class Context:
to_control.append(CrossAttentionType.TOKENS) to_control.append(CrossAttentionType.TOKENS)
return to_control return to_control
def save_slice(
self,
identifier: str,
slice: torch.Tensor,
dim: Optional[int],
offset: int,
slice_size: Optional[int],
):
if identifier not in self.saved_cross_attention_maps:
self.saved_cross_attention_maps[identifier] = {
"dim": dim,
"slice_size": slice_size,
"slices": {offset or 0: slice},
}
else:
self.saved_cross_attention_maps[identifier]["slices"][offset or 0] = slice
def get_slice( def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: CrossAttnControlContext):
self,
identifier: str,
requested_dim: Optional[int],
requested_offset: int,
slice_size: int,
):
saved_attention_dict = self.saved_cross_attention_maps[identifier]
if requested_dim is None:
if saved_attention_dict["dim"] is not None:
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
return saved_attention_dict["slices"][0]
if saved_attention_dict["dim"] == requested_dim:
if slice_size != saved_attention_dict["slice_size"]:
raise RuntimeError(
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}"
)
return saved_attention_dict["slices"][requested_offset]
if saved_attention_dict["dim"] is None:
whole_saved_attention = saved_attention_dict["slices"][0]
if requested_dim == 0:
return whole_saved_attention[requested_offset : requested_offset + slice_size]
elif requested_dim == 1:
return whole_saved_attention[:, requested_offset : requested_offset + slice_size]
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
if saved_attention is None:
return None, None
return saved_attention["dim"], saved_attention["slice_size"]
def clear_requests(self, cleanup=True):
self.tokens_cross_attention_action = Context.Action.NONE
self.self_cross_attention_action = Context.Action.NONE
if cleanup:
self.saved_cross_attention_maps = {}
def offload_saved_attention_slices_to_cpu(self):
for _key, map_dict in self.saved_cross_attention_maps.items():
for offset, slice in map_dict["slices"].items():
map_dict[offset] = slice.to("cpu")
class InvokeAICrossAttentionMixin:
"""
Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
and dymamic slicing strategy selection.
"""
def __init__(self):
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
self.attention_slice_wrangler = None
self.slicing_strategy_getter = None
self.attention_slice_calculated_callback = None
def set_attention_slice_wrangler(
self,
wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]],
):
"""
Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
`module` is the current Attention module for which the callback is being invoked.
`suggested_attention_slice` is the default-calculated attention slice
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
Pass None to use the default attention calculation.
:return:
"""
self.attention_slice_wrangler = wrangler
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]):
self.slicing_strategy_getter = getter
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
self.attention_slice_calculated_callback = callback
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
# calculate attention scores
# attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
attention_scores = torch.baddbmm(
torch.empty(
query.shape[0],
query.shape[1],
key.shape[1],
dtype=query.dtype,
device=query.device,
),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
# calculate attention slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
attention_slice_wrangler = self.attention_slice_wrangler
if attention_slice_wrangler is not None:
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
if self.attention_slice_calculated_callback is not None:
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
hidden_states = torch.bmm(attention_slice, value)
return hidden_states
def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
return r
def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
return r
def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
return self.einsum_op_slice_dim0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return self.einsum_lowest_level(q, k, v, None, None, None)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
slicing_strategy_getter = self.slicing_strategy_getter
if slicing_strategy_getter is not None:
(dim, slice_size) = slicing_strategy_getter(self)
if dim is not None:
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
if dim == 0:
return self.einsum_op_slice_dim0(q, k, v, slice_size)
elif dim == 1:
return self.einsum_op_slice_dim1(q, k, v, slice_size)
# fallback for when there is no saved strategy, or saved strategy does not slice
mem_free_total = get_mem_free_total(q.device)
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def get_invokeai_attention_mem_efficient(self, q, k, v):
if q.device.type == "cuda":
# print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
return self.einsum_op_cuda(q, k, v)
if q.device.type == "mps" or q.device.type == "cpu":
if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v)
return self.einsum_op_mps_v2(q, k, v)
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return self.einsum_op_tensor_mem(q, k, v, 32)
def restore_default_cross_attention(
model,
is_running_diffusers: bool,
restore_attention_processor: Optional[AttentionProcessor] = None,
):
if is_running_diffusers:
unet = model
unet.set_attn_processor(restore_attention_processor or AttnProcessor())
else:
remove_attention_function(model)
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
""" """
Inject attention parameters and functions into the passed in model to enable cross attention editing. Inject attention parameters and functions into the passed in model to enable cross attention editing.
@ -362,170 +87,6 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
cross_attention_class: type = InvokeAIDiffusersCrossAttention
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
attention_module_tuples = [
(name, module)
for name, module in model.named_modules()
if isinstance(module, cross_attention_class) and which_attn in name
]
cross_attention_modules_in_model_count = len(attention_module_tuples)
expected_count = 16
if cross_attention_modules_in_model_count != expected_count:
# non-fatal error but .swap() won't work.
logger.error(
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching "
"failed or some assumption has changed about the structure of the model itself. Please fix the "
f"monkey-patching, and/or update the {expected_count} above to an appropriate number, and/or find and "
"inform someone who knows what it means. This error is non-fatal, but it is likely that .swap() and "
"attention map display will not work properly until it is fixed."
)
return attention_module_tuples
def inject_attention_function(unet, context: Context):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size):
# memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
attention_slice = suggested_attention_slice
if context.get_should_save_maps(module.identifier):
# print(module.identifier, "saving suggested_attention_slice of shape",
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
slice_to_save = attention_slice.to("cpu") if dim is not None else attention_slice
context.save_slice(
module.identifier,
slice_to_save,
dim=dim,
offset=offset,
slice_size=slice_size,
)
elif context.get_should_apply_saved_maps(module.identifier):
# print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
# slice may have been offloaded to CPU
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
if context.is_tokens_cross_attention(module.identifier):
index_map = context.cross_attention_index_map
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
this_attention_slice = suggested_attention_slice
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
saved_mask = mask
this_mask = 1 - mask
attention_slice = remapped_saved_attention_slice * saved_mask + this_attention_slice * this_mask
else:
# just use everything
attention_slice = saved_attention_slice
return attention_slice
cross_attention_modules = get_cross_attention_modules(
unet, CrossAttentionType.TOKENS
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
for identifier, module in cross_attention_modules:
module.identifier = identifier
try:
module.set_attention_slice_wrangler(attention_slice_wrangler)
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
except AttributeError as e:
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
else:
raise
def remove_attention_function(unet):
cross_attention_modules = get_cross_attention_modules(
unet, CrossAttentionType.TOKENS
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
for _identifier, module in cross_attention_modules:
try:
# clear wrangler callback
module.set_attention_slice_wrangler(None)
module.set_slicing_strategy_getter(None)
except AttributeError as e:
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
else:
raise
def is_attribute_error_about(error: AttributeError, attribute: str):
if hasattr(error, "name"): # Python 3.10
return error.name == attribute
else: # Python 3.9
return attribute in str(error)
def get_mem_free_total(device):
# only on cuda
if not torch.cuda.is_available():
return None
stats = torch.cuda.memory_stats(device)
mem_active = stats["active_bytes.all.current"]
mem_reserved = stats["reserved_bytes.all.current"]
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
return mem_free_total
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
InvokeAICrossAttentionMixin.__init__(self)
def _attention(self, query, key, value, attention_mask=None):
# default_result = super()._attention(query, key, value)
if attention_mask is not None:
print(f"{type(self).__name__} ignoring passed-in attention_mask")
attention_result = self.get_invokeai_attention_mem_efficient(query, key, value)
hidden_states = self.reshape_batch_dim_to_heads(attention_result)
return hidden_states
## 🧨diffusers implementation follows
"""
# base implementation
class AttnProcessor:
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
"""
@dataclass @dataclass
class SwapCrossAttnContext: class SwapCrossAttnContext:
modified_text_embeddings: torch.Tensor modified_text_embeddings: torch.Tensor

View File

@ -1,100 +0,0 @@
import math
from typing import Optional
import torch
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms.functional import resize as tv_resize
class AttentionMapSaver:
def __init__(self, token_ids: range, latents_shape: torch.Size):
self.token_ids = token_ids
self.latents_shape = latents_shape
# self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
self.collated_maps: dict[str, torch.Tensor] = {}
def clear_maps(self):
self.collated_maps = {}
def add_attention_maps(self, maps: torch.Tensor, key: str):
"""
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
:return: None
"""
key_and_size = f"{key}_{maps.shape[1]}"
# extract desired tokens
maps = maps[:, :, self.token_ids]
# merge attention heads to a single map per token
maps = torch.sum(maps, 0)
# store
if key_and_size not in self.collated_maps:
self.collated_maps[key_and_size] = torch.zeros_like(maps, device="cpu")
self.collated_maps[key_and_size] += maps.cpu()
def write_maps_to_disk(self, path: str):
pil_image = self.get_stacked_maps_image()
if pil_image is not None:
pil_image.save(path, "PNG")
def get_stacked_maps_image(self) -> Optional[Image.Image]:
"""
Scale all collected attention maps to the same size, blend them together and return as an image.
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
"""
num_tokens = len(self.token_ids)
if num_tokens == 0:
return None
latents_height = self.latents_shape[0]
latents_width = self.latents_shape[1]
merged = None
for _key, maps in self.collated_maps.items():
# maps has shape [(H*W), N] for N tokens
# but we want [N, H, W]
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
this_maps_height = int(float(latents_height) * this_scale_factor)
this_maps_width = int(float(latents_width) * this_scale_factor)
# and we need to do some dimension juggling
maps = torch.reshape(
torch.swapdims(maps, 0, 1),
[num_tokens, this_maps_height, this_maps_width],
)
# scale to output size if necessary
if this_scale_factor != 1:
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
# normalize
maps_min = torch.min(maps)
maps_range = torch.max(maps) - maps_min
# print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
maps_normalized = (maps - maps_min) / maps_range
# expand to (-0.1, 1.1) and clamp
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
# merge together, producing a vertical stack
maps_stacked = torch.reshape(
maps_normalized_expanded_clamped,
[num_tokens * latents_height, latents_width],
)
if merged is None:
merged = maps_stacked
else:
# screen blend
merged = 1 - (1 - maps_stacked) * (1 - merged)
if merged is None:
return None
merged_bytes = merged.mul(0xFF).byte()
return Image.fromarray(merged_bytes.numpy(), mode="L")

View File

@ -17,13 +17,11 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
) )
from .cross_attention_control import ( from .cross_attention_control import (
Context,
CrossAttentionType, CrossAttentionType,
CrossAttnControlContext,
SwapCrossAttnContext, SwapCrossAttnContext,
get_cross_attention_modules,
setup_cross_attention_control_attention_processors, setup_cross_attention_control_attention_processors,
) )
from .cross_attention_map_saving import AttentionMapSaver
ModelForwardCallback: TypeAlias = Union[ ModelForwardCallback: TypeAlias = Union[
# x, t, conditioning, Optional[cross-attention kwargs] # x, t, conditioning, Optional[cross-attention kwargs]
@ -69,14 +67,12 @@ class InvokeAIDiffuserComponent:
self, self,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
extra_conditioning_info: Optional[ExtraConditioningInfo], extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int,
): ):
old_attn_processors = unet.attn_processors old_attn_processors = unet.attn_processors
try: try:
self.cross_attention_control_context = Context( self.cross_attention_control_context = CrossAttnControlContext(
arguments=extra_conditioning_info.cross_attention_control_args, arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
) )
setup_cross_attention_control_attention_processors( setup_cross_attention_control_attention_processors(
unet, unet,
@ -87,27 +83,6 @@ class InvokeAIDiffuserComponent:
finally: finally:
self.cross_attention_control_context = None self.cross_attention_control_context = None
unet.set_attn_processor(old_attn_processors) unet.set_attn_processor(old_attn_processors)
# TODO resuscitate attention map saving
# self.remove_attention_map_saving()
def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key):
if dim is not None:
# sliced tokens attention map saving is not implemented
return
saver.add_attention_maps(slice, key)
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
for identifier, module in tokens_cross_attention_modules:
key = "down" if identifier.startswith("down") else "up" if identifier.startswith("up") else "mid"
module.set_attention_slice_calculated_callback(
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key)
)
def remove_attention_map_saving(self):
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
for _, module in tokens_cross_attention_modules:
module.set_attention_slice_calculated_callback(None)
def do_controlnet_step( def do_controlnet_step(
self, self,
@ -592,54 +567,3 @@ class InvokeAIDiffuserComponent:
self.last_percent_through = percent_through self.last_percent_through = percent_through
return latents.to(device=dev) return latents.to(device=dev)
# todo: make this work
@classmethod
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) # aka sigmas
deltas = None
uncond_latents = None
weighted_cond_list = (
c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)]
)
# below is fugly omg
conditionings = [uc] + [c for c, weight in weighted_cond_list]
weights = [1] + [weight for c, weight in weighted_cond_list]
chunk_count = math.ceil(len(conditionings) / 2)
deltas = None
for chunk_index in range(chunk_count):
offset = chunk_index * 2
chunk_size = min(2, len(conditionings) - offset)
if chunk_size == 1:
c_in = conditionings[offset]
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
latents_b = None
else:
c_in = torch.cat(conditionings[offset : offset + 2])
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
if chunk_index == 0:
uncond_latents = latents_a
deltas = latents_b - uncond_latents
else:
deltas = torch.cat((deltas, latents_a - uncond_latents))
if latents_b is not None:
deltas = torch.cat((deltas, latents_b - uncond_latents))
# merge the weighted deltas together into a single merged delta
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
normalize = False
if normalize:
per_delta_weights /= torch.sum(per_delta_weights)
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
return uncond_latents + deltas_merged * global_guidance_scale

View File

@ -858,9 +858,9 @@ def do_textual_inversion_training(
# Let's make sure we don't update any embedding weights besides the newly added token # Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
with torch.no_grad(): with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
orig_embeds_params[index_no_updates] index_no_updates
) ] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:

View File

@ -144,7 +144,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.nextrely = top_of_table self.nextrely = top_of_table
self.lora_models = self.add_model_widgets( self.lora_models = self.add_model_widgets(
model_type=ModelType.Lora, model_type=ModelType.LoRA,
window_width=window_width, window_width=window_width,
) )
bottom_of_table = max(bottom_of_table, self.nextrely) bottom_of_table = max(bottom_of_table, self.nextrely)

View File

@ -30,7 +30,7 @@
"lint:prettier": "prettier --check .", "lint:prettier": "prettier --check .",
"lint:tsc": "tsc --noEmit", "lint:tsc": "tsc --noEmit",
"lint": "concurrently -g -c red,green,yellow,blue,magenta pnpm:lint:*", "lint": "concurrently -g -c red,green,yellow,blue,magenta pnpm:lint:*",
"fix": "knip --fix && eslint --fix . && prettier --log-level warn --write .", "fix": "eslint --fix . && prettier --log-level warn --write .",
"preinstall": "npx only-allow pnpm", "preinstall": "npx only-allow pnpm",
"storybook": "storybook dev -p 6006", "storybook": "storybook dev -p 6006",
"build-storybook": "storybook build", "build-storybook": "storybook build",

View File

@ -134,8 +134,6 @@
"loadMore": "Mehr laden", "loadMore": "Mehr laden",
"noImagesInGallery": "Keine Bilder in der Galerie", "noImagesInGallery": "Keine Bilder in der Galerie",
"loading": "Lade", "loading": "Lade",
"preparingDownload": "bereite Download vor",
"preparingDownloadFailed": "Problem beim Download vorbereiten",
"deleteImage": "Lösche Bild", "deleteImage": "Lösche Bild",
"copy": "Kopieren", "copy": "Kopieren",
"download": "Runterladen", "download": "Runterladen",
@ -967,7 +965,7 @@
"resumeFailed": "Problem beim Fortsetzen des Prozesses", "resumeFailed": "Problem beim Fortsetzen des Prozesses",
"pruneFailed": "Problem beim leeren der Warteschlange", "pruneFailed": "Problem beim leeren der Warteschlange",
"pauseTooltip": "Prozess anhalten", "pauseTooltip": "Prozess anhalten",
"back": "Hinten", "back": "Ende",
"resumeSucceeded": "Prozess wird fortgesetzt", "resumeSucceeded": "Prozess wird fortgesetzt",
"resumeTooltip": "Prozess wieder aufnehmen", "resumeTooltip": "Prozess wieder aufnehmen",
"time": "Zeit", "time": "Zeit",

View File

@ -78,6 +78,7 @@
"aboutDesc": "Using Invoke for work? Check out:", "aboutDesc": "Using Invoke for work? Check out:",
"aboutHeading": "Own Your Creative Power", "aboutHeading": "Own Your Creative Power",
"accept": "Accept", "accept": "Accept",
"add": "Add",
"advanced": "Advanced", "advanced": "Advanced",
"advancedOptions": "Advanced Options", "advancedOptions": "Advanced Options",
"ai": "ai", "ai": "ai",
@ -303,6 +304,12 @@
"method": "High Resolution Fix Method" "method": "High Resolution Fix Method"
} }
}, },
"prompt": {
"addPromptTrigger": "Add Prompt Trigger",
"compatibleEmbeddings": "Compatible Embeddings",
"noPromptTriggers": "No triggers available",
"noMatchingTriggers": "No matching triggers"
},
"embedding": { "embedding": {
"addEmbedding": "Add Embedding", "addEmbedding": "Add Embedding",
"incompatibleModel": "Incompatible base model:", "incompatibleModel": "Incompatible base model:",
@ -734,6 +741,8 @@
"customConfig": "Custom Config", "customConfig": "Custom Config",
"customConfigFileLocation": "Custom Config File Location", "customConfigFileLocation": "Custom Config File Location",
"customSaveLocation": "Custom Save Location", "customSaveLocation": "Custom Save Location",
"defaultSettings": "Default Settings",
"defaultSettingsSaved": "Default Settings Saved",
"delete": "Delete", "delete": "Delete",
"deleteConfig": "Delete Config", "deleteConfig": "Delete Config",
"deleteModel": "Delete Model", "deleteModel": "Delete Model",
@ -768,6 +777,7 @@
"mergedModelName": "Merged Model Name", "mergedModelName": "Merged Model Name",
"mergedModelSaveLocation": "Save Location", "mergedModelSaveLocation": "Save Location",
"mergeModels": "Merge Models", "mergeModels": "Merge Models",
"metadata": "Metadata",
"model": "Model", "model": "Model",
"modelAdded": "Model Added", "modelAdded": "Model Added",
"modelConversionFailed": "Model Conversion Failed", "modelConversionFailed": "Model Conversion Failed",
@ -839,9 +849,12 @@
"statusConverting": "Converting", "statusConverting": "Converting",
"syncModels": "Sync Models", "syncModels": "Sync Models",
"syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.", "syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.",
"triggerPhrases": "Trigger Phrases",
"typePhraseHere": "Type phrase here",
"upcastAttention": "Upcast Attention", "upcastAttention": "Upcast Attention",
"updateModel": "Update Model", "updateModel": "Update Model",
"useCustomConfig": "Use Custom Config", "useCustomConfig": "Use Custom Config",
"useDefaultSettings": "Use Default Settings",
"v1": "v1", "v1": "v1",
"v2_768": "v2 (768px)", "v2_768": "v2 (768px)",
"v2_base": "v2 (512px)", "v2_base": "v2 (512px)",
@ -860,6 +873,7 @@
"models": { "models": {
"addLora": "Add LoRA", "addLora": "Add LoRA",
"allLoRAsAdded": "All LoRAs added", "allLoRAsAdded": "All LoRAs added",
"concepts": "Concepts",
"loraAlreadyAdded": "LoRA already added", "loraAlreadyAdded": "LoRA already added",
"esrganModel": "ESRGAN Model", "esrganModel": "ESRGAN Model",
"loading": "loading", "loading": "loading",

View File

@ -505,8 +505,6 @@
"seamLowThreshold": "Bajo", "seamLowThreshold": "Bajo",
"coherencePassHeader": "Parámetros de la coherencia", "coherencePassHeader": "Parámetros de la coherencia",
"compositingSettingsHeader": "Ajustes de la composición", "compositingSettingsHeader": "Ajustes de la composición",
"coherenceSteps": "Pasos",
"coherenceStrength": "Fuerza",
"patchmatchDownScaleSize": "Reducir a escala", "patchmatchDownScaleSize": "Reducir a escala",
"coherenceMode": "Modo" "coherenceMode": "Modo"
}, },

View File

@ -114,7 +114,8 @@
"checkpoint": "Checkpoint", "checkpoint": "Checkpoint",
"safetensors": "Safetensors", "safetensors": "Safetensors",
"ai": "ia", "ai": "ia",
"file": "File" "file": "File",
"toResolve": "Da risolvere"
}, },
"gallery": { "gallery": {
"generations": "Generazioni", "generations": "Generazioni",
@ -142,8 +143,6 @@
"copy": "Copia", "copy": "Copia",
"download": "Scarica", "download": "Scarica",
"setCurrentImage": "Imposta come immagine corrente", "setCurrentImage": "Imposta come immagine corrente",
"preparingDownload": "Preparazione del download",
"preparingDownloadFailed": "Problema durante la preparazione del download",
"downloadSelection": "Scarica gli elementi selezionati", "downloadSelection": "Scarica gli elementi selezionati",
"noImageSelected": "Nessuna immagine selezionata", "noImageSelected": "Nessuna immagine selezionata",
"deleteSelection": "Elimina la selezione", "deleteSelection": "Elimina la selezione",
@ -609,8 +608,6 @@
"seamLowThreshold": "Basso", "seamLowThreshold": "Basso",
"seamHighThreshold": "Alto", "seamHighThreshold": "Alto",
"coherencePassHeader": "Passaggio di coerenza", "coherencePassHeader": "Passaggio di coerenza",
"coherenceSteps": "Passi",
"coherenceStrength": "Forza",
"compositingSettingsHeader": "Impostazioni di composizione", "compositingSettingsHeader": "Impostazioni di composizione",
"patchmatchDownScaleSize": "Ridimensiona", "patchmatchDownScaleSize": "Ridimensiona",
"coherenceMode": "Modalità", "coherenceMode": "Modalità",
@ -1400,19 +1397,6 @@
"Regola la maschera." "Regola la maschera."
] ]
}, },
"compositingCoherenceSteps": {
"heading": "Passi",
"paragraphs": [
"Numero di passi utilizzati nel Passaggio di Coerenza.",
"Simile ai passi di generazione."
]
},
"compositingBlur": {
"heading": "Sfocatura",
"paragraphs": [
"Il raggio di sfocatura della maschera."
]
},
"compositingCoherenceMode": { "compositingCoherenceMode": {
"heading": "Modalità", "heading": "Modalità",
"paragraphs": [ "paragraphs": [
@ -1431,13 +1415,6 @@
"Un secondo ciclo di riduzione del rumore aiuta a comporre l'immagine Inpaint/Outpaint." "Un secondo ciclo di riduzione del rumore aiuta a comporre l'immagine Inpaint/Outpaint."
] ]
}, },
"compositingStrength": {
"heading": "Forza",
"paragraphs": [
"Quantità di rumore aggiunta per il Passaggio di Coerenza.",
"Simile alla forza di riduzione del rumore."
]
},
"paramNegativeConditioning": { "paramNegativeConditioning": {
"paragraphs": [ "paragraphs": [
"Il processo di generazione evita i concetti nel prompt negativo. Utilizzatelo per escludere qualità o oggetti dall'output.", "Il processo di generazione evita i concetti nel prompt negativo. Utilizzatelo per escludere qualità o oggetti dall'output.",

View File

@ -123,8 +123,6 @@
"autoSwitchNewImages": "새로운 이미지로 자동 전환", "autoSwitchNewImages": "새로운 이미지로 자동 전환",
"loading": "불러오는 중", "loading": "불러오는 중",
"unableToLoad": "갤러리를 로드할 수 없음", "unableToLoad": "갤러리를 로드할 수 없음",
"preparingDownload": "다운로드 준비",
"preparingDownloadFailed": "다운로드 준비 중 발생한 문제",
"singleColumnLayout": "단일 열 레이아웃", "singleColumnLayout": "단일 열 레이아웃",
"image": "이미지", "image": "이미지",
"loadMore": "더 불러오기", "loadMore": "더 불러오기",

View File

@ -97,8 +97,6 @@
"featuresWillReset": "Als je deze afbeelding verwijdert, dan worden deze functies onmiddellijk teruggezet.", "featuresWillReset": "Als je deze afbeelding verwijdert, dan worden deze functies onmiddellijk teruggezet.",
"loading": "Bezig met laden", "loading": "Bezig met laden",
"unableToLoad": "Kan galerij niet laden", "unableToLoad": "Kan galerij niet laden",
"preparingDownload": "Bezig met voorbereiden van download",
"preparingDownloadFailed": "Fout bij voorbereiden van download",
"downloadSelection": "Download selectie", "downloadSelection": "Download selectie",
"currentlyInUse": "Deze afbeelding is momenteel in gebruik door de volgende functies:", "currentlyInUse": "Deze afbeelding is momenteel in gebruik door de volgende functies:",
"copy": "Kopieer", "copy": "Kopieer",
@ -535,8 +533,6 @@
"coherencePassHeader": "Coherentiestap", "coherencePassHeader": "Coherentiestap",
"maskBlur": "Vervaag", "maskBlur": "Vervaag",
"maskBlurMethod": "Vervagingsmethode", "maskBlurMethod": "Vervagingsmethode",
"coherenceSteps": "Stappen",
"coherenceStrength": "Sterkte",
"seamHighThreshold": "Hoog", "seamHighThreshold": "Hoog",
"seamLowThreshold": "Laag", "seamLowThreshold": "Laag",
"invoke": { "invoke": {
@ -1139,13 +1135,6 @@
"Een afbeeldingsgrootte (in aantal pixels) equivalent aan 512x512 wordt aanbevolen voor SD1.5-modellen. Een grootte-equivalent van 1024x1024 wordt aanbevolen voor SDXL-modellen." "Een afbeeldingsgrootte (in aantal pixels) equivalent aan 512x512 wordt aanbevolen voor SD1.5-modellen. Een grootte-equivalent van 1024x1024 wordt aanbevolen voor SDXL-modellen."
] ]
}, },
"compositingCoherenceSteps": {
"heading": "Stappen",
"paragraphs": [
"Het aantal te gebruiken ontruisingsstappen in de coherentiefase.",
"Gelijk aan de hoofdparameter Stappen."
]
},
"dynamicPrompts": { "dynamicPrompts": {
"paragraphs": [ "paragraphs": [
"Dynamische prompts vormt een enkele prompt om in vele.", "Dynamische prompts vormt een enkele prompt om in vele.",
@ -1160,12 +1149,6 @@
], ],
"heading": "VAE" "heading": "VAE"
}, },
"compositingBlur": {
"heading": "Vervaging",
"paragraphs": [
"De vervagingsstraal van het masker."
]
},
"paramIterations": { "paramIterations": {
"paragraphs": [ "paragraphs": [
"Het aantal te genereren afbeeldingen.", "Het aantal te genereren afbeeldingen.",
@ -1240,13 +1223,6 @@
], ],
"heading": "Ontruisingssterkte" "heading": "Ontruisingssterkte"
}, },
"compositingStrength": {
"heading": "Sterkte",
"paragraphs": [
"Ontruisingssterkte voor de coherentiefase.",
"Gelijk aan de parameter Ontruisingssterkte Afbeelding naar afbeelding."
]
},
"paramNegativeConditioning": { "paramNegativeConditioning": {
"paragraphs": [ "paragraphs": [
"Het genereerproces voorkomt de gegeven begrippen in de negatieve prompt. Gebruik dit om bepaalde zaken of voorwerpen uit te sluiten van de uitvoerafbeelding.", "Het genereerproces voorkomt de gegeven begrippen in de negatieve prompt. Gebruik dit om bepaalde zaken of voorwerpen uit te sluiten van de uitvoerafbeelding.",

View File

@ -143,8 +143,6 @@
"problemDeletingImagesDesc": "Не удалось удалить одно или несколько изображений", "problemDeletingImagesDesc": "Не удалось удалить одно или несколько изображений",
"loading": "Загрузка", "loading": "Загрузка",
"unableToLoad": "Невозможно загрузить галерею", "unableToLoad": "Невозможно загрузить галерею",
"preparingDownload": "Подготовка к скачиванию",
"preparingDownloadFailed": "Проблема с подготовкой к скачиванию",
"image": "изображение", "image": "изображение",
"drop": "перебросить", "drop": "перебросить",
"problemDeletingImages": "Проблема с удалением изображений", "problemDeletingImages": "Проблема с удалением изображений",
@ -612,9 +610,7 @@
"maskBlurMethod": "Метод размытия", "maskBlurMethod": "Метод размытия",
"seamLowThreshold": "Низкий", "seamLowThreshold": "Низкий",
"seamHighThreshold": "Высокий", "seamHighThreshold": "Высокий",
"coherenceSteps": "Шагов",
"coherencePassHeader": "Порог Coherence", "coherencePassHeader": "Порог Coherence",
"coherenceStrength": "Сила",
"compositingSettingsHeader": "Настройки компоновки", "compositingSettingsHeader": "Настройки компоновки",
"invoke": { "invoke": {
"noNodesInGraph": "Нет узлов в графе", "noNodesInGraph": "Нет узлов в графе",
@ -1321,13 +1317,6 @@
"Размер изображения (в пикселях), эквивалентный 512x512, рекомендуется для моделей SD1.5, а размер, эквивалентный 1024x1024, рекомендуется для моделей SDXL." "Размер изображения (в пикселях), эквивалентный 512x512, рекомендуется для моделей SD1.5, а размер, эквивалентный 1024x1024, рекомендуется для моделей SDXL."
] ]
}, },
"compositingCoherenceSteps": {
"heading": "Шаги",
"paragraphs": [
"Количество шагов снижения шума, используемых при прохождении когерентности.",
"То же, что и основной параметр «Шаги»."
]
},
"dynamicPrompts": { "dynamicPrompts": {
"paragraphs": [ "paragraphs": [
"Динамические запросы превращают одно приглашение на множество.", "Динамические запросы превращают одно приглашение на множество.",
@ -1342,12 +1331,6 @@
], ],
"heading": "VAE" "heading": "VAE"
}, },
"compositingBlur": {
"heading": "Размытие",
"paragraphs": [
"Радиус размытия маски."
]
},
"paramIterations": { "paramIterations": {
"paragraphs": [ "paragraphs": [
"Количество изображений, которые нужно сгенерировать.", "Количество изображений, которые нужно сгенерировать.",
@ -1422,13 +1405,6 @@
], ],
"heading": "Шумоподавление" "heading": "Шумоподавление"
}, },
"compositingStrength": {
"heading": "Сила",
"paragraphs": [
null,
"То же, что параметр «Сила шумоподавления img2img»."
]
},
"paramNegativeConditioning": { "paramNegativeConditioning": {
"paragraphs": [ "paragraphs": [
"Stable Diffusion пытается избежать указанных в отрицательном запросе концепций. Используйте это, чтобы исключить качества или объекты из вывода.", "Stable Diffusion пытается избежать указанных в отрицательном запросе концепций. Используйте это, чтобы исключить качества или объекты из вывода.",

View File

@ -355,7 +355,6 @@
"starImage": "Yıldız Koy", "starImage": "Yıldız Koy",
"download": "İndir", "download": "İndir",
"deleteSelection": "Seçileni Sil", "deleteSelection": "Seçileni Sil",
"preparingDownloadFailed": "İndirme Hazırlanırken Sorun",
"problemDeletingImages": "Görsel Silmede Sorun", "problemDeletingImages": "Görsel Silmede Sorun",
"featuresWillReset": "Bu görseli silerseniz, o özellikler resetlenecektir.", "featuresWillReset": "Bu görseli silerseniz, o özellikler resetlenecektir.",
"galleryImageResetSize": "Boyutu Resetle", "galleryImageResetSize": "Boyutu Resetle",
@ -377,7 +376,6 @@
"setCurrentImage": "Çalışma Görseli Yap", "setCurrentImage": "Çalışma Görseli Yap",
"unableToLoad": "Galeri Yüklenemedi", "unableToLoad": "Galeri Yüklenemedi",
"downloadSelection": "Seçileni İndir", "downloadSelection": "Seçileni İndir",
"preparingDownload": "İndirmeye Hazırlanıyor",
"singleColumnLayout": "Tek Sütun Düzen", "singleColumnLayout": "Tek Sütun Düzen",
"generations": ıktılar", "generations": ıktılar",
"showUploads": "Yüklenenleri Göster", "showUploads": "Yüklenenleri Göster",
@ -723,7 +721,6 @@
"clipSkip": "CLIP Atlama", "clipSkip": "CLIP Atlama",
"randomizeSeed": "Rastgele Tohum", "randomizeSeed": "Rastgele Tohum",
"cfgScale": "CFG Ölçeği", "cfgScale": "CFG Ölçeği",
"coherenceStrength": "Etki",
"controlNetControlMode": "Yönetim Kipi", "controlNetControlMode": "Yönetim Kipi",
"general": "Genel", "general": "Genel",
"img2imgStrength": "Görselden Görsel Ölçüsü", "img2imgStrength": "Görselden Görsel Ölçüsü",
@ -793,7 +790,6 @@
"cfgRescaleMultiplier": "CFG Rescale Çarpanı", "cfgRescaleMultiplier": "CFG Rescale Çarpanı",
"cfgRescale": "CFG Rescale", "cfgRescale": "CFG Rescale",
"coherencePassHeader": "Uyum Geçişi", "coherencePassHeader": "Uyum Geçişi",
"coherenceSteps": "Adım",
"infillMethod": "Doldurma Yöntemi", "infillMethod": "Doldurma Yöntemi",
"maskBlurMethod": "Bulandırma Yöntemi", "maskBlurMethod": "Bulandırma Yöntemi",
"steps": "Adım", "steps": "Adım",

View File

@ -136,8 +136,6 @@
"copy": "复制", "copy": "复制",
"download": "下载", "download": "下载",
"setCurrentImage": "设为当前图像", "setCurrentImage": "设为当前图像",
"preparingDownload": "准备下载",
"preparingDownloadFailed": "准备下载时出现问题",
"downloadSelection": "下载所选内容", "downloadSelection": "下载所选内容",
"noImageSelected": "无选中的图像", "noImageSelected": "无选中的图像",
"deleteSelection": "删除所选内容", "deleteSelection": "删除所选内容",
@ -616,11 +614,9 @@
"incompatibleBaseModelForControlAdapter": "有 #{{number}} 个 Control Adapter 模型与主模型不兼容。" "incompatibleBaseModelForControlAdapter": "有 #{{number}} 个 Control Adapter 模型与主模型不兼容。"
}, },
"patchmatchDownScaleSize": "缩小", "patchmatchDownScaleSize": "缩小",
"coherenceSteps": "步数",
"clipSkip": "CLIP 跳过层", "clipSkip": "CLIP 跳过层",
"compositingSettingsHeader": "合成设置", "compositingSettingsHeader": "合成设置",
"useCpuNoise": "使用 CPU 噪声", "useCpuNoise": "使用 CPU 噪声",
"coherenceStrength": "强度",
"enableNoiseSettings": "启用噪声设置", "enableNoiseSettings": "启用噪声设置",
"coherenceMode": "模式", "coherenceMode": "模式",
"cpuNoise": "CPU 噪声", "cpuNoise": "CPU 噪声",
@ -1402,19 +1398,6 @@
"图像尺寸(单位:像素)建议 SD 1.5 模型使用等效 512x512 的尺寸SDXL 模型使用等效 1024x1024 的尺寸。" "图像尺寸(单位:像素)建议 SD 1.5 模型使用等效 512x512 的尺寸SDXL 模型使用等效 1024x1024 的尺寸。"
] ]
}, },
"compositingCoherenceSteps": {
"heading": "步数",
"paragraphs": [
"一致性层中使用的去噪步数。",
"与主参数中的步数相同。"
]
},
"compositingBlur": {
"heading": "模糊",
"paragraphs": [
"遮罩模糊半径。"
]
},
"noiseUseCPU": { "noiseUseCPU": {
"heading": "使用 CPU 噪声", "heading": "使用 CPU 噪声",
"paragraphs": [ "paragraphs": [
@ -1467,13 +1450,6 @@
"第二轮去噪有助于合成内补/外扩图像。" "第二轮去噪有助于合成内补/外扩图像。"
] ]
}, },
"compositingStrength": {
"heading": "强度",
"paragraphs": [
"一致性层使用的去噪强度。",
"去噪强度与图生图的参数相同。"
]
},
"paramNegativeConditioning": { "paramNegativeConditioning": {
"paragraphs": [ "paragraphs": [
"生成过程会避免生成负向提示词中的概念。使用此选项来使输出排除部分质量或对象。", "生成过程会避免生成负向提示词中的概念。使用此选项来使输出排除部分质量或对象。",

View File

@ -55,6 +55,8 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested'; import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store'; import type { AppDispatch, RootState } from 'app/store/store';
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>; export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
@ -151,5 +153,7 @@ addFirstListImagesListener(startAppListening);
// Ad-hoc upscale workflwo // Ad-hoc upscale workflwo
addUpscaleRequestedListener(startAppListening); addUpscaleRequestedListener(startAppListening);
// Dynamic prompts // Prompts
addDynamicPromptsListener(startAppListening); addDynamicPromptsListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);

View File

@ -7,8 +7,10 @@ import {
selectAllT2IAdapters, selectAllT2IAdapters,
} from 'features/controlAdapters/store/controlAdaptersSlice'; } from 'features/controlAdapters/store/controlAdaptersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice'; import { loraRemoved } from 'features/lora/store/loraSlice';
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice'; import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
import { heightChanged, modelChanged, vaeSelected, widthChanged } from 'features/parameters/store/generationSlice';
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas'; import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice'; import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es'; import { forEach, some } from 'lodash-es';
import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models'; import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models';
@ -24,7 +26,9 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
const log = logger('models'); const log = logger('models');
log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`); log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`);
const currentModel = getState().generation.model; const state = getState();
const currentModel = state.generation.model;
const models = mainModelsAdapterSelectors.selectAll(action.payload); const models = mainModelsAdapterSelectors.selectAll(action.payload);
if (models.length === 0) { if (models.length === 0) {
@ -39,6 +43,29 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
return; return;
} }
const defaultModel = state.config.sd.defaultModel;
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
if (defaultModelInList) {
const result = zParameterModel.safeParse(defaultModelInList);
if (result.success) {
dispatch(modelChanged(defaultModelInList, currentModel));
const optimalDimension = getOptimalDimension(defaultModelInList);
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
return;
}
const { width, height } = calculateNewSize(
state.generation.aspectRatio.value,
optimalDimension * optimalDimension
);
dispatch(widthChanged(width));
dispatch(heightChanged(height));
return;
}
}
const result = zParameterModel.safeParse(models[0]); const result = zParameterModel.safeParse(models[0]);
if (!result.success) { if (!result.success) {

View File

@ -0,0 +1,96 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { setDefaultSettings } from 'features/parameters/store/actions';
import {
setCfgRescaleMultiplier,
setCfgScale,
setScheduler,
setSteps,
vaePrecisionChanged,
vaeSelected,
} from 'features/parameters/store/generationSlice';
import {
isParameterCFGRescaleMultiplier,
isParameterCFGScale,
isParameterPrecision,
isParameterScheduler,
isParameterSteps,
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { map } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: setDefaultSettings,
effect: async (action, { dispatch, getState }) => {
const state = getState();
const currentModel = state.generation.model;
if (!currentModel) {
return;
}
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
if (!modelConfig || !modelConfig.default_settings) {
return;
}
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = modelConfig.default_settings;
if (vae) {
// we store this as "default" within default settings
// to distinguish it from no default set
if (vae === 'default') {
dispatch(vaeSelected(null));
} else {
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
const vaeArray = map(data?.entities);
const validVae = vaeArray.find((model) => model.key === vae);
const result = zParameterVAEModel.safeParse(validVae);
if (!result.success) {
return;
}
dispatch(vaeSelected(result.data));
}
}
if (vae_precision) {
if (isParameterPrecision(vae_precision)) {
dispatch(vaePrecisionChanged(vae_precision));
}
}
if (cfg_scale) {
if (isParameterCFGScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale));
}
}
if (cfg_rescale_multiplier) {
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
}
}
if (steps) {
if (isParameterSteps(steps)) {
dispatch(setSteps(steps));
}
}
if (scheduler) {
if (isParameterScheduler(scheduler)) {
dispatch(setScheduler(scheduler));
}
}
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
},
});
};

View File

@ -14,7 +14,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
const { bytes, total_bytes, id } = action.payload.data; const { bytes, total_bytes, id } = action.payload.data;
dispatch( dispatch(
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => { modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
if (modelImport) { if (modelImport) {
modelImport.bytes = bytes; modelImport.bytes = bytes;
@ -33,7 +33,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
const { id } = action.payload.data; const { id } = action.payload.data;
dispatch( dispatch(
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => { modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
if (modelImport) { if (modelImport) {
modelImport.status = 'completed'; modelImport.status = 'completed';
@ -41,7 +41,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
return draft; return draft;
}) })
); );
dispatch(api.util.invalidateTags([{ type: 'ModelConfig' }])); dispatch(api.util.invalidateTags(['Model']));
}, },
}); });
@ -51,7 +51,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
const { id, error, error_type } = action.payload.data; const { id, error, error_type } = action.payload.data;
dispatch( dispatch(
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => { modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
if (modelImport) { if (modelImport) {
modelImport.status = 'error'; modelImport.status = 'error';

View File

@ -1,4 +1,5 @@
import type { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; import type { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import type { InvokeTabName } from 'features/ui/store/tabMap'; import type { InvokeTabName } from 'features/ui/store/tabMap';
import type { O } from 'ts-toolbelt'; import type { O } from 'ts-toolbelt';
@ -82,6 +83,8 @@ export type AppConfig = {
guidance: NumericalParameterConfig; guidance: NumericalParameterConfig;
cfgRescaleMultiplier: NumericalParameterConfig; cfgRescaleMultiplier: NumericalParameterConfig;
img2imgStrength: NumericalParameterConfig; img2imgStrength: NumericalParameterConfig;
scheduler?: ParameterScheduler;
vaePrecision?: ParameterPrecision;
// Canvas // Canvas
boundingBoxHeight: NumericalParameterConfig; // initial value comes from model boundingBoxHeight: NumericalParameterConfig; // initial value comes from model
boundingBoxWidth: NumericalParameterConfig; // initial value comes from model boundingBoxWidth: NumericalParameterConfig; // initial value comes from model

View File

@ -1,21 +0,0 @@
import type { Meta, StoryObj } from '@storybook/react';
import { EmbeddingSelect } from './EmbeddingSelect';
import type { EmbeddingSelectProps } from './types';
const meta: Meta<typeof EmbeddingSelect> = {
title: 'Feature/Prompt/EmbeddingSelect',
tags: ['autodocs'],
component: EmbeddingSelect,
};
export default meta;
type Story = StoryObj<typeof EmbeddingSelect>;
const Component = (props: EmbeddingSelectProps) => {
return <EmbeddingSelect {...props}>Invoke</EmbeddingSelect>;
};
export const Default: Story = {
render: Component,
};

View File

@ -1,67 +0,0 @@
import type { ChakraProps } from '@invoke-ai/ui-library';
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import type { EmbeddingSelectProps } from 'features/embedding/types';
import { t } from 'i18next';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import type { TextualInversionModelConfig } from 'services/api/types';
const noOptionsMessage = () => t('embedding.noMatchingEmbedding');
export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const getIsDisabled = useCallback(
(embedding: TextualInversionModelConfig): boolean => {
const isCompatible = currentBaseModel === embedding.base;
const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible;
},
[currentBaseModel]
);
const { data, isLoading } = useGetTextualInversionModelsQuery();
const _onChange = useCallback(
(embedding: TextualInversionModelConfig | null) => {
if (!embedding) {
return;
}
onSelect(embedding.name);
},
[onSelect]
);
const { options, onChange } = useGroupedModelCombobox({
modelEntities: data,
getIsDisabled,
onChange: _onChange,
});
return (
<FormControl>
<Combobox
placeholder={isLoading ? t('common.loading') : t('embedding.addEmbedding')}
defaultMenuIsOpen
autoFocus
value={null}
options={options}
noOptionsMessage={noOptionsMessage}
onChange={onChange}
onMenuClose={onClose}
data-testid="add-embedding"
sx={selectStyles}
/>
</FormControl>
);
});
EmbeddingSelect.displayName = 'EmbeddingSelect';
const selectStyles: ChakraProps['sx'] = {
w: 'full',
};

View File

@ -59,7 +59,7 @@ const LoRASelect = () => {
return ( return (
<FormControl isDisabled={!options.length}> <FormControl isDisabled={!options.length}>
<InformationalPopover feature="lora"> <InformationalPopover feature="lora">
<FormLabel>{t('models.lora')} </FormLabel> <FormLabel>{t('models.concepts')} </FormLabel>
</InformationalPopover> </InformationalPopover>
<Combobox <Combobox
placeholder={placeholder} placeholder={placeholder}

View File

@ -1,228 +0,0 @@
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import BaseModelSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect';
import BooleanSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect';
import ModelFormatSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect';
import ModelTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect';
import ModelVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect';
import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect';
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { isNil, omitBy } from 'lodash-es';
import { useCallback, useEffect } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useInstallModelMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
export const AdvancedImport = () => {
const dispatch = useAppDispatch();
const [installModel] = useInstallModelMutation();
const { t } = useTranslation();
const {
register,
handleSubmit,
control,
formState: { errors },
setValue,
resetField,
reset,
watch,
} = useForm<AnyModelConfig>({
defaultValues: {
name: '',
base: 'sd-1',
type: 'main',
path: '',
description: '',
format: 'diffusers',
vae: '',
variant: 'normal',
},
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
(values) => {
installModel({
source: values.path,
config: omitBy(values, isNil),
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelAdded', {
modelName: values.name,
}),
status: 'success',
})
)
);
reset();
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddFailed'),
status: 'error',
})
)
);
}
});
},
[installModel, dispatch, t, reset]
);
const watchedModelType = watch('type');
const watchedModelFormat = watch('format');
useEffect(() => {
if (watchedModelType === 'main') {
setValue('format', 'diffusers');
setValue('repo_variant', '');
setValue('variant', 'normal');
}
if (watchedModelType === 'lora') {
setValue('format', 'lycoris');
} else if (watchedModelType === 'embedding') {
setValue('format', 'embedding_file');
} else if (watchedModelType === 'ip_adapter') {
setValue('format', 'invokeai');
} else {
setValue('format', 'diffusers');
}
resetField('upcast_attention');
resetField('ztsnr_training');
resetField('vae');
resetField('config');
resetField('prediction_type');
resetField('image_encoder_model_id');
}, [watchedModelType, resetField, setValue]);
return (
<ScrollableContent>
<form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
<Flex alignItems="flex-end" gap="4">
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.modelType')}</FormLabel>
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
</FormControl>
<Text px="2" fontSize="xs" textAlign="center">
{t('modelManager.advancedImportInfo')}
</Text>
</Flex>
<Flex p={4} borderRadius={4} bg="base.850" height="100%" direction="column" gap="3">
<FormControl isInvalid={Boolean(errors.name)}>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.name')}</FormLabel>
<Input
{...register('name', {
validate: (value) => value.trim().length >= 3 || 'Must be at least 3 characters',
})}
/>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
<Flex>
<FormControl>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.description')}</FormLabel>
<Textarea size="sm" {...register('description')} />
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect control={control} name="base" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('common.format')}</FormLabel>
<ModelFormatSelect control={control} name="format" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.path')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</FormControl>
</Flex>
{watchedModelType === 'main' && (
<>
<Flex gap={4}>
{watchedModelFormat === 'diffusers' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
</FormControl>
)}
{watchedModelFormat === 'checkpoint' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
<Input {...register('config')} />
</FormControl>
)}
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</FormControl>
</Flex>
</>
)}
{watchedModelType === 'ip_adapter' && (
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
<Input {...register('image_encoder_model_id')} />
</FormControl>
</Flex>
)}
<Button mt={2} type="submit">
{t('modelManager.addModel')}
</Button>
</Flex>
</Flex>
</form>
</ScrollableContent>
);
};

View File

@ -12,7 +12,7 @@ type SimpleImportModelConfig = {
location: string; location: string;
}; };
export const SimpleImport = () => { export const InstallModelForm = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [installModel, { isLoading }] = useInstallModelMutation(); const [installModel, { isLoading }] = useInstallModelMutation();

View File

@ -5,19 +5,19 @@ import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next'; import { t } from 'i18next';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models'; import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models';
import { ImportQueueItem } from './ImportQueueItem'; import { ModelInstallQueueItem } from './ModelInstallQueueItem';
export const ImportQueue = () => { export const ModelInstallQueue = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data } = useGetModelImportsQuery(); const { data } = useListModelInstallsQuery();
const [pruneModelImports] = usePruneModelImportsMutation(); const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation();
const pruneQueue = useCallback(() => { const pruneCompletedModelInstalls = useCallback(() => {
pruneModelImports() _pruneCompletedModelInstalls()
.unwrap() .unwrap()
.then((_) => { .then((_) => {
dispatch( dispatch(
@ -41,7 +41,7 @@ export const ImportQueue = () => {
); );
} }
}); });
}, [pruneModelImports, dispatch]); }, [_pruneCompletedModelInstalls, dispatch]);
const pruneAvailable = useMemo(() => { const pruneAvailable = useMemo(() => {
return data?.some( return data?.some(
@ -53,14 +53,19 @@ export const ImportQueue = () => {
<Flex flexDir="column" p={3} h="full"> <Flex flexDir="column" p={3} h="full">
<Flex justifyContent="space-between" alignItems="center"> <Flex justifyContent="space-between" alignItems="center">
<Text>{t('modelManager.importQueue')}</Text> <Text>{t('modelManager.importQueue')}</Text>
<Button size="sm" isDisabled={!pruneAvailable} onClick={pruneQueue} tooltip={t('modelManager.pruneTooltip')}> <Button
size="sm"
isDisabled={!pruneAvailable}
onClick={pruneCompletedModelInstalls}
tooltip={t('modelManager.pruneTooltip')}
>
{t('modelManager.prune')} {t('modelManager.prune')}
</Button> </Button>
</Flex> </Flex>
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full"> <Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
<ScrollableContent> <ScrollableContent>
<Flex flexDir="column-reverse" gap="2"> <Flex flexDir="column-reverse" gap="2">
{data?.map((model) => <ImportQueueItem key={model.id} model={model} />)} {data?.map((model) => <ModelInstallQueueItem key={model.id} installJob={model} />)}
</Flex> </Flex>
</ScrollableContent> </ScrollableContent>
</Box> </Box>

View File

@ -6,17 +6,24 @@ import type { ModelInstallStatus } from 'services/api/types';
const STATUSES = { const STATUSES = {
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' }, waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },
downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' }, downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
downloads_done: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' }, running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
completed: { colorScheme: 'green', translationKey: 'queue.completed' }, completed: { colorScheme: 'green', translationKey: 'queue.completed' },
error: { colorScheme: 'red', translationKey: 'queue.failed' }, error: { colorScheme: 'red', translationKey: 'queue.failed' },
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' }, cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
}; };
const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus; errorReason?: string | null }) => { const ModelInstallQueueBadge = ({
status,
errorReason,
}: {
status?: ModelInstallStatus;
errorReason?: string | null;
}) => {
const { t } = useTranslation(); const { t } = useTranslation();
if (!status) { if (!status || !Object.keys(STATUSES).includes(status)) {
return <></>; return null;
} }
return ( return (
@ -25,4 +32,4 @@ const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus
</Tooltip> </Tooltip>
); );
}; };
export default memo(ImportQueueBadge); export default memo(ModelInstallQueueBadge);

View File

@ -3,15 +3,16 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next'; import { t } from 'i18next';
import { isNil } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { PiXBold } from 'react-icons/pi'; import { PiXBold } from 'react-icons/pi';
import { useDeleteModelImportMutation } from 'services/api/endpoints/models'; import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types'; import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
import ImportQueueBadge from './ImportQueueBadge'; import ModelInstallQueueBadge from './ModelInstallQueueBadge';
type ModelListItemProps = { type ModelListItemProps = {
model: ModelInstallJob; installJob: ModelInstallJob;
}; };
const formatBytes = (bytes: number) => { const formatBytes = (bytes: number) => {
@ -26,26 +27,26 @@ const formatBytes = (bytes: number) => {
return `${bytes.toFixed(2)} ${units[i]}`; return `${bytes.toFixed(2)} ${units[i]}`;
}; };
export const ImportQueueItem = (props: ModelListItemProps) => { export const ModelInstallQueueItem = (props: ModelListItemProps) => {
const { model } = props; const { installJob } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [deleteImportModel] = useDeleteModelImportMutation(); const [deleteImportModel] = useCancelModelInstallMutation();
const source = useMemo(() => { const source = useMemo(() => {
if (model.source.type === 'hf') { if (installJob.source.type === 'hf') {
return model.source as HFModelSource; return installJob.source as HFModelSource;
} else if (model.source.type === 'local') { } else if (installJob.source.type === 'local') {
return model.source as LocalModelSource; return installJob.source as LocalModelSource;
} else if (model.source.type === 'url') { } else if (installJob.source.type === 'url') {
return model.source as URLModelSource; return installJob.source as URLModelSource;
} else { } else {
return model.source as LocalModelSource; return installJob.source as LocalModelSource;
} }
}, [model.source]); }, [installJob.source]);
const handleDeleteModelImport = useCallback(() => { const handleDeleteModelImport = useCallback(() => {
deleteImportModel(model.id) deleteImportModel(installJob.id)
.unwrap() .unwrap()
.then((_) => { .then((_) => {
dispatch( dispatch(
@ -69,7 +70,7 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
); );
} }
}); });
}, [deleteImportModel, model, dispatch]); }, [deleteImportModel, installJob, dispatch]);
const modelName = useMemo(() => { const modelName = useMemo(() => {
switch (source.type) { switch (source.type) {
@ -85,19 +86,23 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
}, [source]); }, [source]);
const progressValue = useMemo(() => { const progressValue = useMemo(() => {
if (model.bytes === undefined || model.total_bytes === undefined) { if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
return null;
}
if (installJob.total_bytes === 0) {
return 0; return 0;
} }
return (model.bytes / model.total_bytes) * 100; return (installJob.bytes / installJob.total_bytes) * 100;
}, [model.bytes, model.total_bytes]); }, [installJob.bytes, installJob.total_bytes]);
const progressString = useMemo(() => { const progressString = useMemo(() => {
if (model.status !== 'downloading' || model.bytes === undefined || model.total_bytes === undefined) { if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
return ''; return '';
} }
return `${formatBytes(model.bytes)} / ${formatBytes(model.total_bytes)}`; return `${formatBytes(installJob.bytes)} / ${formatBytes(installJob.total_bytes)}`;
}, [model.bytes, model.total_bytes, model.status]); }, [installJob.bytes, installJob.total_bytes, installJob.status]);
return ( return (
<Flex gap="2" w="full" alignItems="center"> <Flex gap="2" w="full" alignItems="center">
@ -109,19 +114,21 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
<Flex flexDir="column" flex={1}> <Flex flexDir="column" flex={1}>
<Tooltip label={progressString}> <Tooltip label={progressString}>
<Progress <Progress
value={progressValue} value={progressValue ?? 0}
isIndeterminate={progressValue === undefined} isIndeterminate={progressValue === null}
aria-label={t('accessibility.invokeProgressBar')} aria-label={t('accessibility.invokeProgressBar')}
h={2} h={2}
/> />
</Tooltip> </Tooltip>
</Flex> </Flex>
<Box minW="100px" textAlign="center"> <Box minW="100px" textAlign="center">
<ImportQueueBadge status={model.status} errorReason={model.error_reason} /> <ModelInstallQueueBadge status={installJob.status} errorReason={installJob.error_reason} />
</Box> </Box>
<Box minW="20px"> <Box minW="20px">
{(model.status === 'downloading' || model.status === 'waiting' || model.status === 'running') && ( {(installJob.status === 'downloading' ||
installJob.status === 'waiting' ||
installJob.status === 'running') && (
<IconButton <IconButton
isRound={true} isRound={true}
size="xs" size="xs"

View File

@ -2,24 +2,24 @@ import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@
import type { ChangeEventHandler } from 'react'; import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react'; import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useLazyScanModelsQuery } from 'services/api/endpoints/models'; import { useLazyScanFolderQuery } from 'services/api/endpoints/models';
import { ScanModelsResults } from './ScanModelsResults'; import { ScanModelsResults } from './ScanFolderResults';
export const ScanModelsForm = () => { export const ScanModelsForm = () => {
const [scanPath, setScanPath] = useState(''); const [scanPath, setScanPath] = useState('');
const [errorMessage, setErrorMessage] = useState(''); const [errorMessage, setErrorMessage] = useState('');
const { t } = useTranslation(); const { t } = useTranslation();
const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery(); const [_scanFolder, { isLoading, data }] = useLazyScanFolderQuery();
const handleSubmitScan = useCallback(async () => { const scanFolder = useCallback(async () => {
_scanModels({ scan_path: scanPath }).catch((error) => { _scanFolder({ scan_path: scanPath }).catch((error) => {
if (error) { if (error) {
setErrorMessage(error.data.detail); setErrorMessage(error.data.detail);
} }
}); });
}, [_scanModels, scanPath]); }, [_scanFolder, scanPath]);
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => { const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
setScanPath(e.target.value); setScanPath(e.target.value);
@ -36,7 +36,7 @@ export const ScanModelsForm = () => {
<Input value={scanPath} onChange={handleSetScanPath} /> <Input value={scanPath} onChange={handleSetScanPath} />
</Flex> </Flex>
<Button onClick={handleSubmitScan} isLoading={isLoading} isDisabled={scanPath.length === 0}> <Button onClick={scanFolder} isLoading={isLoading} isDisabled={scanPath.length === 0}>
{t('modelManager.scanFolder')} {t('modelManager.scanFolder')}
</Button> </Button>
</Flex> </Flex>

View File

@ -18,7 +18,7 @@ import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi'; import { PiXBold } from 'react-icons/pi';
import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models'; import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models';
import { ScanModelResultItem } from './ScanModelResultItem'; import { ScanModelResultItem } from './ScanFolderResultItem';
type ScanModelResultsProps = { type ScanModelResultsProps = {
results: ScanFolderResponse; results: ScanFolderResponse;

View File

@ -1,12 +1,11 @@
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library'; import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { AdvancedImport } from './AddModelPanel/AdvancedImport'; import { InstallModelForm } from './AddModelPanel/InstallModelForm';
import { ImportQueue } from './AddModelPanel/ImportQueue/ImportQueue'; import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
import { ScanModelsForm } from './AddModelPanel/ScanModels/ScanModelsForm'; import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
import { SimpleImport } from './AddModelPanel/SimpleImport';
export const ImportModels = () => { export const InstallModels = () => {
const { t } = useTranslation(); const { t } = useTranslation();
return ( return (
<Flex layerStyle="first" p={3} borderRadius="base" w="full" h="full" flexDir="column" gap={2}> <Flex layerStyle="first" p={3} borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
@ -17,15 +16,11 @@ export const ImportModels = () => {
<Tabs variant="collapse" height="100%"> <Tabs variant="collapse" height="100%">
<TabList> <TabList>
<Tab>{t('common.simple')}</Tab> <Tab>{t('common.simple')}</Tab>
<Tab>{t('modelManager.advanced')}</Tab>
<Tab>{t('modelManager.scan')}</Tab> <Tab>{t('modelManager.scan')}</Tab>
</TabList> </TabList>
<TabPanels p={3} height="100%"> <TabPanels p={3} height="100%">
<TabPanel> <TabPanel>
<SimpleImport /> <InstallModelForm />
</TabPanel>
<TabPanel height="100%">
<AdvancedImport />
</TabPanel> </TabPanel>
<TabPanel height="100%"> <TabPanel height="100%">
<ScanModelsForm /> <ScanModelsForm />
@ -34,7 +29,7 @@ export const ImportModels = () => {
</Tabs> </Tabs>
</Box> </Box>
<Box layerStyle="second" borderRadius="base" w="full" h="50%"> <Box layerStyle="second" borderRadius="base" w="full" h="50%">
<ImportQueue /> <ModelInstallQueue />
</Box> </Box>
</Flex> </Flex>
); );

View File

@ -5,7 +5,7 @@ import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { IoFilter } from 'react-icons/io5'; import { IoFilter } from 'react-icons/io5';
export const MODEL_TYPE_LABELS: { [key: string]: string } = { const MODEL_TYPE_LABELS: { [key: string]: string } = {
main: 'Main', main: 'Main',
lora: 'LoRA', lora: 'LoRA',
embedding: 'Textual Inversion', embedding: 'Textual Inversion',

View File

@ -1,14 +1,14 @@
import { Box } from '@invoke-ai/ui-library'; import { Box } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { ImportModels } from './ImportModels'; import { InstallModels } from './InstallModels';
import { Model } from './ModelPanel/Model'; import { Model } from './ModelPanel/Model';
export const ModelPane = () => { export const ModelPane = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
return ( return (
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full"> <Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
{selectedModelKey ? <Model /> : <ImportModels />} {selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}
</Box> </Box>
); );
}; };

View File

@ -0,0 +1,68 @@
import { Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { isNil } from 'lodash-es';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision } = config.sd;
return {
initialSteps: steps.initial,
initialCfg: guidance.initial,
initialScheduler: scheduler,
initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial,
initialVaePrecision: vaePrecision,
};
});
export const DefaultSettings = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { t } = useTranslation();
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
useAppSelector(initialStatesSelector);
const defaultSettingsDefaults = useMemo(() => {
return {
vae: { isEnabled: !isNil(data?.default_settings?.vae), value: data?.default_settings?.vae || 'default' },
vaePrecision: {
isEnabled: !isNil(data?.default_settings?.vae_precision),
value: data?.default_settings?.vae_precision || initialVaePrecision || 'fp32',
},
scheduler: {
isEnabled: !isNil(data?.default_settings?.scheduler),
value: data?.default_settings?.scheduler || initialScheduler || 'euler',
},
steps: { isEnabled: !isNil(data?.default_settings?.steps), value: data?.default_settings?.steps || initialSteps },
cfgScale: {
isEnabled: !isNil(data?.default_settings?.cfg_scale),
value: data?.default_settings?.cfg_scale || initialCfg,
},
cfgRescaleMultiplier: {
isEnabled: !isNil(data?.default_settings?.cfg_rescale_multiplier),
value: data?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier,
},
};
}, [
data?.default_settings,
initialSteps,
initialCfg,
initialScheduler,
initialCfgRescaleMultiplier,
initialVaePrecision,
]);
if (isLoading) {
return <Text>{t('common.loading')}</Text>;
}
return <DefaultSettingsForm defaultSettingsDefaults={defaultSettingsDefaults} />;
};

View File

@ -0,0 +1,72 @@
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
type DefaultCfgRescaleMultiplierType = DefaultSettingsFormData['cfgRescaleMultiplier'];
export function DefaultCfgRescaleMultiplier(props: UseControllerProps<DefaultSettingsFormData>) {
const { field } = useController(props);
const sliderMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMin);
const sliderMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMax);
const numberInputMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMin);
const numberInputMax = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.numberInputMax);
const coarseStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.coarseStep);
const fineStep = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.fineStep);
const { t } = useTranslation();
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
const onChange = useCallback(
(v: number) => {
const updatedValue = {
...(field.value as DefaultCfgRescaleMultiplierType),
value: v,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => {
return (field.value as DefaultCfgRescaleMultiplierType).value;
}, [field.value]);
const isDisabled = useMemo(() => {
return !(field.value as DefaultCfgRescaleMultiplierType).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={1} alignItems="flex-start">
<InformationalPopover feature="paramCFGRescaleMultiplier">
<FormLabel>{t('parameters.cfgRescaleMultiplier')}</FormLabel>
</InformationalPopover>
<Flex w="full" gap={1}>
<CompositeSlider
value={value}
min={sliderMin}
max={sliderMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
marks={marks}
isDisabled={isDisabled}
/>
<CompositeNumberInput
value={value}
min={numberInputMin}
max={numberInputMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
}

View File

@ -0,0 +1,72 @@
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
type DefaultCfgType = DefaultSettingsFormData['cfgScale'];
export function DefaultCfgScale(props: UseControllerProps<DefaultSettingsFormData>) {
const { field } = useController(props);
const sliderMin = useAppSelector((s) => s.config.sd.guidance.sliderMin);
const sliderMax = useAppSelector((s) => s.config.sd.guidance.sliderMax);
const numberInputMin = useAppSelector((s) => s.config.sd.guidance.numberInputMin);
const numberInputMax = useAppSelector((s) => s.config.sd.guidance.numberInputMax);
const coarseStep = useAppSelector((s) => s.config.sd.guidance.coarseStep);
const fineStep = useAppSelector((s) => s.config.sd.guidance.fineStep);
const { t } = useTranslation();
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
const onChange = useCallback(
(v: number) => {
const updatedValue = {
...(field.value as DefaultCfgType),
value: v,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => {
return (field.value as DefaultCfgType).value;
}, [field.value]);
const isDisabled = useMemo(() => {
return !(field.value as DefaultCfgType).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={1} alignItems="flex-start">
<InformationalPopover feature="paramCFGScale">
<FormLabel>{t('parameters.cfgScale')}</FormLabel>
</InformationalPopover>
<Flex w="full" gap={1}>
<CompositeSlider
value={value}
min={sliderMin}
max={sliderMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
marks={marks}
isDisabled={isDisabled}
/>
<CompositeNumberInput
value={value}
min={numberInputMin}
max={numberInputMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
}

View File

@ -0,0 +1,50 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants';
import { isParameterScheduler } from 'features/parameters/types/parameterSchemas';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
type DefaultSchedulerType = DefaultSettingsFormData['scheduler'];
export function DefaultScheduler(props: UseControllerProps<DefaultSettingsFormData>) {
const { t } = useTranslation();
const { field } = useController(props);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isParameterScheduler(v?.value)) {
return;
}
const updatedValue = {
...(field.value as DefaultSchedulerType),
value: v.value,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(
() => SCHEDULER_OPTIONS.find((o) => o.value === (field.value as DefaultSchedulerType).value),
[field]
);
const isDisabled = useMemo(() => {
return !(field.value as DefaultSchedulerType).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={1} alignItems="flex-start">
<InformationalPopover feature="paramScheduler">
<FormLabel>{t('parameters.scheduler')}</FormLabel>
</InformationalPopover>
<Combobox isDisabled={isDisabled} value={value} options={SCHEDULER_OPTIONS} onChange={onChange} />
</FormControl>
);
}

View File

@ -0,0 +1,147 @@
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { IoPencil } from 'react-icons/io5';
import { useUpdateModelMutation } from 'services/api/endpoints/models';
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
import { DefaultCfgScale } from './DefaultCfgScale';
import { DefaultScheduler } from './DefaultScheduler';
import { DefaultSteps } from './DefaultSteps';
import { DefaultVae } from './DefaultVae';
import { DefaultVaePrecision } from './DefaultVaePrecision';
import { SettingToggle } from './SettingToggle';
export interface FormField<T> {
value: T;
isEnabled: boolean;
}
export type DefaultSettingsFormData = {
vae: FormField<string>;
vaePrecision: FormField<string>;
scheduler: FormField<ParameterScheduler>;
steps: FormField<number>;
cfgScale: FormField<number>;
cfgRescaleMultiplier: FormField<number>;
};
export const DefaultSettingsForm = ({
defaultSettingsDefaults,
}: {
defaultSettingsDefaults: DefaultSettingsFormData;
}) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const [updateModel, { isLoading }] = useUpdateModelMutation();
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
defaultValues: defaultSettingsDefaults,
});
const onSubmit = useCallback<SubmitHandler<DefaultSettingsFormData>>(
(data) => {
if (!selectedModelKey) {
return;
}
const body = {
vae: data.vae.isEnabled ? data.vae.value : null,
vae_precision: data.vaePrecision.isEnabled ? data.vaePrecision.value : null,
cfg_scale: data.cfgScale.isEnabled ? data.cfgScale.value : null,
cfg_rescale_multiplier: data.cfgRescaleMultiplier.isEnabled ? data.cfgRescaleMultiplier.value : null,
steps: data.steps.isEnabled ? data.steps.value : null,
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
};
updateModel({
key: selectedModelKey,
body: { default_settings: body },
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.defaultSettingsSaved'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
},
[selectedModelKey, dispatch, updateModel, t]
);
return (
<>
<Flex gap="2" justifyContent="space-between" w="full" mb={5}>
<Heading fontSize="md">{t('modelManager.defaultSettings')}</Heading>
<Button
size="sm"
leftIcon={<IoPencil />}
colorScheme="invokeYellow"
isDisabled={!formState.isDirty}
onClick={handleSubmit(onSubmit)}
type="submit"
isLoading={isLoading}
>
{t('common.save')}
</Button>
</Flex>
<Flex flexDir="column" gap={8}>
<Flex gap={8}>
<Flex gap={4} w="full">
<SettingToggle control={control} name="vae" />
<DefaultVae control={control} name="vae" />
</Flex>
<Flex gap={4} w="full">
<SettingToggle control={control} name="vaePrecision" />
<DefaultVaePrecision control={control} name="vaePrecision" />
</Flex>
</Flex>
<Flex gap={8}>
<Flex gap={4} w="full">
<SettingToggle control={control} name="scheduler" />
<DefaultScheduler control={control} name="scheduler" />
</Flex>
<Flex gap={4} w="full">
<SettingToggle control={control} name="steps" />
<DefaultSteps control={control} name="steps" />
</Flex>
</Flex>
<Flex gap={8}>
<Flex gap={4} w="full">
<SettingToggle control={control} name="cfgScale" />
<DefaultCfgScale control={control} name="cfgScale" />
</Flex>
<Flex gap={4} w="full">
<SettingToggle control={control} name="cfgRescaleMultiplier" />
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
</Flex>
</Flex>
</Flex>
</>
);
};

View File

@ -0,0 +1,72 @@
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
type DefaultSteps = DefaultSettingsFormData['steps'];
export function DefaultSteps(props: UseControllerProps<DefaultSettingsFormData>) {
const { field } = useController(props);
const sliderMin = useAppSelector((s) => s.config.sd.steps.sliderMin);
const sliderMax = useAppSelector((s) => s.config.sd.steps.sliderMax);
const numberInputMin = useAppSelector((s) => s.config.sd.steps.numberInputMin);
const numberInputMax = useAppSelector((s) => s.config.sd.steps.numberInputMax);
const coarseStep = useAppSelector((s) => s.config.sd.steps.coarseStep);
const fineStep = useAppSelector((s) => s.config.sd.steps.fineStep);
const { t } = useTranslation();
const marks = useMemo(() => [sliderMin, Math.floor(sliderMax / 2), sliderMax], [sliderMax, sliderMin]);
const onChange = useCallback(
(v: number) => {
const updatedValue = {
...(field.value as DefaultSteps),
value: v,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => {
return (field.value as DefaultSteps).value;
}, [field.value]);
const isDisabled = useMemo(() => {
return !(field.value as DefaultSteps).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={1} alignItems="flex-start">
<InformationalPopover feature="paramSteps">
<FormLabel>{t('parameters.steps')}</FormLabel>
</InformationalPopover>
<Flex w="full" gap={1}>
<CompositeSlider
value={value}
min={sliderMin}
max={sliderMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
marks={marks}
isDisabled={isDisabled}
/>
<CompositeNumberInput
value={value}
min={numberInputMin}
max={numberInputMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
}

View File

@ -0,0 +1,65 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery, useGetVaeModelsQuery } from 'services/api/endpoints/models';
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
type DefaultVaeType = DefaultSettingsFormData['vae'];
export function DefaultVae(props: UseControllerProps<DefaultSettingsFormData>) {
const { t } = useTranslation();
const { field } = useController(props);
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const { compatibleOptions } = useGetVaeModelsQuery(undefined, {
selectFromResult: ({ data }) => {
const modelArray = map(data?.entities);
const compatibleOptions = modelArray
.filter((vae) => vae.base === modelData?.base)
.map((vae) => ({ label: vae.name, value: vae.key }));
const defaultOption = { label: 'Default VAE', value: 'default' };
return { compatibleOptions: [defaultOption, ...compatibleOptions] };
},
});
const onChange = useCallback<ComboboxOnChange>(
(v) => {
const newValue = !v?.value ? 'default' : v.value;
const updatedValue = {
...(field.value as DefaultVaeType),
value: newValue,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => {
return compatibleOptions.find((vae) => vae.value === (field.value as DefaultVaeType).value);
}, [compatibleOptions, field.value]);
const isDisabled = useMemo(() => {
return !(field.value as DefaultVaeType).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={1} alignItems="flex-start">
<InformationalPopover feature="paramVAE">
<FormLabel>{t('modelManager.vae')}</FormLabel>
</InformationalPopover>
<Combobox isDisabled={isDisabled} value={value} options={compatibleOptions} onChange={onChange} />
</FormControl>
);
}

View File

@ -0,0 +1,51 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { isParameterPrecision } from 'features/parameters/types/parameterSchemas';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DefaultSettingsFormData } from './DefaultSettingsForm';
const options = [
{ label: 'FP16', value: 'fp16' },
{ label: 'FP32', value: 'fp32' },
];
type DefaultVaePrecisionType = DefaultSettingsFormData['vaePrecision'];
export function DefaultVaePrecision(props: UseControllerProps<DefaultSettingsFormData>) {
const { t } = useTranslation();
const { field } = useController(props);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isParameterPrecision(v?.value)) {
return;
}
const updatedValue = {
...(field.value as DefaultVaePrecisionType),
value: v.value,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => options.find((o) => o.value === (field.value as DefaultVaePrecisionType).value), [field]);
const isDisabled = useMemo(() => {
return !(field.value as DefaultVaePrecisionType).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={1} alignItems="flex-start">
<InformationalPopover feature="paramVAEPrecision">
<FormLabel>{t('modelManager.vaePrecision')}</FormLabel>
</InformationalPopover>
<Combobox isDisabled={isDisabled} value={value} options={options} onChange={onChange} />
</FormControl>
);
}

View File

@ -0,0 +1,28 @@
import { Switch } from '@invoke-ai/ui-library';
import type { ChangeEvent } from 'react';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { DefaultSettingsFormData, FormField } from './DefaultSettingsForm';
export function SettingToggle<T>(props: UseControllerProps<DefaultSettingsFormData>) {
const { field } = useController(props);
const value = useMemo(() => {
return !!(field.value as FormField<T>).isEnabled;
}, [field.value]);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
const updatedValue: FormField<T> = {
...(field.value as FormField<T>),
isEnabled: e.target.checked,
};
field.onChange(updatedValue);
},
[field]
);
return <Switch isChecked={value} onChange={onChange} />;
}

View File

@ -3,9 +3,9 @@ import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo'; import { typedMemo } from 'common/util/typedMemo';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form'; import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form'; import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types'; import type { UpdateModelArg } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [ const options: ComboboxOption[] = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] }, { value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
@ -14,8 +14,12 @@ const options: ComboboxOption[] = [
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] }, { value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
]; ];
const BaseModelSelect = (props: UseControllerProps<AnyModelConfig>) => { type Props = {
const { field } = useController(props); control: Control<UpdateModelArg['body']>;
};
const BaseModelSelect = ({ control }: Props) => {
const { field } = useController({ control, name: 'base' });
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]); const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>( const onChange = useCallback<ComboboxOnChange>(
(v) => { (v) => {

View File

@ -1,27 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'none', label: '-' },
{ value: 'true', label: 'True' },
{ value: 'false', label: 'False' },
];
const BooleanSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value === 'true');
},
[field]
);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(BooleanSelect);

View File

@ -1,47 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController, useWatch } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const ModelFormatSelect = (props: UseControllerProps<AnyModelConfig>) => {
const { field, formState } = useController(props);
const type = useWatch({ control: props.control, name: 'type' });
const onChange = useCallback<ComboboxOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
const options: ComboboxOption[] = useMemo(() => {
const modelType = type || formState.defaultValues?.type;
if (modelType === 'lora') {
return [
{ value: 'lycoris', label: 'LyCORIS' },
{ value: 'diffusers', label: 'Diffusers' },
];
} else if (modelType === 'embedding') {
return [
{ value: 'embedding_file', label: 'Embedding File' },
{ value: 'embedding_folder', label: 'Embedding Folder' },
];
} else if (modelType === 'ip_adapter') {
return [{ value: 'invokeai', label: 'invokeai' }];
} else {
return [
{ value: 'diffusers', label: 'Diffusers' },
{ value: 'checkpoint', label: 'Checkpoint' },
];
}
}, [type, formState.defaultValues?.type]);
const value = useMemo(() => options.find((o) => o.value === field.value), [options, field.value]);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(ModelFormatSelect);

View File

@ -1,32 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { MODEL_TYPE_LABELS } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'main', label: MODEL_TYPE_LABELS['main'] as string },
{ value: 'lora', label: MODEL_TYPE_LABELS['lora'] as string },
{ value: 'embedding', label: MODEL_TYPE_LABELS['embedding'] as string },
{ value: 'vae', label: MODEL_TYPE_LABELS['vae'] as string },
{ value: 'controlnet', label: MODEL_TYPE_LABELS['controlnet'] as string },
{ value: 'ip_adapter', label: MODEL_TYPE_LABELS['ip_adapter'] as string },
{ value: 't2i_adapater', label: MODEL_TYPE_LABELS['t2i_adapter'] as string },
] as const;
const ModelTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(ModelTypeSelect);

View File

@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library'; import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo'; import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form'; import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form'; import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types'; import type { UpdateModelArg } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [ const options: ComboboxOption[] = [
{ value: 'normal', label: 'Normal' }, { value: 'normal', label: 'Normal' },
@ -12,8 +12,12 @@ const options: ComboboxOption[] = [
{ value: 'depth', label: 'Depth' }, { value: 'depth', label: 'Depth' },
]; ];
const ModelVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => { type Props = {
const { field } = useController(props); control: Control<UpdateModelArg['body']>;
};
const ModelVariantSelect = ({ control }: Props) => {
const { field } = useController({ control, name: 'variant' });
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]); const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>( const onChange = useCallback<ComboboxOnChange>(
(v) => { (v) => {

View File

@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library'; import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo'; import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form'; import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form'; import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types'; import type { UpdateModelArg } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [ const options: ComboboxOption[] = [
{ value: 'none', label: '-' }, { value: 'none', label: '-' },
@ -13,8 +13,12 @@ const options: ComboboxOption[] = [
{ value: 'sample', label: 'sample' }, { value: 'sample', label: 'sample' },
]; ];
const PredictionTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => { type Props = {
const { field } = useController(props); control: Control<UpdateModelArg['body']>;
};
const PredictionTypeSelect = ({ control }: Props) => {
const { field } = useController({ control, name: 'prediction_type' });
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]); const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>( const onChange = useCallback<ComboboxOnChange>(
(v) => { (v) => {

View File

@ -1,27 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'none', label: '-' },
{ value: 'fp16', label: 'fp16' },
{ value: 'fp32', label: 'fp32' },
];
const RepoVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value);
},
[field]
);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(RepoVariantSelect);

View File

@ -0,0 +1,41 @@
import { Box, Flex } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { useMemo } from 'react';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import type { ModelType } from 'services/api/types';
import { TriggerPhrases } from './TriggerPhrases';
const MODEL_TYPE_TRIGGER_PHRASE: ModelType[] = ['main', 'lora'];
export const ModelMetadata = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const shouldShowTriggerPhraseSettings = useMemo(() => {
if (!data?.type) {
return false;
}
return MODEL_TYPE_TRIGGER_PHRASE.includes(data.type);
}, [data]);
const apiResponseFormatted = useMemo(() => {
if (!data?.source_api_response) {
return {};
}
return JSON.parse(data.source_api_response);
}, [data?.source_api_response]);
return (
<Flex flexDir="column" height="full" gap="3">
{shouldShowTriggerPhraseSettings && (
<Box layerStyle="second" borderRadius="base" p={3}>
<TriggerPhrases />
</Box>
)}
<DataViewer label="metadata" data={apiResponseFormatted} />
</Flex>
);
};

View File

@ -0,0 +1,106 @@
import {
Button,
Flex,
FormControl,
FormErrorMessage,
Input,
Tag,
TagCloseButton,
TagLabel,
} from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { ModelListHeader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelListHeader';
import type { ChangeEvent } from 'react';
import { useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
export const TriggerPhrases = () => {
const { t } = useTranslation();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data: modelConfig } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const [phrase, setPhrase] = useState('');
const [updateModel, { isLoading }] = useUpdateModelMutation();
const handlePhraseChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setPhrase(e.target.value);
}, []);
const triggerPhrases = useMemo(() => {
return modelConfig?.trigger_phrases || [];
}, [modelConfig?.trigger_phrases]);
const errors = useMemo(() => {
const errors = [];
if (phrase.length && triggerPhrases.includes(phrase)) {
errors.push('Phrase is already in list');
}
return errors;
}, [phrase, triggerPhrases]);
const addTriggerPhrase = useCallback(async () => {
if (!selectedModelKey) {
return;
}
if (!phrase.length || triggerPhrases.includes(phrase)) {
return;
}
await updateModel({
key: selectedModelKey,
body: { trigger_phrases: [...triggerPhrases, phrase] },
}).unwrap();
setPhrase('');
}, [updateModel, selectedModelKey, phrase, triggerPhrases]);
const removeTriggerPhrase = useCallback(
async (phraseToRemove: string) => {
if (!selectedModelKey) {
return;
}
const filteredPhrases = triggerPhrases.filter((p) => p !== phraseToRemove);
await updateModel({ key: selectedModelKey, body: { trigger_phrases: filteredPhrases } }).unwrap();
},
[updateModel, selectedModelKey, triggerPhrases]
);
return (
<Flex flexDir="column" w="full" gap="5">
<ModelListHeader title={t('modelManager.triggerPhrases')} />
<form>
<FormControl w="full" isInvalid={Boolean(errors.length)}>
<Flex flexDir="column" w="full">
<Flex gap="3" alignItems="center" w="full">
<Input value={phrase} onChange={handlePhraseChange} placeholder={t('modelManager.typePhraseHere')} />
<Button
type="submit"
onClick={addTriggerPhrase}
isDisabled={Boolean(errors.length)}
isLoading={isLoading}
>
{t('common.add')}
</Button>
</Flex>
{!!errors.length && errors.map((error) => <FormErrorMessage key={error}>{error}</FormErrorMessage>)}
</Flex>
</FormControl>
</form>
<Flex gap="4" flexWrap="wrap" mt="3" mb="3">
{triggerPhrases.map((phrase, index) => (
<Tag size="md" key={index}>
<TagLabel>{phrase}</TagLabel>
<TagCloseButton onClick={removeTriggerPhrase.bind(null, phrase)} isDisabled={isLoading} />
</Tag>
))}
</Flex>
</Flex>
);
};

View File

@ -1,9 +1,58 @@
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs, Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import { ModelMetadata } from './Metadata/ModelMetadata';
import { ModelAttrView } from './ModelAttrView';
import { ModelEdit } from './ModelEdit'; import { ModelEdit } from './ModelEdit';
import { ModelView } from './ModelView'; import { ModelView } from './ModelView';
export const Model = () => { export const Model = () => {
const { t } = useTranslation();
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode); const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
return selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />; const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
if (isLoading) {
return <Text>{t('common.loading')}</Text>;
}
if (!data) {
return <Text>{t('common.somethingWentWrong')}</Text>;
}
return (
<>
<Flex flexDir="column" gap={1} p={2}>
<Heading as="h2" fontSize="lg">
{data.name}
</Heading>
{data.source && (
<Text variant="subtext">
{t('modelManager.source')}: {data?.source}
</Text>
)}
<Box mt="4">
<ModelAttrView label="Description" value={data.description} />
</Box>
</Flex>
<Tabs mt="4" h="100%">
<TabList>
<Tab>{t('modelManager.settings')}</Tab>
<Tab>{t('modelManager.metadata')}</Tab>
</TabList>
<TabPanels h="100%">
<TabPanel>{selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />}</TabPanel>
<TabPanel h="full">
<ModelMetadata />
</TabPanel>
</TabPanels>
</Tabs>
</>
);
}; };

View File

@ -13,7 +13,7 @@ import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useConvertMainModelsMutation } from 'services/api/endpoints/models'; import { useConvertModelMutation } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/types'; import type { CheckpointModelConfig } from 'services/api/types';
interface ModelConvertProps { interface ModelConvertProps {
@ -24,7 +24,7 @@ export const ModelConvert = (props: ModelConvertProps) => {
const { model } = props; const { model } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const [convertModel, { isLoading }] = useConvertMainModelsMutation(); const [convertModel, { isLoading }] = useConvertModelMutation();
const { isOpen, onOpen, onClose } = useDisclosure(); const { isOpen, onOpen, onClose } = useDisclosure();
const modelConvertHandler = useCallback(() => { const modelConvertHandler = useCallback(() => {

View File

@ -1,5 +1,6 @@
import { import {
Button, Button,
Checkbox,
Flex, Flex,
FormControl, FormControl,
FormErrorMessage, FormErrorMessage,
@ -19,66 +20,27 @@ import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form'; import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { UpdateModelArg } from 'services/api/endpoints/models'; import type { UpdateModelArg } from 'services/api/endpoints/models';
import { useGetModelConfigQuery, useUpdateModelsMutation } from 'services/api/endpoints/models'; import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import BaseModelSelect from './Fields/BaseModelSelect'; import BaseModelSelect from './Fields/BaseModelSelect';
import BooleanSelect from './Fields/BooleanSelect';
import ModelFormatSelect from './Fields/ModelFormatSelect';
import ModelTypeSelect from './Fields/ModelTypeSelect';
import ModelVariantSelect from './Fields/ModelVariantSelect'; import ModelVariantSelect from './Fields/ModelVariantSelect';
import PredictionTypeSelect from './Fields/PredictionTypeSelect'; import PredictionTypeSelect from './Fields/PredictionTypeSelect';
import RepoVariantSelect from './Fields/RepoVariantSelect';
export const ModelEdit = () => { export const ModelEdit = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken); const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelsMutation(); const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
const { t } = useTranslation(); const { t } = useTranslation();
// const modelData = useMemo(() => {
// if (!data) {
// return null;
// }
// const modelFormat = data.format;
// const modelType = data.type;
// if (modelType === 'main') {
// if (modelFormat === 'diffusers') {
// return data as DiffusersModelConfig;
// } else if (modelFormat === 'checkpoint') {
// return data as CheckpointModelConfig;
// }
// }
// switch (modelType) {
// case 'lora':
// return data as LoRAModelConfig;
// case 'embedding':
// return data as TextualInversionModelConfig;
// case 't2i_adapter':
// return data as T2IAdapterModelConfig;
// case 'ip_adapter':
// return data as IPAdapterModelConfig;
// case 'controlnet':
// return data as ControlNetModelConfig;
// case 'vae':
// return data as VAEModelConfig;
// default:
// return null;
// }
// }, [data]);
const { const {
register, register,
handleSubmit, handleSubmit,
control, control,
formState: { errors }, formState: { errors },
reset, reset,
watch,
} = useForm<UpdateModelArg['body']>({ } = useForm<UpdateModelArg['body']>({
defaultValues: { defaultValues: {
...data, ...data,
@ -86,10 +48,7 @@ export const ModelEdit = () => {
mode: 'onChange', mode: 'onChange',
}); });
const watchedModelType = watch('type'); const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
const watchedModelFormat = watch('format');
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
(values) => { (values) => {
if (!data?.key) { if (!data?.key) {
return; return;
@ -143,33 +102,31 @@ export const ModelEdit = () => {
return ( return (
<Flex flexDir="column" h="full"> <Flex flexDir="column" h="full">
<form onSubmit={handleSubmit(onSubmit)}> <form onSubmit={handleSubmit(onSubmit)}>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}> <Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center"> <FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
<FormLabel hidden={true}>{t('modelManager.modelName')}</FormLabel> <FormLabel hidden={true}>{t('modelManager.modelName')}</FormLabel>
<Input <Input
{...register('name', { {...register('name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters', validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
})} })}
size="lg" size="lg"
/> />
<Flex gap={2}> {errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
<Button size="sm" onClick={handleClickCancel}> </FormControl>
{t('common.cancel')} <Button size="sm" onClick={handleClickCancel}>
</Button> {t('common.cancel')}
<Button </Button>
size="sm" <Button
colorScheme="invokeYellow" size="sm"
onClick={handleSubmit(onSubmit)} colorScheme="invokeYellow"
isLoading={isSubmitting} onClick={handleSubmit(onSubmit)}
isDisabled={Boolean(Object.keys(errors).length)} isLoading={isSubmitting}
> isDisabled={Boolean(Object.keys(errors).length)}
{t('common.save')} >
</Button> {t('common.save')}
</Flex> </Button>
</Flex> </Flex>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</FormControl>
<Flex flexDir="column" gap={3} mt="4"> <Flex flexDir="column" gap={3} mt="4">
<Flex> <Flex>
@ -184,76 +141,22 @@ export const ModelEdit = () => {
<Flex gap={4}> <Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}> <FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel> <FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect control={control} name="base" /> <BaseModelSelect control={control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.modelType')}</FormLabel>
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
</FormControl> </FormControl>
</Flex> </Flex>
<Flex gap={4}> {data.type === 'main' && data.format === 'checkpoint' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('common.format')}</FormLabel>
<ModelFormatSelect control={control} name="format" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.path')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</FormControl>
</Flex>
{watchedModelType === 'main' && (
<>
<Flex gap={4}>
{watchedModelFormat === 'diffusers' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
</FormControl>
)}
{watchedModelFormat === 'checkpoint' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
<Input {...register('config')} />
</FormControl>
)}
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</FormControl>
</Flex>
</>
)}
{watchedModelType === 'ip_adapter' && (
<Flex gap={4}> <Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}> <FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel> <FormLabel>{t('modelManager.variant')}</FormLabel>
<Input {...register('image_encoder_model_id')} /> <ModelVariantSelect control={control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect control={control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<Checkbox {...register('upcast_attention')} />
</FormControl> </FormControl>
</Flex> </Flex>
)} )}

View File

@ -1,12 +1,11 @@
import { Box, Button, Flex, Heading, Text } from '@invoke-ai/ui-library'; import { Box, Button, Flex, Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query'; import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { IoPencil } from 'react-icons/io5'; import { IoPencil } from 'react-icons/io5';
import { useGetModelConfigQuery, useGetModelMetadataQuery } from 'services/api/endpoints/models'; import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import type { import type {
CheckpointModelConfig, CheckpointModelConfig,
ControlNetModelConfig, ControlNetModelConfig,
@ -18,6 +17,7 @@ import type {
VAEModelConfig, VAEModelConfig,
} from 'services/api/types'; } from 'services/api/types';
import { DefaultSettings } from './DefaultSettings';
import { ModelAttrView } from './ModelAttrView'; import { ModelAttrView } from './ModelAttrView';
import { ModelConvert } from './ModelConvert'; import { ModelConvert } from './ModelConvert';
@ -26,7 +26,6 @@ export const ModelView = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken); const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
const modelData = useMemo(() => { const modelData = useMemo(() => {
if (!data) { if (!data) {
@ -73,85 +72,49 @@ export const ModelView = () => {
return <Text>{t('common.somethingWentWrong')}</Text>; return <Text>{t('common.somethingWentWrong')}</Text>;
} }
return ( return (
<Flex flexDir="column" h="full"> <Flex flexDir="column" h="full" gap="2">
<Flex w="full" justifyContent="space-between"> <Box layerStyle="second" borderRadius="base" p={3}>
<Flex flexDir="column" gap={1} p={2}> <Flex gap="2" justifyContent="flex-end" w="full">
<Heading as="h2" fontSize="lg">
{modelData.name}
</Heading>
{modelData.source && (
<Text variant="subtext">
{t('modelManager.source')}: {modelData.source}
</Text>
)}
</Flex>
<Flex gap={2}>
<Button size="sm" leftIcon={<IoPencil />} colorScheme="invokeYellow" onClick={handleEditModel}> <Button size="sm" leftIcon={<IoPencil />} colorScheme="invokeYellow" onClick={handleEditModel}>
{t('modelManager.edit')} {t('modelManager.edit')}
</Button> </Button>
{modelData.type === 'main' && modelData.format === 'checkpoint' && <ModelConvert model={modelData} />} {modelData.type === 'main' && modelData.format === 'checkpoint' && <ModelConvert model={modelData} />}
</Flex> </Flex>
</Flex> <Flex flexDir="column" gap={3}>
<Flex gap={2}>
<Flex flexDir="column" p={2} gap={3}> <ModelAttrView label={t('modelManager.baseModel')} value={modelData.base} />
<Flex> <ModelAttrView label={t('modelManager.modelType')} value={modelData.type} />
<ModelAttrView label="Description" value={modelData.description} /> </Flex>
</Flex> <Flex gap={2}>
<Heading as="h3" fontSize="md" mt="4"> <ModelAttrView label={t('common.format')} value={modelData.format} />
{t('modelManager.modelSettings')} <ModelAttrView label={t('modelManager.path')} value={modelData.path} />
</Heading> </Flex>
<Box layerStyle="second" borderRadius="base" p={3}> {modelData.type === 'main' && (
<Flex flexDir="column" gap={3}>
<Flex gap={2}> <Flex gap={2}>
<ModelAttrView label={t('modelManager.baseModel')} value={modelData.base} /> {modelData.format === 'diffusers' && modelData.repo_variant && (
<ModelAttrView label={t('modelManager.modelType')} value={modelData.type} /> <ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
</Flex> )}
<Flex gap={2}> {modelData.format === 'checkpoint' && (
<ModelAttrView label={t('common.format')} value={modelData.format} /> <>
<ModelAttrView label={t('modelManager.path')} value={modelData.path} /> <ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} />
</Flex>
{modelData.type === 'main' && (
<>
<Flex gap={2}>
{modelData.format === 'diffusers' && (
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
)}
{modelData.format === 'checkpoint' && (
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
)}
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} /> <ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} /> <ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} /> <ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
</Flex> </>
<Flex gap={2}> )}
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} /> </Flex>
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} /> )}
</Flex> {modelData.type === 'ip_adapter' && (
</> <Flex gap={2}>
)} <ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} />
{modelData.type === 'ip_adapter' && ( </Flex>
<Flex gap={2}> )}
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} /> </Flex>
</Flex> </Box>
)} <Box layerStyle="second" borderRadius="base" p={3}>
</Flex> <DefaultSettings />
</Box> </Box>
</Flex>
{metadata && (
<>
<Heading as="h3" fontSize="md" mt="4">
{t('modelManager.modelMetadata')}
</Heading>
<Flex h="full" w="full" p={2}>
<DataViewer label="metadata" data={metadata} />
</Flex>
</>
)}
</Flex> </Flex>
); );
}; };

Some files were not shown because too many files have changed in this diff Show More