Compare commits

..

546 Commits

Author SHA1 Message Date
cb99a5f0e4 Sort results of hashes in _hash_dir to always get the same result, remove dependency on os.walk 2024-03-07 15:54:54 -05:00
3bb50926af Testing different file chunk sizes 2024-03-07 15:25:20 -05:00
bb3f1b9ca6 Add threading to scan dir calls, cap thread pool in hash function to number of files 2024-03-07 15:21:45 -05:00
119d26e102 Remove manual memory management in hashlib_hasher in favor of using Python's built-in buffering 2024-03-07 14:08:15 -05:00
e1c16c33a4 Speed up hashing 2024-03-07 13:59:02 -05:00
ad70cdfe87 feat: undo/redo discard canvas staged image 2024-03-07 19:24:55 +11:00
549d461107 refactor: 🚨 satisfy the linter 2024-03-07 19:24:55 +11:00
cab3748010 feat: discard current inpaint item 2024-03-07 19:24:55 +11:00
779b3e0e8e tidy(ui): remove npm lockfile 2024-03-06 21:57:41 -05:00
9b48029bc9 tidy(mm): ModelImages service 2024-03-06 21:57:41 -05:00
347f1fd0b7 fix tests 2024-03-06 21:57:41 -05:00
4af5a09a68 cleanup 2024-03-06 21:57:41 -05:00
8df02623f2 cleanup 2024-03-06 21:57:41 -05:00
aa88fadc30 use webp images 2024-03-06 21:57:41 -05:00
8411029d93 get model image url from model config, added thumbnail formatting for images 2024-03-06 21:57:41 -05:00
239b1e8cc7 moved upload image field and added delete image functionality 2024-03-06 21:57:41 -05:00
8a68355926 got model images displaying, still need to clean up types and unused code 2024-03-06 21:57:41 -05:00
86aef9f31d removed modelimage for now 2024-03-06 21:57:41 -05:00
2f6964bfa5 fetching model image, still not working 2024-03-06 21:57:41 -05:00
c1cdfd132b moved model image to edit page, added model_images service 2024-03-06 21:57:41 -05:00
f6bfe5e6f2 created ugly model image upload component 2024-03-06 21:57:41 -05:00
b5a8455b5f translationBot(ui): update translation (Russian)
Currently translated at 94.6% (1431 of 1512 strings)

translationBot(ui): update translation (Russian)

Currently translated at 94.6% (1431 of 1512 strings)

Co-authored-by: Васянатор <ilabulanov339@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/
Translation: InvokeAI/Web UI
2024-03-07 11:47:01 +11:00
645ef081ea translationBot(ui): update translation (Italian)
Currently translated at 98.0% (1487 of 1516 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.0% (1482 of 1512 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.0% (1475 of 1505 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-03-07 11:47:01 +11:00
e68d7fa6d7 fix(ui): update types 2024-03-07 10:56:59 +11:00
c5ab1c7ad6 chore(ui): typegen 2024-03-07 10:56:59 +11:00
5a561cab78 fix(ui): typo 2024-03-07 10:56:59 +11:00
132790eebe tidy(nodes): use canonical capitalizations 2024-03-07 10:56:59 +11:00
c57f6ee885 fix(ui): fix metadata for graphs to use new enriched format 2024-03-07 10:56:59 +11:00
d4a2ea68fc chore(ui): typegen 2024-03-07 10:56:59 +11:00
528ac5dd25 refactor(nodes): model identifiers
- All models are identified by a key and optionally a submodel type via new model `ModelField`. Previously, a few model types had their own class, but not all of them. This inconsistency just added complexity without any benefit.
- Update all invocation to use the new format.
- In the node API, models are loaded by key or an instance of `ModelField` as a convenience.
- Add an enriched model schema for metadata. It includes key, hash, name, base and type.
2024-03-07 10:56:59 +11:00
afd9ae7712 tidy(mm): remove convenience methods from high level model manager service
These were added as a hold-me-over for the nodes API changes, no longer needed. A followup commit will fix the nodes API to not rely on these.
2024-03-07 10:56:59 +11:00
4eefed12f0 refactor: 🚨 please the almighty linter 2024-03-07 10:44:40 +11:00
4301a3d6fd feat: invert scroll direction for brush size 2024-03-07 10:44:40 +11:00
99c0662e3f fix(nodes): load config before doing anything else
This was preventing custom nodes from loading if a custom nodes dir was specified

Closes #5862
2024-03-07 10:36:27 +11:00
cdc0d0c182 add config_path to ModelRecordChanges 2024-03-07 10:29:29 +11:00
a00369a67a add config path as field in model update form when model is a checkpoint 2024-03-07 10:29:29 +11:00
4f096ac3ba feat(scripts): add frontend-types to Makefile to generate types 2024-03-07 10:16:44 +11:00
f5e3341465 feat(scripts): add support for file path & stdin to frontend typegen script 2024-03-07 10:16:44 +11:00
474852ef7e feat(scripts): add script to generate openapi schema 2024-03-07 10:16:44 +11:00
b1d72d411e only show default settings on main models 2024-03-07 09:07:43 +11:00
46614ee28f lint fix 2024-03-06 15:06:27 -05:00
b019f9bb8b make sure all metadata in viewer is rendered at correct font size - specifically fixes control adapter metadata being too big 2024-03-06 15:06:27 -05:00
b857692073 update uploads from canvas to controlnet to be intermediates so they do not appear in gallery 2024-03-06 15:06:27 -05:00
90fb7a1a59 move linear tab to be first on workflow edit mode 2024-03-06 15:06:27 -05:00
56fcf6af78 empty state for workflow with no linear fields in view mode 2024-03-06 15:06:27 -05:00
c4fe7e697b add right-padding to prompt textareas so that text does not go behind icons 2024-03-06 15:06:27 -05:00
2fd483dfc8 use base.800 on invokeBlue.400 for all gallery selected states 2024-03-06 15:06:27 -05:00
b9a9507422 update padding in color picker 2024-03-06 15:06:27 -05:00
f2744fd7d1 fix URL for get image workflow (#5882)
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2024-03-06 12:46:16 -05:00
fe6e879d38 fix(ui): invalidate InvocationCacheStatus query cache after clearing intermediates 2024-03-06 08:14:12 -05:00
d3ab08fe10 tests: add invocation cache tests 2024-03-06 08:14:12 -05:00
b0615bdfd4 fix(nodes): correctly serialize outputs
In order for delete by match to work, we need the whole invocation output to be stringified.

For some reason, the serialization of the output was set to only include the `type` field. It should instead include the whole output.

I don't understand how this ever worked unless pydantic had different serialization behaviour in v1 (though it appears to have been the same).

Closes #5805
2024-03-06 08:14:12 -05:00
bab20467fb fix(nodes): fix invocation cache clear method args 2024-03-06 08:14:12 -05:00
e24624109e fix(nodes): fix invocation cache ABC typing 2024-03-06 08:14:12 -05:00
458e7185b8 fix: 🐛 didn't include renamed file 2024-03-06 20:06:14 +11:00
a95128f5f2 refactor: ✏️ canvas mask compositor naming
changes `...MaskCompositer` spelling to `...MaskCompositor`
2024-03-06 20:06:14 +11:00
46f32c5e3c Remove references to the no longer existing invokeai.app.services.model_metadata package 2024-03-05 19:58:25 -05:00
e30cb4b52f updates for defaultModel (#5866)
* move defaultModel logic to modelsLoaded and update to work for key instead of name/base/type string

* lint fix

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2024-03-05 09:55:22 -05:00
ba1f6bf926 chore: lint 2024-03-05 23:50:19 +11:00
4a9cca6c2d fix(ui): format model API response data 2024-03-05 23:50:19 +11:00
b0275700b3 refactor(ui): compute prompt trigger options in the component
We can derive the valid trigger options in the component without needing to lift the options list into global state.
2024-03-05 23:50:19 +11:00
8319aca5f9 chore(ui): typegen 2024-03-05 23:50:19 +11:00
51a604f907 pkg(ui): do not fix knip in lint:fix script 2024-03-05 23:50:19 +11:00
7515d73628 make trigger phrases a list of options and add lora name as description to appear in dropdown 2024-03-05 23:50:19 +11:00
2c453aa531 fix type error 2024-03-05 23:50:19 +11:00
2cca6e4c76 check if lora is enabled before adding trigger phrases 2024-03-05 23:50:19 +11:00
ef171e890a use a listener to recalculate trigger phrases when model or lora list changes 2024-03-05 23:50:19 +11:00
caafbf2f0d only show trigger phrase settings on main and lora 2024-03-05 23:50:19 +11:00
2db5eaf907 lint fix 2024-03-05 23:50:19 +11:00
f234bf6256 cleanup 2024-03-05 23:50:19 +11:00
cfa78b4052 adapt embedding popover to work for trigger phrases also 2024-03-05 23:50:19 +11:00
ba1dd4b02b UI in MM to create trigger phrases 2024-03-05 23:50:19 +11:00
bcf58cac59 feat(mm): add config to skip model hash
This is useful for when you are using a memory DB and do not want to wait for all models to be hashed on startup.
2024-03-05 23:50:19 +11:00
e866d90ab2 tidy(mm): remove unused method on probe 2024-03-05 23:50:19 +11:00
e8797787cf fix(mm): fix incorrect calls to update_model 2024-03-05 23:50:19 +11:00
0082ecb22b feat(mm): add path to ModelRecordChanges 2024-03-05 23:50:19 +11:00
656839fcd1 fix(mm): fix typing on heuristic_import 2024-03-05 23:50:19 +11:00
99407c899f feat(ui): update UI to use new model config backend
- Update all queries
- Remove Advanced Add
- Removed un-editable, internal-only model attributes from model edit UI (e.g. format, repo variant, model type)
- Update model tags so the list refreshes when a model installs
- Rename some queries, components, variables, types to match backend
- Fix divide-by-zero in install queue
2024-03-05 23:50:19 +11:00
48119d9010 revert(mm): restore convert route 2024-03-05 23:50:19 +11:00
7c9128b253 tidy(mm): use canonical capitalization for all model-related enums, classes
For example, "Lora" -> "LoRA", "Vae" -> "VAE".
2024-03-05 23:50:19 +11:00
4f9bb00275 tidy(api): tidy mm routes
Rename MM routes to be consistent:
- "import" -> "install"
- "model_record" -> "model"

Comment several unused routes while I work (may end up removing them?):
- list model summary (we use the search route instead)
- add model record
- convert model
- merge models
2024-03-05 23:50:19 +11:00
78895b3e80 fix(mm): add missing inplace parameter to model install abc 2024-03-05 23:50:19 +11:00
3030a34b88 fix(mm): make type and format required in openapi schema for model config 2024-03-05 23:50:19 +11:00
58fa9c2fac fix(mm): do not allow extra fields on ModelRecordChanges 2024-03-05 23:50:19 +11:00
a8b6635050 fix(mm): make key required in openapi schema for model config 2024-03-05 23:50:19 +11:00
6829610a71 tests: rename "example_config" -> "example_it_config" 2024-03-05 23:50:19 +11:00
5551cf8ac4 feat(mm): revise update_model to use ModelRecordChanges 2024-03-05 23:50:19 +11:00
37b969d339 tidy(mm): add default_settings to model config 2024-03-05 23:50:19 +11:00
c953e61294 tidy(mm): "trigger_words" -> "trigger_phrases" 2024-03-05 23:50:19 +11:00
93dd3c848e tidy(mm): remove unused code in select_hf_files.py 2024-03-05 23:50:19 +11:00
02bde7bb75 tests: fix test_hf_model_select::test_select_multiple_weights on windows 2024-03-05 23:50:19 +11:00
3391c19926 chore: ruff 2024-03-05 23:50:19 +11:00
0f60b1ced4 fix(mm): use .value for model config discriminators
There is a breaking change in python 3.11 related to how enums with `str` as a mixin are formatted. This appears to have not caused any grief for us until now.

Re-jigger the discriminator setup to use `.value` so everything works on both python 3.10 and 3.11.
2024-03-05 23:50:19 +11:00
44c40d7d1a refactor(mm): remove unused metadata logic, fix tests
- Metadata is merged with the config. We can simplify the MM substantially and remove the handling for metadata.
- Per discussion, we don't have an ETA for frontend implementation of tags, and with the realization that the tags from CivitAI are largely useless, there's no reason to keep tags in the MM right now. When we are ready to implement tags on the frontend, we can refer back to the implementation here and use it if it supports the design.
- Fix all tests.
2024-03-05 23:50:19 +11:00
0b9a212363 tests: remove 60s timeout for tests
This makes it very difficult to troubleshoot tests. Our github actions now have timeouts, so there's no risk of a test stalling for ages.
2024-03-05 23:50:19 +11:00
c3aa985c93 refactor(mm): get metadata working 2024-03-05 23:50:19 +11:00
7cb0da1f66 refactor(mm): wip schema changes 2024-03-05 23:50:19 +11:00
3534366146 fix(mm): fix extraneous downloaded files in diffusers
Sometimes, diffusers model components (tokenizer, unet, etc.) have multiple weights files in the same directory.

In this situation, we assume the files are different versions of the same weights. For example, we may have multiple
formats (`.bin`, `.safetensors`) with different precisions. When downloading model files, we want to select only
the best of these files for the requested format and precision/variant.

The previous logic assumed that each model weights file would have the same base filename, but this assumption was
not always true. The logic is revised score each file and choose the best scoring file, resulting in only a single
file being downloaded for each submodel/subdirectory.
2024-03-05 23:50:19 +11:00
f2b5f8753f tidy(mm): remove json_schema_extra from config - not needed 2024-03-05 23:50:19 +11:00
f13f5984c0 fix(mm): update db schema & migration 2024-03-05 23:50:19 +11:00
94e1e64296 chore: ruff 2024-03-05 23:50:19 +11:00
2411bf53c0 tidy(mm): better descriptions for model configs 2024-03-05 23:50:19 +11:00
9378e47a06 feat(mm): add source_type to model configs 2024-03-05 23:50:19 +11:00
4471ea8ad1 refactor(mm): simplify model metadata schemas 2024-03-05 23:50:19 +11:00
2c835fd550 refactor(mm): WIP db schema 2024-03-05 23:50:19 +11:00
61b737bb9f tidy(mm): remove update method from ModelConfigBase
It's only used in the soon-to-be-removed model merge logic
2024-03-05 23:50:19 +11:00
a8cd3dfc99 refactor(mm): add models table (schema WIP), rename "original_hash" -> "hash" 2024-03-05 23:50:19 +11:00
0cce582f2f tidy(mm): remove current_hash 2024-03-05 23:50:19 +11:00
4347d1c7f7 tests(mm): fix some objects in tests 2024-03-05 23:50:19 +11:00
bd4fd9693d tidy(mm): rename ckpt "last_modified" -> "converted_at"
Clarify what this timestamp means
2024-03-05 23:50:19 +11:00
9b40c28144 tidy(mm): rename ckpy "config" -> "config_path" 2024-03-05 23:50:19 +11:00
16a5d718bf fix(mm): add config field to ckpt vaes 2024-03-05 23:50:19 +11:00
76cbc745e1 refactor(mm): add CheckpointConfigBase for all ckpt models 2024-03-05 23:50:19 +11:00
0a614943f6 fix(mm): fix broken get_model_discriminator_value 2024-03-05 23:50:19 +11:00
e426096d32 fix(mm): misc typing fixes for model loaders 2024-03-05 23:50:19 +11:00
c561cd751f fix(mm): use correct import path for ConfigMixin, ModelMixin 2024-03-05 23:50:19 +11:00
af9298f0ef tidy(mm): tidy class names in config.py 2024-03-05 23:50:19 +11:00
5b74117836 fix(mm): use generic for model loader registry
This preserves the typing for classes using the decorator
2024-03-05 23:50:19 +11:00
38474c9797 fix(mm): use correct import path for ModelMixin 2024-03-05 23:50:19 +11:00
b880a31039 refactor(mm): remove ztsnr_training field on _MainConfig
This is used to determine the CFG Rescale Multiplier setting. We'll handle this in the UI as a default setting.
2024-03-05 23:50:19 +11:00
dd31bc4586 refactor(mm): remove vae field on _MainConfig
We will handle default VAE selection in the UI.
2024-03-05 23:50:19 +11:00
316573df2d feat(mm): use callable discriminator for AnyModelConfig union 2024-03-05 23:50:19 +11:00
8b34f5298c Default model settings (#5850)
* UI in MM to create trigger phrases

* add scheduler and vaePrecision to config

* UI for configuring default settings for models'

* hook MM default model settings up to API

* add button to set default settings in parameters

* pull out trigger phrases

* back-end for default settings

* lint

* remove log;
gi

* ruff

* ruff format

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2024-03-04 09:39:03 -05:00
893bcd16fc Next: Allow in place local installs of models 2024-03-04 23:11:41 +11:00
f6028a4c61 Log a stack trace for invocation errors. 2024-03-04 23:01:56 +11:00
264aee3ffa translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2024-03-04 21:39:46 +11:00
4deb60f365 translationBot(ui): update translation (Italian)
Currently translated at 98.0% (1442 of 1470 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-03-04 21:39:46 +11:00
B N
f2d5fb176f translationBot(ui): update translation (German)
Currently translated at 80.4% (1183 of 1470 strings)

Co-authored-by: B N <berndnieschalk@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2024-03-04 21:39:46 +11:00
94005b5501 add button to navigate to model manager if tab is enabled 2024-03-03 19:50:50 -05:00
02dc1a8780 consolidate tabs for main model and concepts in generation panel 2024-03-03 19:50:50 -05:00
ef958568ac Update Transformers 4.37.2 -> 4.38.2 2024-03-03 19:41:33 -05:00
48e323d887 docs: added both create mask nodes to defaultNodes 2024-03-03 12:58:47 -05:00
735857479d fix(canvas): use corrected mask for pasteback 2024-03-03 12:58:47 -05:00
2f372d9b18 tests(mm): update tests to reflect using UUID for key 2024-03-03 14:32:14 +11:00
554d175792 feat(mm): improved model hash class
- Use memory view for hashlib algorithms (closer to python 3.11's filehash API in hashlib)
- Remove `sha1_fast` (realized it doesn't even hash the whole file, it just does the first block)
- Add support for custom file filters
- Update docstrings
- Update tests
2024-03-03 14:32:14 +11:00
ae99428883 fix(mm): use UUIDv4 for key
This changes the functionality of this PR to only use the updated hashing for model hashes with a UUID for the key.
2024-03-03 14:32:14 +11:00
863ce00712 tests(mm): add tests for ModelHash 2024-03-03 14:32:14 +11:00
86982f3059 feat(mm): make ModelHash instantiatable, taking an algorithm as arg 2024-03-03 14:32:14 +11:00
ec8ed530a7 feat(mm): modularize ModelHash to facilitate testing 2024-03-03 14:32:14 +11:00
982076d7d7 feat(mm): add hashing algos to ModelHash
- Some algos are slow, so it is now just called ModelHash
- Added all hashlib algos, plus BLAKE3 and the fast (but incorrect) SHA1 algo
2024-03-03 14:32:14 +11:00
2e4672f931 feat(mm): make hash.py a script for testing 2024-03-03 14:32:14 +11:00
908e915a71 feat(mm): use blake3 for hashing 2024-03-03 14:32:14 +11:00
a72056e0df make model key assignment deterministic
- When installing, model keys are now calculated from the model contents.
- .safetensors, .ckpt and other single file models are hashed with sha1
- The contents of diffusers directories are hashed using imohash (faster)

fixup yaml->sql db migration script to assign deterministic key

- this commit also detects and assigns the correct image encoder for
  ip adapter models.
2024-03-03 14:32:14 +11:00
d8d7ddf43a Remove attention map saving (#5845)
## What type of PR is this? (check all applicable)

- [x] Refactor
- [ ] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because

## Description

Attention map saving was a feature that existed a long time ago in
Invoke (>1 year ago). This PR strips out a bunch of dead code that still
remains from that feature and is polluting our diffusion implementation.

This change should not have any functional effect on the app.

## QA Instructions, Screenshots, Recordings

I did a quick smoke test of SD and SDXL image generation. All of the
deleted code was unused, so the risk should be relatively low.

## Merge Plan

- [x] Change target branch to `main` before merging.

## Added/updated tests?

- [ ] Yes
- [x] No: This PR just deletes a bunch of unused code.
2024-03-02 11:15:25 -05:00
cc45007dc4 Remove unused code for attention map saving. 2024-03-02 08:25:41 -05:00
73bec56c59 Delete unused functions from shared_invokeai_diffusion.py. 2024-03-02 08:25:41 -05:00
f8b54930f0 docs: update RELEASE.md 2024-03-02 08:23:49 -05:00
51cc9f9466 ci: add comments to workflows 2024-03-02 08:23:49 -05:00
d2ad465e96 ci: rename test matrix
Now python version: platform, e.g. `py3.10: linux-cpu`

This displays better in GH actions.
2024-03-02 08:23:49 -05:00
09037b7cd4 ci: add conditionals for jobs based on dispatch/call 2024-03-02 08:23:49 -05:00
b2a850b5ea ci: rename jobs, remove extraneous needs in release 2024-03-02 08:23:49 -05:00
3ba5c2b0b4 ci: split build job 2024-03-02 08:23:49 -05:00
06fc6ccfe5 ci: workflow & job names 2024-03-02 08:23:49 -05:00
0c6b0cfdab ci: tidy pr labeler labels 2024-03-02 08:23:49 -05:00
eef3373799 ci: fix workflows
Do not split up "on change" and "do the thing". Less convoluted, no catch-22 with required checks for PRs.
2024-03-02 08:23:49 -05:00
6935830f99 Remove unused constructor declared with typo in name: __int__. 2024-03-01 15:12:03 -05:00
7651eeea8d Merge sequential conditioning and cac conditioning logic to eliminate a bunch of duplication. 2024-03-01 15:12:03 -05:00
204e7d383b Remove outdated comments related to T2I-Adapters and ControlNets. 2024-03-01 15:12:03 -05:00
9bc4e7a593 Remove use of **kwargs in do_unet_step(...), where full parameter list is known and supported. 2024-03-01 15:12:03 -05:00
ad96857e0f Fix avoid storing extra conditioning info in two places. 2024-03-01 15:12:03 -05:00
8fb297e5f6 add regression tests for <NOKEY> issue 2024-02-29 22:31:05 -05:00
0832e1818e Fix problem of all installed models being assigned "<NOKEY>"
- Also fix redundant scanning of models directory at startup.
2024-02-29 22:31:05 -05:00
26d4d93e64 ci: update mkdocs workflow
Bit of a merge of the docs at https://github.com/squidfunk/mkdocs-material/blob/master/docs/publishing-your-site.md and the previous workflow.

Not sure how to test this without access to the main repo.
2024-02-29 21:57:20 -05:00
77f39aa935 ci: bump setup-python v4 -> v5 2024-02-29 21:57:20 -05:00
6aae940834 ci: clean up unused workflow 2024-02-29 21:57:20 -05:00
be8dcad1da feat(installer): do not delete dist/ 2024-02-29 21:57:20 -05:00
5f2e493244 feat(installer): print outputs 2024-02-29 21:57:20 -05:00
c60c9825cb feat(installer): add check for CI in create_installer.sh
If in CI, print a message saying so.

If not, prompt user to confirm that they are in the correct working directory.
2024-02-29 21:57:20 -05:00
6f368395df fix(installer): conditional syntax for old bash in create_installer.sh 2024-02-29 21:57:20 -05:00
ea4d071503 ci: add reasonable timeouts for jobs
The timeouts are at least 3x the expected time to complete the jobs.

This is particularly relevant for the `pytest` job. Occasionally, it hangs while running tests that do network things, and the job only times out after 6 hours.
2024-02-29 21:57:20 -05:00
b95e5d0730 ci: bump tj-actions/changed-files -> v41 2024-02-29 21:57:20 -05:00
99ee8f9099 feat(installer): remove vX-latest from tag_release
Also update RELEASE.md accordingly, and make the release.yml workflow match on `v*` tags.
2024-02-29 21:57:20 -05:00
50e58ff323 feat(installer): just use python3 in scripts 2024-02-29 21:57:20 -05:00
b5c12985e7 docs: update RELEASE.md 2024-02-29 21:57:20 -05:00
a865277667 ci: add comments to workflows 2024-02-29 21:57:20 -05:00
b2b65a9012 feat(installer): address feedback 2024-02-29 21:57:20 -05:00
9fe579dd99 docs: update docs/RELEASE.md 2024-02-29 21:57:20 -05:00
a0313ba634 feat: automated releases via github action
- Restructure & update code check workflows
- Add release workflow to handle checks/tests, build and publish to PyPI
- Add docs/RELEASE.md explaining the workflow & process
- `create_installer.sh`: Update to work with the release workflow
- `create_installer.sh` & `tag_release.sh`: Fix the ANSI escape codes for macOS
- `tag_release.sh`: Add check for python binary name
- `tag_release.sh`: Print `git remote -v` output
- `tag_release.sh`: Fix error when deleting nonexistant tags
2024-02-29 21:57:20 -05:00
3a2afe1d15 chore: ruff 2024-03-01 10:42:33 +11:00
813a086cfe fix race condition between downloading last file and starting install 2024-03-01 10:42:33 +11:00
e18533e3b5 add debugging statements and a timeout to download test 2024-03-01 10:42:33 +11:00
dd9daf8efb chore: ruff 2024-03-01 10:42:33 +11:00
ad86b29798 chore: remove pin on ruff
This ensures it matches the github workflow.

Also there's an update that stabilizes a number of formatting rules, so there will be a format commit after this.
2024-03-01 10:42:33 +11:00
8b03af391a fix(ui): fix metadata display issue 2024-03-01 10:42:33 +11:00
bbbd18f119 fix(ui): baseUrl hardcoded api path
We now hav multiple api versions for different routers, so we cannot hardcode the `/api/v1` portion of the baseUrl
2024-03-01 10:42:33 +11:00
c074beff7c fix(ui): typo in feature tooltips 2024-03-01 10:42:33 +11:00
0b07e2aad4 docs: add v3 -> v4 migration, invocation API docs 2024-03-01 10:42:33 +11:00
753919c6d7 docs(nodes): update all docstrings for public nodes API 2024-03-01 10:42:33 +11:00
2f26768d19 fix: make invocation_context.py accessible to mkdocs
Needs an `__init__.py`.
2024-03-01 10:42:33 +11:00
ae19971f65 docs: update mkdocs config 2024-03-01 10:42:33 +11:00
e364ce1d4e docs: bump mkdocs, add mkdocstrings
Also remove ancient requirements file - the docs dependencies are in the pyproject.toml file.
2024-03-01 10:42:33 +11:00
0b0128647b feat(nodes): revise model load API args 2024-03-01 10:42:33 +11:00
39725e9560 Next: Remove deprecated app.on_event usage in api runner 2024-03-01 10:42:33 +11:00
0305e90287 chore: ruff 2024-03-01 10:42:33 +11:00
ae34bcfbc0 fix: Assertion issue with SDXL Compel 2024-03-01 10:42:33 +11:00
01898d766f Fix merge with next 2024-03-01 10:42:33 +11:00
e7afae0159 Switch absolute path to as_posix in _walk_directory 2024-03-01 10:42:33 +11:00
f16e64084b Ruff checks 2024-03-01 10:42:33 +11:00
8992d89817 Fix directory called on _walk_directory 2024-03-01 10:42:33 +11:00
0fc2f90824 Switch ModelSearch from os.walk to os.scandir 2024-03-01 10:42:33 +11:00
c670dacc29 Ruff format 2024-03-01 10:42:33 +11:00
f475b78734 Ruff check 2024-03-01 10:42:33 +11:00
ca9b815c89 Extract TI loading logic into util, disallow it from ever failing a generation 2024-03-01 10:42:33 +11:00
8efd4284e9 Fix one last reference to the uncasted model 2024-03-01 10:42:33 +11:00
5922cee541 Allow TIs to be either a key or a name in the prompt during our transition to using keys 2024-03-01 10:42:33 +11:00
94e3857110 handle change to Civitai metadata schema for commercial usage 2024-03-01 10:42:33 +11:00
4b4b940461 updated to use new import model mutation 2024-03-01 10:42:33 +11:00
574d6538b9 fix(ui): merge conflict 2024-03-01 10:42:33 +11:00
3141c6efd5 chore(ui): bump deps
The only major version is `query-string`. The breaking change for it is dropping support for old versions of node. Not a problem for us.
2024-03-01 10:42:33 +11:00
9cf2897064 ci: change frontend check to dpdm 2024-03-01 10:42:33 +11:00
bcf742ef87 feat(ui): move from madge to dpdm for circular dependencies 2024-03-01 10:42:33 +11:00
f6c068afdd tidy(ui): fix circular dependencies in listeners 2024-03-01 10:42:33 +11:00
7d2e840590 tidy: remove some traces of ONNX 2024-03-01 10:42:33 +11:00
f0b3485ce9 chore(ui): typegen, update knip config
Knip should never touch the autogenerated types
2024-03-01 10:42:33 +11:00
37608cdea2 chore(ui): update pnpm-lock.yaml
Forgot to run `pnpm i` earlier after removing packages.
2024-03-01 10:42:33 +11:00
aafa464707 ci: add knip to ui check workflow 2024-03-01 10:42:33 +11:00
1176c549c0 feat(ui): configure knip 2024-03-01 10:42:33 +11:00
d90210fea6 tidy(ui): clean up unused code 6
unused files
2024-03-01 10:42:33 +11:00
d99bec8b1a tidy(ui): clean up unused code 5
variables, types and schemas
2024-03-01 10:42:33 +11:00
b661d93bd8 tidy(ui): clean up unused code 4
variables, types and schemas
2024-03-01 10:42:33 +11:00
dc64089c9d tidy(ui): clean up unused code 3
variables, types and schemas
2024-03-01 10:42:33 +11:00
a6f6fe581e tidy(ui): clean up unused code 2
types and schemas
2024-03-01 10:42:33 +11:00
12e859835b feat(mm): add log stmt for download complete event 2024-03-01 10:42:33 +11:00
b218282149 fix(ui): model install progress sets total bytes correctly 2024-03-01 10:42:33 +11:00
80065858ed chore(ui): lint 2024-03-01 10:42:33 +11:00
aaeef03593 fix(ui): fix remaining TS issues 2024-03-01 10:42:33 +11:00
97ecd99b9c fix(ui): fix up MM queries & types (wip) 2024-03-01 10:42:33 +11:00
202e739404 tidy(api): remove non-heuristic install route 2024-03-01 10:42:33 +11:00
10d36b4045 tidy(mm): remove ONNX from AnyModelConfig 2024-03-01 10:42:33 +11:00
8f93ae8d7c tidy(ui): clean up unused code 1
- Only export when necessary
- Remove totally usused functions, variables, state, etc
- Remove unused packages
2024-03-01 10:42:33 +11:00
506fa55f18 feat(ui): add knip + minimal config
https://knip.dev/

Replaces `unimported`
2024-03-01 10:42:33 +11:00
4c19d5cee4 fix(ui): fix missing component import 2024-03-01 10:42:33 +11:00
afa7043dcd ui: split the canvas mask blur and edge size setting 2024-03-01 10:42:33 +11:00
32b8478974 added add all button to scan models 2024-03-01 10:42:33 +11:00
d23f2de9d7 feat(ui): create metadata types for control adapters
These are the same as the existing control adapter types, but the model field is non-nullable, simplifying handling of these objects.
2024-03-01 10:42:33 +11:00
9abfb02bf0 fix(ui): model metadata handlers use model identifiers, not configs
Model metadata includes the main model, VAE and refiner model.

These used full model configs, as returned by the server, as their metadata type.

LoRA and control adapter metadata only use the metadata identifier.

This created a difference in handling. After parsing a model/vae/refiner, we have its name and can display it. But for LoRAs and control adapters, we only have the model key and must query for the full model config to get the name.

This change makes main model/vae/refiner metadata only have the model key, like LoRAs and control adapters.

The render function is now async so fetching can occur within it. All metadata fields with models now only contain the identifier, and fetch the model name to render their values.
2024-03-01 10:42:33 +11:00
7b4ef5926d fix(ui): CanvasPasteBack types 2024-03-01 10:42:33 +11:00
6c5be9e89c tidy(ui): remove unused metadata schemas 2024-03-01 10:42:33 +11:00
80697a71de feat(nodes): update LoRAMetadataItem model
LoRA model now at under `model` not `lora.
2024-03-01 10:42:33 +11:00
a253047d8e tidy(ui): tidy model identifier logic
- Move some files around
- Use util to extract key and base from model config
2024-03-01 10:42:33 +11:00
7176c5d9d6 feat(ui): optimize model query caching
When we retrieve a list of models, upsert that data into the `getModelConfig` and `getModelConfigByAttrs` query caches.

With this change, calls to those two queries are almost always going to be free, because their caches will already have all models in them. The exception is queries for models that no longer exist.
2024-03-01 10:42:33 +11:00
0b54bfb7c5 fix(ui): fix lora metadata item type 2024-03-01 10:42:33 +11:00
24daacecf2 fix(ui): fix node type 2024-03-01 10:42:33 +11:00
7326c78ab5 feat(ui): add transformation to width/height parameter schemas to round to multiple of 8
This allows image dimensions that are not multiples of 8 to still be recalled with best effort.
2024-03-01 10:42:33 +11:00
04545e792c fix(ui): fix lora metadata rendering 2024-03-01 10:42:33 +11:00
e6de915c34 fix(ui): fix type issues related to change in LoRA type 2024-03-01 10:42:33 +11:00
71ceab9094 feat(ui): migrate all metadata recall logic to new system 2024-03-01 10:42:33 +11:00
ff00ed8e80 fix(ui): use id for component key in control adapter components 2024-03-01 10:42:33 +11:00
ce3f9037cd feat(ui): no JSX in metadata handlers 2024-03-01 10:42:33 +11:00
d1f4cde8c7 feat(ui): refactor metadata handling (again)
Add concepts for metadata handlers. Handlers include parsers, recallers and validators for different metadata types:
- Parsers parse a raw metadata object of any shape to a structured object.
- Recallers load the parsed metadata into state. Recallers are optional, as some metadata types don't need to be loaded into state.
- Validators provide an additional layer of validation before recalling the metadata. This is needed because a metadata object may be valid, but not able to be recalled due to some other requirement, like base model compatibility. Validators are optional.

Sometimes metadata is not a single object but a list of items - like LoRAs. Metadata handlers may implement an optional set of "item" handlers which operate on individual items in the list.

Parsers and validators are async to allow fetching additional data, like a model config. Recallers are synchronous.

The these handlers are composed into a public API, exported as a `handlers` object. Besides the handlers functions, a metadata handler set includes:
- A function to get the label of the metadata type.
- An optional function to render the value of the metadata type.
- An optional function to render the _item_ value of the metadata type.
2024-03-01 10:42:33 +11:00
90327cb521 build(ui): do not fail build on eslint error in dev mode 2024-03-01 10:42:33 +11:00
4d5458648b chore(ui): typegen 2024-03-01 10:42:33 +11:00
8d8f1abd50 feat(api): add MM get_by_attrs route
Gets the first model that matches the given name, base and type. Raises 404 if there isn't one.

This will be used for backwards compatibility with old metadata.
2024-03-01 10:42:33 +11:00
e20a506e40 undo 2024-03-01 10:42:33 +11:00
77b8eed51b fix literal strings in MM UI 2024-03-01 10:42:33 +11:00
c954cd4c8d fix TI appearing as key in prompt 2024-03-01 10:42:33 +11:00
630d3615ca fix base model grouping in combobox 2024-03-01 10:42:33 +11:00
c80c0f0fb9 fix(mm): fix ModelCacheBase method name 2024-03-01 10:42:33 +11:00
37d66488c5 chore: ruff 2024-03-01 10:42:33 +11:00
371e3cc260 recover gracefuly from GPU out of memory errors (next version) 2024-03-01 10:42:33 +11:00
d22738723d clear out VRAM when an OOM occurs 2024-03-01 10:42:33 +11:00
fbd9ffdc5a feat(ui): bulk download click to download 2024-03-01 10:42:33 +11:00
04c060a89d fix(ui): fix node types for canvas graphs 2024-03-01 10:42:33 +11:00
6f591b324b chore(ui): typegen 2024-03-01 10:42:33 +11:00
82249cc634 tidy(nodes): rename canvas paste back 2024-03-01 10:42:33 +11:00
cc82ce820a fix: outpaint result not getting pasted back correctly 2024-03-01 10:42:33 +11:00
8e1fbd6ed1 fix: lint errors 2024-03-01 10:42:33 +11:00
68d79c002d canvas: improve paste back (or try to) 2024-03-01 10:42:33 +11:00
8f6c2a8b92 wip(ui): Replace 2 Layer Coherence pass with Gradient Mask 2024-03-01 10:42:33 +11:00
ea7b7bcf40 chore: ruff 2024-03-01 10:42:33 +11:00
1456c997fb fix(ui): fix merge issue 2024-03-01 10:42:33 +11:00
7fce234646 fix(ui): use new scan_folder response instead of hook to determine if models are installed already 2024-03-01 10:42:33 +11:00
9e02384674 chore(ui): typegen 2024-03-01 10:42:33 +11:00
531d6f40f4 feat(mm): add logic to scan_folder route to check if a model is already installed
This was done in the frontend before but it's something the backend should handle.

The logic compares the found model paths to the path and source of all installed models. It excludes core models.
2024-03-01 10:42:33 +11:00
98d60e7db5 chore(ui): lint 2024-03-01 10:42:33 +11:00
1436a5f295 build(ui): restore i18n eslint rule 2024-03-01 10:42:33 +11:00
e22c4987bf chore: ruff 2024-03-01 10:42:33 +11:00
4420392241 fix(ui): fix metadata route 2024-03-01 10:42:33 +11:00
1d410e6346 chore(ui): typegen 2024-03-01 10:42:33 +11:00
c98668e7f5 feat(api): mm metadata route "meta" -> "metadata" 2024-03-01 10:42:33 +11:00
740dbc0c32 lint fix 2024-03-01 10:42:33 +11:00
97181d159f updated translations 2024-03-01 10:42:33 +11:00
65b0d3d436 fix convert endpoint logic 2024-03-01 10:42:33 +11:00
baf1194cae clean up old model manager components and endpoints 2024-03-01 10:42:33 +11:00
9b1f63379a add model convert to checkpoint main models 2024-03-01 10:42:33 +11:00
c3f4e87a6e fix logic to see if scanned models are already installed, style tweaks 2024-03-01 10:42:33 +11:00
26a209a00d add error_reason to ModelInstallJob 2024-03-01 10:42:33 +11:00
625c86ba9a add error_reason to UI if import fails 2024-03-01 10:42:33 +11:00
53f0090197 fix types for ImportQueue, add QuickAdd for scan models 2024-03-01 10:42:33 +11:00
5496699d6c refactored and fixed issues with advanced import form 2024-03-01 10:42:33 +11:00
b5ce28e60b fix(ui): misc MM cleanup 2024-03-01 10:42:33 +11:00
816fb53a14 chore(ui): temp disable eslint i18 rule 2024-03-01 10:42:33 +11:00
793c7ec832 fix(ui): fix ImportMainModelResponse type 2024-03-01 10:42:33 +11:00
62c67d7c4b fix(ui): simplify model install event listeners 2024-03-01 10:42:33 +11:00
7c41b3439a fix(ui): fix model install event types 2024-03-01 10:42:33 +11:00
cdd2f18bbd added advanced import forms, not fully working yet 2024-03-01 10:42:33 +11:00
e7d7b37896 get positioning/scrolling working for scan results list 2024-03-01 10:42:33 +11:00
57a402053e basic scan working and renders results 2024-03-01 10:42:33 +11:00
9ae09e9a7c add scan model endpoint, break add model into tabs 2024-03-01 10:42:33 +11:00
5a12886dbb update metadata endpoint 2024-03-01 10:42:33 +11:00
5b7633f3c6 allow metadata-less models to be used for GET metadata endpoint 2024-03-01 10:42:33 +11:00
68f24d9f0d added status to import queue model 2024-03-01 10:42:33 +11:00
ea364bdf82 delete model imports and prune all finished, update state with socket messages 2024-03-01 10:42:33 +11:00
18904f79ef fix sync model endpoint 2024-03-01 10:42:33 +11:00
782d15af13 form error handling 2024-03-01 10:42:33 +11:00
86e2b39f0d finish model update 2024-03-01 10:42:33 +11:00
20576deae8 added socket listeners, added more info to ui 2024-03-01 10:42:33 +11:00
0a69779df9 edit view for model, depending on type and valid values 2024-03-01 10:42:33 +11:00
6b68971f38 hook up Add Model button 2024-03-01 10:42:33 +11:00
c46eb72d45 single model view 2024-03-01 10:42:33 +11:00
87ce74e05d added import model form and importqueue 2024-03-01 10:42:33 +11:00
c7d462b222 model list, filtering, searching 2024-03-01 10:42:33 +11:00
9068400433 workspace for mary and jenn 2024-03-01 10:42:33 +11:00
55f3c6e721 get old UI working somewhat with new endpoints 2024-03-01 10:42:33 +11:00
c778ab8db4 Allow passing in key on register 2024-03-01 10:42:33 +11:00
65b91356d0 Remove passing keys in on register 2024-03-01 10:42:33 +11:00
de9287a3e4 Run ruff 2024-03-01 10:42:33 +11:00
008716040b Allow users to run model manager without cuda 2024-03-01 10:42:33 +11:00
abc569c2dd fix(ui): roll back utility-types
It's `Required` util does not distribute over unions as expected. Also we have `ts-toolbelt` already for some utils.
2024-03-01 10:42:33 +11:00
3ed2963f43 feat(ui): refactor metadata handling
Refactor of metadata recall handling. This is in preparation for a backwards compatibility layer for models.

- Create helpers to fetch a model outside react (e.g. not in a hook)
- Created helpers to parse model metadata
- Renamed a lot of types that were confusing and/or had naming collisions
2024-03-01 10:42:33 +11:00
79b16596b5 chore(ui): typegen 2024-03-01 10:42:33 +11:00
239ecfaf79 fix(nodes): make fields on ModelConfigBase required
The setup of `ModelConfigBase` means autogenerated types have critical fields flagged as nullable (like `key` and `base`). Need to manually flag them as required.
2024-03-01 10:42:33 +11:00
0d9fbe5e04 feat(ui): replace type-fest with utility-types
- The new package has more useful types
- Only used `JsonObject` from `type-fest`; added an implementation of that type
2024-03-01 10:42:33 +11:00
cc41e8912c several small model install enhancements
- Support extended HF repoid syntax in TUI. This allows
  installation of subfolders and safetensors files, as in
  `XpucT/Deliberate::Deliberate_v5.safetensors`

- Add `error` and `error_traceback` properties to the install
  job objects.

- Rename the `heuristic_import` route to `heuristic_install`.

- Fix the example `config` input in the `heuristic_install` route.
2024-03-01 10:42:33 +11:00
1cec0bb179 use official Deliberate download repo 2024-03-01 10:42:33 +11:00
65dd4f4abc fix repo-id for the Deliberate v5 model
prevent lora and embedding file suffixes from being stripped during installation

apply psychedelicious patch to get compel to load proper TI embedding
2024-03-01 10:42:33 +11:00
5bb3aeaccd remove startup dependency on legacy models.yaml file 2024-03-01 10:42:33 +11:00
30a374a70f chore: typing 2024-03-01 10:42:33 +11:00
07dde92664 chore: typing fix 2024-03-01 10:42:33 +11:00
06cc57d82a feat(nodes): added gradient mask node 2024-03-01 10:42:33 +11:00
f7fc20459a Run ruff 2024-03-01 10:42:33 +11:00
9269bdd233 rename endpoint for scanning 2024-03-01 10:42:33 +11:00
97cfcd2eef Create /search endpoint, update model object structure in scan model page 2024-03-01 10:42:33 +11:00
571a86a965 chore(ui): bump deps
Notable updates:
- Minor version of RTK includes customizable selectors for RTK Query, so we can remove the patch that was added to ensure only the LRU memoize function was used for perf reasons. Updated to use the LRU memoize function.
- Major version of react-resizable-panels. No breaking changes, works great, and you can now resize all panels when dragging at the intersection point of panels. Cool!
- Minor (?) version of nanostores. `action` API is removed, we were using it in one spot. Fixed.
- @invoke-ai/eslint-config-react has all deps bumped and now has its dependent plugins/configs listed as normal dependencies (as opposed to peer deps). This means we can remove those packages from explicit dev deps.
2024-03-01 10:42:33 +11:00
dbd929df05 tidy(ui): remove debugging stmt 2024-03-01 10:42:33 +11:00
b59d23d608 fix(ui): handle new model format for metadata 2024-03-01 10:42:33 +11:00
9d9b417432 fix(ui): use model names in badges 2024-03-01 10:42:33 +11:00
34f3a39cc9 fix(nodes): fix TI loading 2024-03-01 10:42:33 +11:00
e3c23baae9 fix(ui): fix package build 2024-03-01 10:42:33 +11:00
6a923cce70 feat(ui): do not subscribe to bulk download sio room if baseUrl is set 2024-03-01 10:42:33 +11:00
c0f0f2f39e feat(ui): revise bulk download listeners
- Use a single listener for all of the to keep them in one spot
- Use the bulk download item name as a toast id so we can update the existing toasts
- Update handling to work with other environments
- Move all bulk download handling from components to listener
2024-03-01 10:42:33 +11:00
64908eda55 chore(ui): typegen 2024-03-01 10:42:33 +11:00
a37b60db13 feat(bulk_download): update response model, messages 2024-03-01 10:42:33 +11:00
9e296f6916 implementing download for bulk_download events 2024-03-01 10:42:33 +11:00
ab94484c6c setting up event listeners for bulk download socket 2024-03-01 10:42:33 +11:00
5cba55d670 test: clean up & fix tests
- Deduplicate the mock invocation services. This is possible now that the import order issue is resolved.
- Merge `DummyEventService` into `TestEventService` and update all tests to use `TestEventService`.
2024-03-01 10:42:33 +11:00
cbb997e7d0 tidy(bulk_download): don't store events service separately
Using the invoker object directly leaves no ambiguity as to what `_events_bus` actually is.
2024-03-01 10:42:33 +11:00
98441ad08d tidy(bulk_download): do not rely on pagination API to get all images for board
We can get all images for the board as a list of image names, then pass that to `_image_handler` to get the DTOs, decoupling from the pagination API.
2024-03-01 10:42:33 +11:00
80c67dd6e0 tidy(bulk_download): nit - use or as a coalescing operator
Just a bit cleaner.
2024-03-01 10:42:33 +11:00
38af234108 tidy(bulk_download): use single underscore for private attrs
Double underscores are used in the app but it doesn't actually do or convey anything that single underscores don't already do. Considered unpythonic except for actual dunder/magic methods.
2024-03-01 10:42:33 +11:00
2291122c2b tidy(bulk_download): remove class-level attr annotations
These can be misleading as they shadow actual assigned class attributes. This pattern is in the rest of the app but it shouldn't be.
2024-03-01 10:42:33 +11:00
bf3b10cb1c tidy(bulk_download): remove extraneous abstract methods
`start`, `stop` and `__init__` are not required in implementations of an ABC or service.
2024-03-01 10:42:33 +11:00
7f8f182a00 tidy(bulk_download): clean up comments 2024-03-01 10:42:33 +11:00
e51867756a adding bulk_download_item_name to socket events 2024-03-01 10:42:33 +11:00
a8d7cf4e97 refactoring handlers to do null check 2024-03-01 10:42:33 +11:00
037cac8154 removing dependency on an output folder, embrace python temp folder for bulk download 2024-03-01 10:42:33 +11:00
0ab9fe6987 relocating event_service fixture due to import ordering 2024-03-01 10:42:33 +11:00
b5a9ed351d moving the responsibility of cleaning up board names to the service not the route 2024-03-01 10:42:33 +11:00
5f4b406cfe updating imports to satisfy ruff 2024-03-01 10:42:33 +11:00
f15aa562c2 using temp directory for downloads 2024-03-01 10:42:33 +11:00
d0f3571e59 returning the bulk_download_item_name on response for possible polling 2024-03-01 10:42:33 +11:00
b5ca1643a6 narrowing bulk_download stop service scope 2024-03-01 10:42:33 +11:00
39c01a833d adding test coverage for new bulk download routes 2024-03-01 10:42:33 +11:00
79eb871683 cleaning up bulk download zip after the response is complete 2024-03-01 10:42:33 +11:00
7544b350f3 replacing import removed during rebase 2024-03-01 10:42:33 +11:00
284ba041bd 97% test coverage on bulk_download 2024-03-01 10:42:33 +11:00
7d91426d8f refactoring bulk_download to be better managed 2024-03-01 10:42:33 +11:00
db812133e7 refactoring dummy event service, DRY principal; adding bulk_download_event to existing invoker tests 2024-03-01 10:42:33 +11:00
795fbf0e81 refactoring bulkdownload to consider image category 2024-03-01 10:42:33 +11:00
7114d64b86 fixing issue where default board did not return images 2024-03-01 10:42:33 +11:00
c43ea9f25c using the board name to download boards 2024-03-01 10:42:33 +11:00
52b0deb179 reworking some of the logic to use a default room, adding endpoint to download file on complete 2024-03-01 10:42:33 +11:00
7ecc18938b linted and styling 2024-03-01 10:42:33 +11:00
56d2d220a8 implementation of bulkdownload background task 2024-03-01 10:42:33 +11:00
f1967c3393 adding socket events for bulk download 2024-03-01 10:42:33 +11:00
812e24cbd2 groundwork for the bulk_download_service 2024-03-01 10:42:33 +11:00
8afe328af0 fix(ui): get workflow editor model selects working 2024-03-01 10:42:33 +11:00
e771c5f467 fix(ui): get refiner model select working 2024-03-01 10:42:33 +11:00
e7e3045a8a fix(ui): get vae model select working 2024-03-01 10:42:33 +11:00
f870f810d5 fix(ui): get embedding select working 2024-03-01 10:42:33 +11:00
a793103d7a fix(ui): get lora select working 2024-03-01 10:42:33 +11:00
7e5a85496e chore(ui): bump @invoke-ai/ui-library 2024-03-01 10:42:33 +11:00
ca7e928710 fix(ui): fix low-hanging fruit types 2024-03-01 10:42:33 +11:00
5b133ad198 Add a few convenience targets to Makefile
- "test" to run pytests
- "frontend-install" to reinstall pnpm's node modeuls
2024-03-01 10:42:33 +11:00
89fa36a818 chore(nodes): update TODO comment 2024-03-01 10:42:33 +11:00
e3f9da29ba tidy(nodes): clean up profiler/stats in processor, better comments 2024-03-01 10:42:33 +11:00
763debdeeb fix(nodes): fix typing on stats service context manager 2024-03-01 10:42:33 +11:00
8bf9fd34ad fix(nodes): fix model load events
was accessing incorrect properties in event data
2024-03-01 10:42:33 +11:00
0b0cb0ccc6 feat(nodes): making invocation class var in processor 2024-03-01 10:42:33 +11:00
fa39523b11 feat(nodes): improved error messages in processor 2024-03-01 10:42:33 +11:00
16676feea8 feat(nodes): make processor thread limit and polling interval configurable 2024-03-01 10:42:33 +11:00
0788a27a80 tests(nodes): fix tests following removal of services 2024-03-01 10:42:33 +11:00
d53a2a2d4e chore(nodes): better comments for invocation context 2024-03-01 10:42:33 +11:00
ccfe6b6bef chore(nodes): "context_data" -> "data"
Changed within InvocationContext, for brevity.
2024-03-01 10:42:33 +11:00
fdac0c3c9b refactor(nodes): move is_canceled to context.util 2024-03-01 10:42:33 +11:00
18adcc1dd2 feat(nodes): add whole queue_item to InvocationContextData
No reason to not have the whole thing in there.
2024-03-01 10:42:33 +11:00
86c50f2d5b tidy(nodes): remove extraneous comments 2024-03-01 10:42:33 +11:00
3cfac8b843 feat(nodes): better invocation error messages 2024-03-01 10:42:33 +11:00
0788b6ecee chore(nodes): add comments for cancel state 2024-03-01 10:42:33 +11:00
317d076a1a feat(nodes): promote is_canceled to public node API 2024-03-01 10:42:33 +11:00
725c03cf87 refactor(nodes): merge processors
Consolidate graph processing logic into session processor.

With graphs as the unit of work, and the session queue distributing graphs, we no longer need the invocation queue or processor.

Instead, the session processor dequeues the next session and processes it in a simple loop, greatly simplifying the app.

- Remove `graph_execution_manager` service.
- Remove `queue` (invocation queue) service.
- Remove `processor` (invocation processor) service.
- Remove queue-related logic from `Invoker`. It now only starts and stops the services, providing them with access to other services.
- Remove unused `invocation_retrieval_error` and `session_retrieval_error` events, these are no longer needed.
- Clean up stats service now that it is less coupled to the rest of the app.
- Refactor cancellation logic - cancellations now originate from session queue (i.e. HTTP cancel endpoint) and are emitted as events. Processor gets the events and sets the canceled event. Access to this event is provided to the invocation context for e.g. the step callback.
- Remove `sessions` router; it provided access to `graph_executions` but that no longer exists.
2024-03-01 10:42:33 +11:00
da9991e361 tidy(nodes): remove commented tests 2024-03-01 10:42:33 +11:00
67daa127e3 chore(ui): typegen 2024-03-01 10:42:33 +11:00
7e71effa17 tidy(nodes): remove no-op model_config
Because we now customize the JSON Schema creation for GraphExecutionState, the model_config did nothing.
2024-03-01 10:42:33 +11:00
e93bd15392 tidy(nodes): remove LibraryGraphs
The workflow library supersedes this unused feature.
2024-03-01 10:42:33 +11:00
0b81703c9f tidy(nodes): move node tests to parent dir
Thanks to the resolution of the import vs union issue, we can put tests anywhere.
2024-03-01 10:42:33 +11:00
641d235102 tidy(nodes): remove GraphInvocation
`GraphInvocation` is a node that can contain a whole graph. It is removed for a number of reasons:

1. This feature was unused (the UI doesn't support it) and there is no plan for it to be used.

The use-case it served is known in other node execution engines as "node groups" or "blocks" - a self-contained group of nodes, which has group inputs and outputs. This is a planned feature that will be handled client-side.

2. It adds substantial complexity to the graph processing logic. It's probably not enough to have a measurable performance impact but it does make it harder to work in the graph logic.

3. It allows for graphs to be recursive, and the improved invocations union handling does not play well with it. Actually, it works fine within `graph.py` but not in the tests for some reason. I do not understand why. There's probably a workaround, but I took this as encouragement to remove `GraphInvocation` from the app since we don't use it.
2024-03-01 10:42:33 +11:00
b79ae3a101 fix(nodes): fix OpenAPI schema generation
The change to `Graph.nodes` and `GraphExecutionState.results` validation requires some fanagling to get the OpenAPI schema generation to work. See new comments for a details.
2024-03-01 10:42:33 +11:00
731860c332 feat(nodes): JIT graph nodes validation
We use pydantic to validate a union of valid invocations when instantiating a graph.

Previously, we constructed the union while creating the `Graph` class. This introduces a dependency on the order of imports.

For example, consider a setup where we have 3 invocations in the app:

- Python executes the module where `FirstInvocation` is defined, registering `FirstInvocation`.
- Python executes the module where `SecondInvocation` is defined, registering `SecondInvocation`.
- Python executes the module where `Graph` is defined. A union of invocations is created and used to define the `Graph.nodes` field. The union contains `FirstInvocation` and `SecondInvocation`.
- Python executes the module where `ThirdInvocation` is defined, registering `ThirdInvocation`.
- A graph is created that includes `ThirdInvocation`. Pydantic validates the graph using the union, which does not know about `ThirdInvocation`, raising a `ValidationError` about an unknown invocation type.

This scenario has been particularly problematic in tests, where we may create invocations dynamically. The test files have to be structured in such a way that the imports happen in the right order. It's a major pain.

This PR refactors the validation of graph nodes to resolve this issue:

- `BaseInvocation` gets a new method `get_typeadapter`. This builds a pydantic `TypeAdapter` for the union of all registered invocations, caching it after the first call.
- `Graph.nodes`'s type is widened to `dict[str, BaseInvocation]`. This actually is a nice bonus, because we get better type hints whenever we reference `some_graph.nodes`.
- A "plain" field validator takes over the validation logic for `Graph.nodes`. "Plain" validators totally override pydantic's own validation logic. The validator grabs the `TypeAdapter` from `BaseInvocation`, then validates each node with it. The validation is identical to the previous implementation - we get the same errors.

`BaseInvocationOutput` gets the same treatment.
2024-03-01 10:42:33 +11:00
af2117dc0c remove errant def that was crashing invokeai-configure 2024-03-01 10:42:33 +11:00
1242cb4f85 one more redundant RGB convert removed 2024-03-01 10:42:33 +11:00
cd070d8be9 chore: ruff formatting 2024-03-01 10:42:33 +11:00
56ac2104e3 chore(invocations): remove redundant RGB conversions 2024-03-01 10:42:33 +11:00
965867151b chore(invocations): use IMAGE_MODES constant literal 2024-03-01 10:42:33 +11:00
2d007ce532 fix: removed custom module 2024-03-01 10:42:33 +11:00
92394ab751 fix(nodes): canny preprocessor uses RGBA again 2024-03-01 10:42:33 +11:00
43d94c8108 feat(nodes): format option for get_image method
Also default CNet preprocessors to "RGB"
2024-03-01 10:42:33 +11:00
fc20822595 fix: Alpha channel causing issue with DW Processor 2024-03-01 10:42:33 +11:00
5a3195f757 final tidying before marking PR as ready for review
- Replace AnyModelLoader with ModelLoaderRegistry
- Fix type check errors in multiple files
- Remove apparently unneeded `get_model_config_enum()` method from model manager
- Remove last vestiges of old model manager
- Updated tests and documentation

resolve conflict with seamless.py
2024-03-01 10:42:33 +11:00
5d612ec095 Tidy names and locations of modules
- Rename old "model_management" directory to "model_management_OLD" in order to catch
  dangling references to original model manager.
- Caught and fixed most dangling references (still checking)
- Rename lora, textual_inversion and model_patcher modules
- Introduce a RawModel base class to simplfy the Union returned by the
  model loaders.
- Tidy up the model manager 2-related tests. Add useful fixtures, and
  a finalizer to the queue and installer fixtures that will stop the
  services and release threads.
2024-03-01 10:42:33 +11:00
996eb96b4e Fix issues identified during PR review by RyanjDick and brandonrising
- ModelMetadataStoreService is now injected into ModelRecordStoreService
  (these two services are really joined at the hip, and should someday be merged)
- ModelRecordStoreService is now injected into ModelManagerService
- Reduced timeout value for the various installer and download wait*() methods
- Introduced a Mock modelmanager for testing
- Removed bare print() statement with _logger in the install helper backend.
- Removed unused code from model loader init file
- Made `locker` a private variable in the `LoadedModel` object.
- Fixed up model merge frontend (will be deprecated anyway!)
2024-03-01 10:42:33 +11:00
f1597bd6da chore(ui): lint 2024-03-01 10:42:33 +11:00
e50b76571a feat(ui): fix main model & control adapter model selects 2024-03-01 10:42:33 +11:00
db363b5178 refactor(ui): url builders for each router
The MM2 router is at `api/v2/models`. URL builder utils make this a bit easier to manage.
2024-03-01 10:42:33 +11:00
dab939f7d1 feat(ui): update model identifier to be key (wip)
- Update most model identifiers to be `{key: string}` instead of name/base/type. Doesn't change the model select components yet.
- Update model _parameters_, stored in redux, to be `{key: string, base: BaseModel}` - we need to store the base model to be able to check model compatibility. May want to store the whole config? Not sure...
2024-03-01 10:42:33 +11:00
6df3c450e8 fix(nodes): fix t2i adapter model loading 2024-03-01 10:42:33 +11:00
b7ba65fef4 fix(ui): update model types 2024-03-01 10:42:33 +11:00
fc107ed711 tests(ui): add type tests 2024-03-01 10:42:33 +11:00
cb804e75ed tests(ui): enable vitest type testing
This is useful for the zod schemas and types we have created to match the backend.
2024-03-01 10:42:33 +11:00
7996d43af9 chore(ui): typegen 2024-03-01 10:42:33 +11:00
fab30b5a11 feat(ui): export components type 2024-03-01 10:42:33 +11:00
651ac56b2c fix(ui): fix type issues 2024-03-01 10:42:33 +11:00
68f53460f0 chore: lint 2024-03-01 10:42:33 +11:00
c80987eb8a chore: ruff 2024-03-01 10:42:33 +11:00
539570cc7a feat(nodes): update invocation context for mm2, update nodes model usage 2024-03-01 10:42:33 +11:00
88d6de4101 Raise InvalidModelConfigException when unable to detect load class in ModelLoader 2024-03-01 10:42:33 +11:00
4c6e34b216 Update _get_hf_load_class to support clipvision models 2024-03-01 10:42:33 +11:00
262cbaacdd References to context.services.model_manager.store.get_model can only accept keys, remove invalid assertion 2024-03-01 10:42:33 +11:00
35e8a33dfd Remove references to model_records service, change submodel property on ModelInfo to submodel_type to support new params in model manager 2024-03-01 10:42:33 +11:00
b0835db47d improve swagger documentation 2024-03-01 10:42:33 +11:00
3e330d7d9d fix a number of typechecking errors 2024-03-01 10:42:33 +11:00
ff6e94f828 add route for model conversion from safetensors to diffusers
- Begin to add SwaggerUI documentation for AnyModelConfig and other
  discriminated Unions.
2024-03-01 10:42:33 +11:00
a2cc4047f9 add a JIT download_and_cache() call to the model installer 2024-03-01 10:42:33 +11:00
4027e845d4 add back the heuristic_import() method and extend repo_ids to arbitrary file paths 2024-03-01 10:42:33 +11:00
a23dedd2ee make model manager v2 ready for PR review
- Replace legacy model manager service with the v2 manager.

- Update invocations to use new load interface.

- Fixed many but not all type checking errors in the invocations. Most
  were unrelated to model manager

- Updated routes. All the new routes live under the route tag
  `model_manager_v2`. To avoid confusion with the old routes,
  they have the URL prefix `/api/v2/models`. The old routes
  have been de-registered.

- Added a pytest for the loader.

- Updated documentation in contributing/MODEL_MANAGER.md
2024-03-01 10:42:33 +11:00
7956602b19 consolidate model manager parts into a single class 2024-03-01 10:42:33 +11:00
8db01ab1b3 probe for required encoder for IPAdapters and add to config 2024-03-01 10:42:33 +11:00
db340bc253 fix invokeai_configure script to work with new mm; rename CLIs 2024-03-01 10:42:33 +11:00
78ef946e01 BREAKING CHANGES: invocations now require model key, not base/type/name
- Implement new model loader and modify invocations and embeddings

- Finish implementation loaders for all models currently supported by
  InvokeAI.

- Move lora, textual_inversion, and model patching support into
  backend/embeddings.

- Restore support for model cache statistics collection (a little ugly,
  needs work).

- Fixed up invocations that load and patch models.

- Move seamless and silencewarnings utils into better location
2024-03-01 10:42:33 +11:00
5745ce9c7d Multiple refinements on loaders:
- Cache stat collection enabled.
- Implemented ONNX loading.
- Add ability to specify the repo version variant in installer CLI.
- If caller asks for a repo version that doesn't exist, will fall back
  to empty version rather than raising an error.
2024-03-01 10:42:33 +11:00
0d3addc69b added textual inversion and lora loaders 2024-03-01 10:42:33 +11:00
67eb715093 loaders for main, controlnet, ip-adapter, clipvision and t2i 2024-03-01 10:42:33 +11:00
8ba5360269 model loading and conversion implemented for vaes 2024-03-01 10:42:33 +11:00
b8e875bb73 add ram cache module and support files 2024-03-01 10:42:33 +11:00
010c4eae65 add concept of repo variant 2024-03-01 10:42:33 +11:00
95453a22b1 tests(ui): add parseFieldType.test.ts 2024-03-01 10:42:33 +11:00
30db708c4f feat(ui): add more types of FieldParseError
Unfortunately you cannot test for both a specific type of error and match its message. Splitting the error classes makes it easier to test expected error conditions.
2024-03-01 10:42:33 +11:00
fe27af461a feat(ui): add vitest
- Add vitest.
- Consolidate vite configs into single file (easier to config everything based on env for testing)
2024-03-01 10:42:33 +11:00
f8525837b2 feat(ui): workflow schema v3 (WIP)
The changes aim to deduplicate data between workflows and node templates, decoupling workflows from internal implementation details. A good amount of data that was needlessly duplicated from the node template to the workflow is removed.

These changes substantially reduce the file size of workflows (and therefore the images with embedded workflows):

- Default T2I SD1.5 workflow JSON is reduced from 23.7kb (798 lines) to 10.9kb (407 lines).
- Default tiled upscale workflow JSON is reduced from 102.7kb (3341 lines) to 51.9kb (1774 lines).

The trade-off is that we need to reference node templates to get things like the field type and other things. In practice, this is a non-issue, because we need a node template to do anything with a node anyways.

- Field types are not included in the workflow. They are always pulled from the node templates.

The field type is now properly an internal implementation detail and we can change it as needed. Previously this would require a migration for the workflow itself. With the v3 schema, the structure of a field type is an internal implementation detail that we are free to change as we see fit.

- Workflow nodes no long have an `outputs` property and there is no longer such a thing as a `FieldOutputInstance`. These are only on the templates.

These were never referenced at a time when we didn't also have the templates available, and there'd be no reason to do so.

- Node width and height are no longer stored in the node.

These weren't used. Also, per https://reactflow.dev/api-reference/types/node, we shouldn't be programmatically changing these properties. A future enhancement can properly add node resizing.

- `nodeTemplates` slice is merged back into `nodesSlice` as `nodes.templates`. Turns out it's just a hassle having these separate in separate slices.

- Workflow migration logic updated to support the new schema. V1 workflows migrate all the way to v3 now.

- Changes throughout the nodes code to accommodate the above changes.
2024-03-01 10:42:33 +11:00
5fbfed30ac chore(ui): regen types 2024-03-01 10:42:33 +11:00
7a2159beeb feat(nodes): add more missing exports to invocation_api
Crawled through a few custom nodes to figure out what I had missed.
2024-03-01 10:42:33 +11:00
25f64d5b19 chore(nodes): "SAMPLER_NAME_VALUES" -> "SCHEDULER_NAME_VALUES"
This was named inaccurately.
2024-03-01 10:42:33 +11:00
b845e890d1 chore(nodes): remove deprecation logic for nodes API 2024-03-01 10:42:33 +11:00
6d31bc5326 chore(nodes): export model-related objects from invocation_api 2024-03-01 10:42:33 +11:00
0f8af643d1 chore(backend): rename ModelInfo -> LoadedModelInfo
We have two different classes named `ModelInfo` which might need to be used by API consumers. We need to export both but have to deal with this naming collision.

The `ModelInfo` I've renamed here is the one that is returned when a model is loaded. It's the object least likely to be used by API consumers.
2024-03-01 10:42:33 +11:00
e0694a2856 feat(nodes): use LATENT_SCALE_FACTOR in primitives.py, noise.py
- LatentsOutput.build
- NoiseOutput.build
- Noise.width, Noise.height multiple_of
2024-03-01 10:42:33 +11:00
e5d8921cf2 feat(nodes): extract LATENT_SCALE_FACTOR to constants.py 2024-03-01 10:42:33 +11:00
fece935438 feat(nodes): use TemporaryDirectory to handle ephemeral storage in ObjectSerializerDisk
Replace `delete_on_startup: bool` & associated logic with `ephemeral: bool` and `TemporaryDirectory`.

The temp dir is created inside of `output_dir`. For example, if `output_dir` is `invokeai/outputs/tensors/`, then the temp dir might be `invokeai/outputs/tensors/tmpvj35ht7b/`.

The temp dir is cleaned up when the service is stopped, or when it is GC'd if not properly stopped.

In the event of a catastrophic crash where the temp files are not cleaned up, the user can delete the tempdir themselves.

This situation may not occur in normal use, but if you kill the process, python cannot clean up the temp dir itself. This includes running the app in a debugger and killing the debugger process - something I do relatively often.

Tests updated.
2024-03-01 10:42:33 +11:00
11f64dab38 tests: test ObjectSerializerDisk class name extraction 2024-03-01 10:42:33 +11:00
670f2f75e9 chore(nodes): update ObjectSerializerForwardCache docstring 2024-03-01 10:42:33 +11:00
66d0ec3f6c chore(nodes): fix pyright ignore 2024-03-01 10:42:33 +11:00
6087ace4f1 tidy(nodes): "latents" -> "obj" 2024-03-01 10:42:33 +11:00
a9b1aad3d7 tidy(nodes): do not store unnecessarily store invoker 2024-03-01 10:42:33 +11:00
9edb995647 feat(nodes): make delete on startup configurable for obj serializer
- The default is to not delete on startup - feels safer.
- The two services using this class _do_ delete on startup.
- The class has "ephemeral" removed from its name.
- Tests & app updated for this change.
2024-03-01 10:42:33 +11:00
091f4cb583 fix(nodes): use metadata/board_id if provided by user, overriding WithMetadata/WithBoard-provided values 2024-03-01 10:42:33 +11:00
1655061c96 tidy(nodes): clarify comment 2024-03-01 10:42:33 +11:00
220baae793 Revert "feat(nodes): use LATENT_SCALE_FACTOR const in tensor output builders"
This reverts commit ef18fc546560277302f3886e456da9a47e8edce0.
2024-03-01 10:42:33 +11:00
e08f16763b feat(nodes): use LATENT_SCALE_FACTOR const in tensor output builders 2024-03-01 10:42:33 +11:00
6d25789705 tests: fix broken tests 2024-03-01 10:42:33 +11:00
aff44c0e58 tidy(nodes): minor spelling correction 2024-03-01 10:42:33 +11:00
34d23366f4 tests: add object serializer tests
These test both object serializer and its forward cache implementation.
2024-03-01 10:42:33 +11:00
23de78ec9f feat(nodes): allow _delete_all in obj serializer to be called at any time
`_delete_all` logged how many items it deleted, and had to be called _after_ service start bc it needed access to logger.

Move the logger call to the startup method and return the the deleted stats from `_delete_all`. This lets `_delete_all` be called at any time.
2024-03-01 10:42:33 +11:00
507aeac8a5 tidy(nodes): remove object serializer on_saved
It's unused.
2024-03-01 10:42:33 +11:00
9f382419dc revert(nodes): revert making tensors/conditioning use item storage
Turns out they are just different enough in purpose that the implementations would be rather unintuitive. I've made a separate ObjectSerializer service to handle tensors and conditioning.

Refined the class a bit too.
2024-03-01 10:42:33 +11:00
73d871116c feat(nodes): support custom exception in ephemeral disk storage 2024-03-01 10:42:33 +11:00
ab58d34f9b feat(nodes): support custom save and load functions in ItemStorageEphemeralDisk 2024-03-01 10:42:33 +11:00
9cda62c2a7 feat(nodes): create helper function to generate the item ID 2024-03-01 10:42:33 +11:00
a50c7c1cd7 feat(nodes): use ItemStorageABC for tensors and conditioning
Turns out `ItemStorageABC` was almost identical to `PickleStorageBase`. Instead of maintaining separate classes, we can use `ItemStorageABC` for both.

There's only one change needed - the `ItemStorageABC.set` method must return the newly stored item's ID. This allows us to let the service handle the responsibility of naming the item, but still create the requisite output objects during node execution.

The naming implementation is improved here. It extracts the name of the generic and appends a UUID to that string when saving items.
2024-03-01 10:42:33 +11:00
ca09bd63a3 tidy(nodes): do not refer to files as latents in PickleStorageTorch (again) 2024-03-01 10:42:33 +11:00
c96f50cc9a feat(nodes): ItemStorageABC typevar no longer bound to pydantic.BaseModel
This bound is totally unnecessary. There's no requirement for any implementation of `ItemStorageABC` to work only on pydantic models.
2024-03-01 10:42:33 +11:00
de63e888d6 fix(nodes): add super init to PickleStorageTorch 2024-03-01 10:42:33 +11:00
5dd158a2d4 tidy(nodes): do not refer to files as latents in PickleStorageTorch 2024-03-01 10:42:33 +11:00
0710fb3fb0 feat(nodes): replace latents service with tensors and conditioning services
- New generic class `PickleStorageBase`, implements the same API as `LatentsStorageBase`, use for storing non-serializable data via pickling
- Implementation `PickleStorageTorch` uses `torch.save` and `torch.load`, same as `LatentsStorageDisk`
- Add `tensors: PickleStorageBase[torch.Tensor]` to `InvocationServices`
- Add `conditioning: PickleStorageBase[ConditioningFieldData]` to `InvocationServices`
- Remove `latents` service and all `LatentsStorage` classes
- Update `InvocationContext` and all usage of old `latents` service to use the new services/context wrapper methods
2024-03-01 10:42:33 +11:00
31db62ba99 tidy(nodes): delete onnx.py
It doesn't work and keeping it updated to prevent the app from starting was getting tedious. Deleted.
2024-03-01 10:42:33 +11:00
322a60f48f fix(nodes): rearrange fields.py to avoid needing forward refs 2024-03-01 10:42:33 +11:00
b386b1b8af tidy(nodes): remove unnecessary, shadowing class attr declarations 2024-03-01 10:42:33 +11:00
70034d26e2 feat(ui): revise graphs to not use LinearUIOutputInvocation
See this comment for context: https://github.com/invoke-ai/InvokeAI/pull/5491#discussion_r1480760629

- Remove this now-unnecessary node from all graphs
- Update graphs' terminal image-outputting nodes' `is_intermediate` and `board` fields appropriately
- Add util function to prepare the `board` field, tidy the utils
- Update `socketInvocationComplete` listener to work correctly with this change

I've manually tested all graph permutations that were changed (I think this is all...) to ensure images go to the gallery as expected:
- ad-hoc upscaling
- t2i w/ sd1.5
- t2i w/ sd1.5 & hrf
- t2i w/ sdxl
- t2i w/ sdxl + refiner
- i2i w/ sd1.5
- i2i w/ sdxl
- i2i w/ sdxl + refiner
- canvas t2i w/ sd1.5
- canvas t2i w/ sdxl
- canvas t2i w/ sdxl + refiner
- canvas i2i w/ sd1.5
- canvas i2i w/ sdxl
- canvas i2i w/ sdxl + refiner
- canvas inpaint w/ sd1.5
- canvas inpaint w/ sdxl
- canvas inpaint w/ sdxl + refiner
- canvas outpaint w/ sd1.5
- canvas outpaint w/ sdxl
- canvas outpaint w/ sdxl + refiner
2024-03-01 10:42:33 +11:00
d60f1965d1 chore(ui): regen types 2024-03-01 10:42:33 +11:00
7fbdfbf9e5 feat(nodes): add WithBoard field helper class
This class works the same way as `WithMetadata` - it simply adds a `board` field to the node. The context wrapper function is able to pull the board id from this. This allows image-outputting nodes to get a board field "for free", and have their outputs automatically saved to it.

This is a breaking change for node authors who may have a field called `board`, because it makes `board` a reserved field name. I'll look into how to avoid this - maybe by naming this invoke-managed field `_board` to avoid collisions?

Supporting changes:
- `WithBoard` is added to all image-outputting nodes, giving them the ability to save to board.
- Unused, duplicate `WithMetadata` and `WithWorkflow` classes are deleted from `baseinvocation.py`. The "real" versions are in `fields.py`.
- Remove `LinearUIOutputInvocation`. Now that all nodes that output images also have a `board` field by default, this node is no longer necessary. See comment here for context: https://github.com/invoke-ai/InvokeAI/pull/5491#discussion_r1480760629
- Without `LinearUIOutputInvocation`, the `ImagesInferface.update` method is no longer needed, and removed.

Note: This commit does not bump all node versions. I will ensure that is done correctly before merging the PR of which this commit is a part.

Note: A followup commit will implement the frontend changes to support this change.
2024-03-01 10:42:33 +11:00
e137071543 remove unused configdict import 2024-03-01 10:42:33 +11:00
5d2f70b3ef fix(ui): remove original l2i node in HRF graph 2024-03-01 10:42:33 +11:00
47d05fdd81 fix(nodes): do not freeze or cache config in context wrapper
- The config is already cached by the config class's `get_config()` method.
- The config mutates itself in its `root_path` property getter. Freezing the class makes any attempt to grab a path from the config error. Unfortunately this means we cannot easily freeze the class without fiddling with the inner workings of `InvokeAIAppConfig`, which is outside the scope here.
2024-03-01 10:42:33 +11:00
958b80acdd feat(nodes): context.data -> context._data 2024-03-01 10:42:33 +11:00
5730ae9b96 feat(nodes): context.__services -> context._services 2024-03-01 10:42:33 +11:00
60e2eff94d feat(nodes): cache invocation interface config 2024-03-01 10:42:33 +11:00
dcafbb9988 feat(nodes): do not hide services in invocation context interfaces 2024-03-01 10:42:33 +11:00
cc8d713c57 fix(nodes): restore missing context type annotations 2024-03-01 10:42:33 +11:00
59c77832d8 tests(nodes): fix mock InvocationContext 2024-03-01 10:42:33 +11:00
cbf22d8a80 chore(nodes): add comments for ConfigInterface 2024-03-01 10:42:33 +11:00
e11af7de9b feat(nodes): export more things from `invocation_api" 2024-03-01 10:42:33 +11:00
95dd5aad16 feat(nodes): add boards interface to invocation context 2024-03-01 10:42:33 +11:00
4ce21087d3 fix(nodes): restore type annotations for InvocationContext 2024-03-01 10:42:33 +11:00
281c334531 feat(nodes): do not freeze InvocationContextData, prevents it from being subclassesd 2024-03-01 10:42:33 +11:00
282b483d14 feat: tweak pyright config 2024-03-01 10:42:33 +11:00
a466f7a94b feat(nodes): create invocation_api.py
This is the public API for invocations.

Everything a custom node might need should be re-exported from this file.
2024-03-01 10:42:33 +11:00
05fb485d33 feat(nodes): move ConditioningFieldData to conditioning_data.py 2024-03-01 10:42:33 +11:00
6452c706e1 tests: fix missing arg for InvocationContext 2024-03-01 10:42:33 +11:00
f612a96afd feat(nodes): restore previous invocation context methods with deprecation warnings 2024-03-01 10:42:33 +11:00
9af0553652 chore: ruff 2024-03-01 10:42:33 +11:00
1616974b48 feat(nodes): tidy invocation_context.py, improve comments 2024-03-01 10:42:33 +11:00
ef27283569 tests: fix tests for new invocation context 2024-03-01 10:42:33 +11:00
a79a450e9d docs: update INVOCATIONS.md 2024-03-01 10:42:33 +11:00
8637c40661 feat(nodes): update all invocations to use new invocation context
Update all invocations to use the new context. The changes are all fairly simple, but there are a lot of them.

Supporting minor changes:
- Patch bump for all nodes that use the context
- Update invocation processor to provide new context
- Minor change to `EventServiceBase` to accept a node's ID instead of the dict version of a node
- Minor change to `ModelManagerService` to support the new wrapped context
- Fanagling of imports to avoid circular dependencies
2024-03-01 10:42:33 +11:00
9bc2d09889 feat: add pyright config
I was having issues with mypy bother over- and under-reporting certain problems. I've added a pyright config.
2024-03-01 10:42:33 +11:00
3d98446d5d feat(nodes): restricts invocation context power
Creates a low-power `InvocationContext` with simplified methods and data.

See `invocation_context.py` for detailed comments.
2024-03-01 10:42:33 +11:00
992b02aa65 tidy(nodes): move all field things to fields.py
Unfortunately, this is necessary to prevent circular imports at runtime.
2024-03-01 10:42:33 +11:00
63ab5ff5a2 translationBot(ui): update translation (Russian)
Currently translated at 98.3% (1398 of 1422 strings)

Co-authored-by: Васянатор <ilabulanov339@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/
Translation: InvokeAI/Web UI
2024-02-29 23:27:36 +11:00
9a8a9c5848 translationBot(ui): update translation (Italian)
Currently translated at 98.0% (1441 of 1470 strings)

Co-authored-by: Samantha Morello <tildsart@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-02-29 23:27:36 +11:00
1a3ffb6e94 translationBot(ui): update translation (German)
Currently translated at 80.4% (1183 of 1470 strings)

Co-authored-by: Alexander Eichhorn <pfannkuchensack@einfach-doof.de>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2024-02-29 23:27:36 +11:00
3a09bceea4 Update communityNodes.md
Updated description of metadata nodes
2024-02-26 14:20:09 -05:00
2ec6b51d8b translationBot(ui): update translation (Italian)
Currently translated at 97.2% (1430 of 1470 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-02-26 17:41:00 +11:00
B N
34b0ea20dc translationBot(ui): update translation (German)
Currently translated at 80.3% (1181 of 1470 strings)

translationBot(ui): update translation (German)

Currently translated at 80.1% (1178 of 1470 strings)

Co-authored-by: B N <berndnieschalk@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2024-02-26 17:41:00 +11:00
9986fce1a6 translationBot(ui): update translation (German)
Currently translated at 80.0% (1176 of 1470 strings)

Co-authored-by: Alexander Eichhorn <pfannkuchensack@einfach-doof.de>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2024-02-23 07:57:15 +11:00
228f1d7f62 translationBot(ui): update translation (Italian)
Currently translated at 95.6% (1406 of 1470 strings)

translationBot(ui): update translation (Italian)

Currently translated at 93.9% (1381 of 1470 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-02-23 07:57:15 +11:00
B N
01a6378dc1 translationBot(ui): update translation (German)
Currently translated at 78.8% (1159 of 1470 strings)

Co-authored-by: B N <berndnieschalk@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2024-02-23 07:57:15 +11:00
e01769294f translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2024-02-20 22:33:03 +11:00
16aa261e28 updated tooltip popovers (#5751)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [X] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [ ] No


## Description
Added new tooltip popovers and updated copy of existing ones

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Merge Plan

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2024-02-19 13:12:47 -05:00
1dabf18d14 Merge branch 'main' into chainchompa/tooltip-popovers 2024-02-19 13:04:15 -05:00
115d92b1ae updated copy 2024-02-19 12:50:35 -05:00
f0d4c71960 updated tooltip popovers 2024-02-19 12:50:11 -05:00
3e48edda6f add latent-upscale to communityNodes.md (#5728)
Adds the 'latent upscale' community node
2024-02-19 16:53:35 +00:00
716b584f03 translationBot(ui): update translation (Italian)
Currently translated at 97.1% (1384 of 1424 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-02-19 08:18:33 +11:00
B N
d43b843c23 translationBot(ui): update translation (German)
Currently translated at 80.2% (1143 of 1424 strings)

Co-authored-by: B N <berndnieschalk@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2024-02-18 01:47:01 +11:00
f36b5990ed fix(ui): do not provide auth headers for openapi.json 2024-02-15 10:38:26 -05:00
5706237ec7 {release} 3.7.0 (#5727)
## What type of PR is this? (check all applicable)

Release - Invoke 3.7.0

## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [X] Yes
- [ ] No


## Description
Invoke 3.7.0 Release

## QA Instructions, Screenshots, Recordings
Test Installer: 

[InvokeAI-installer-v3.7.0.zip](https://github.com/invoke-ai/InvokeAI/files/14298200/InvokeAI-installer-v3.7.0.zip)

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Merge Plan
Merge once approved
<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->

## Added/updated tests?

- [ ] Yes
- [X] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
1. Release on PyPi
2. Release on GitHub
3. Announce on Discord
2024-02-15 07:59:20 -07:00
163b22a7b3 {release} 3.7.0 2024-02-15 07:34:31 -07:00
287 changed files with 8108 additions and 6793 deletions

View File

@ -0,0 +1,33 @@
name: install frontend dependencies
description: Installs frontend dependencies with pnpm, with caching
runs:
using: 'composite'
steps:
- name: setup node 18
uses: actions/setup-node@v4
with:
node-version: '18'
- name: setup pnpm
uses: pnpm/action-setup@v2
with:
version: 8
run_install: false
- name: get pnpm store directory
shell: bash
run: |
echo "STORE_PATH=$(pnpm store path --silent)" >> $GITHUB_ENV
- name: setup cache
uses: actions/cache@v4
with:
path: ${{ env.STORE_PATH }}
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
restore-keys: |
${{ runner.os }}-pnpm-store-
- name: install frontend dependencies
run: pnpm install --prefer-frozen-lockfile
shell: bash
working-directory: invokeai/frontend/web

28
.github/pr_labels.yml vendored
View File

@ -1,59 +1,59 @@
Root:
root:
- changed-files:
- any-glob-to-any-file: '*'
PythonDeps:
python-deps:
- changed-files:
- any-glob-to-any-file: 'pyproject.toml'
Python:
python:
- changed-files:
- all-globs-to-any-file:
- 'invokeai/**'
- '!invokeai/frontend/web/**'
PythonTests:
python-tests:
- changed-files:
- any-glob-to-any-file: 'tests/**'
CICD:
ci-cd:
- changed-files:
- any-glob-to-any-file: .github/**
Docker:
docker:
- changed-files:
- any-glob-to-any-file: docker/**
Installer:
installer:
- changed-files:
- any-glob-to-any-file: installer/**
Documentation:
docs:
- changed-files:
- any-glob-to-any-file: docs/**
Invocations:
invocations:
- changed-files:
- any-glob-to-any-file: 'invokeai/app/invocations/**'
Backend:
backend:
- changed-files:
- any-glob-to-any-file: 'invokeai/backend/**'
Api:
api:
- changed-files:
- any-glob-to-any-file: 'invokeai/app/api/**'
Services:
services:
- changed-files:
- any-glob-to-any-file: 'invokeai/app/services/**'
FrontendDeps:
frontend-deps:
- changed-files:
- any-glob-to-any-file:
- '**/*/package.json'
- '**/*/pnpm-lock.yaml'
Frontend:
frontend:
- changed-files:
- any-glob-to-any-file: 'invokeai/frontend/web/**'

View File

@ -11,7 +11,7 @@ on:
- 'docker/docker-entrypoint.sh'
- 'workflows/build-container.yml'
tags:
- 'v*'
- 'v*.*.*'
workflow_dispatch:
permissions:

45
.github/workflows/build-installer.yml vendored Normal file
View File

@ -0,0 +1,45 @@
# Builds and uploads the installer and python build artifacts.
name: build installer
on:
workflow_dispatch:
workflow_call:
jobs:
build-installer:
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <2 min
steps:
- name: checkout
uses: actions/checkout@v4
- name: setup python
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: install pypa/build
run: pip install --upgrade build
- name: setup frontend
uses: ./.github/actions/install-frontend-deps
- name: create installer
id: create_installer
run: ./create_installer.sh
working-directory: installer
- name: upload python distribution artifact
uses: actions/upload-artifact@v4
with:
name: dist
path: ${{ steps.create_installer.outputs.DIST_PATH }}
- name: upload installer artifact
uses: actions/upload-artifact@v4
with:
name: ${{ steps.create_installer.outputs.INSTALLER_FILENAME }}
path: ${{ steps.create_installer.outputs.INSTALLER_PATH }}

68
.github/workflows/frontend-checks.yml vendored Normal file
View File

@ -0,0 +1,68 @@
# Runs frontend code quality checks.
#
# Checks for changes to frontend files before running the checks.
# When manually triggered or when called from another workflow, always runs the checks.
name: 'frontend checks'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
workflow_call:
defaults:
run:
working-directory: invokeai/frontend/web
jobs:
frontend-checks:
runs-on: ubuntu-latest
timeout-minutes: 10 # expected run time: <2 min
steps:
- uses: actions/checkout@v4
- name: check for changed frontend files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
frontend:
- 'invokeai/frontend/web/**'
- name: install dependencies
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: ./.github/actions/install-frontend-deps
- name: tsc
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:tsc'
shell: bash
- name: dpdm
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:dpdm'
shell: bash
- name: eslint
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:eslint'
shell: bash
- name: prettier
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:prettier'
shell: bash
- name: knip
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:knip'
shell: bash

48
.github/workflows/frontend-tests.yml vendored Normal file
View File

@ -0,0 +1,48 @@
# Runs frontend tests.
#
# Checks for changes to frontend files before running the tests.
# When manually triggered or called from another workflow, always runs the tests.
name: 'frontend tests'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
workflow_call:
defaults:
run:
working-directory: invokeai/frontend/web
jobs:
frontend-tests:
runs-on: ubuntu-latest
timeout-minutes: 10 # expected run time: <2 min
steps:
- uses: actions/checkout@v4
- name: check for changed frontend files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
frontend:
- 'invokeai/frontend/web/**'
- name: install dependencies
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: ./.github/actions/install-frontend-deps
- name: vitest
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm test:no-watch'
shell: bash

View File

@ -1,6 +1,6 @@
name: "Pull Request Labeler"
name: 'label PRs'
on:
- pull_request_target
- pull_request_target
jobs:
labeler:
@ -9,8 +9,10 @@ jobs:
pull-requests: write
runs-on: ubuntu-latest
steps:
- name: Checkout
- name: checkout
uses: actions/checkout@v4
- uses: actions/labeler@v5
- name: label PRs
uses: actions/labeler@v5
with:
configuration-path: .github/pr_labels.yml

View File

@ -1,45 +0,0 @@
name: Lint frontend
on:
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
push:
branches:
- 'main'
merge_group:
workflow_dispatch:
defaults:
run:
working-directory: invokeai/frontend/web
jobs:
lint-frontend:
if: github.event.pull_request.draft == false
runs-on: ubuntu-22.04
steps:
- name: Setup Node 18
uses: actions/setup-node@v4
with:
node-version: '18'
- name: Checkout
uses: actions/checkout@v4
- name: Setup pnpm
uses: pnpm/action-setup@v2
with:
version: '8.12.1'
- name: Install dependencies
run: 'pnpm install --prefer-frozen-lockfile'
- name: Typescript
run: 'pnpm run lint:tsc'
- name: Madge
run: 'pnpm run lint:dpdm'
- name: ESLint
run: 'pnpm run lint:eslint'
- name: Prettier
run: 'pnpm run lint:prettier'
- name: Knip
run: 'pnpm run lint:knip'

View File

@ -1,51 +1,49 @@
name: mkdocs-material
# This is a mostly a copy-paste from https://github.com/squidfunk/mkdocs-material/blob/master/docs/publishing-your-site.md
name: mkdocs
on:
push:
branches:
- 'refs/heads/main'
- main
workflow_dispatch:
permissions:
contents: write
jobs:
mkdocs-material:
deploy:
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
env:
REPO_URL: '${{ github.server_url }}/${{ github.repository }}'
REPO_NAME: '${{ github.repository }}'
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
steps:
- name: checkout sources
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: checkout
uses: actions/checkout@v4
- name: setup python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: install requirements
env:
PIP_USE_PEP517: 1
run: |
python -m \
pip install ".[docs]"
- name: set cache id
run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- name: confirm buildability
run: |
python -m \
mkdocs build \
--clean \
--verbose
- name: use cache
uses: actions/cache@v4
with:
key: mkdocs-material-${{ env.cache_id }}
path: .cache
restore-keys: |
mkdocs-material-
- name: deploy to gh-pages
if: ${{ github.ref == 'refs/heads/main' }}
run: |
python -m \
mkdocs gh-deploy \
--clean \
--force
- name: install dependencies
run: python -m pip install ".[docs]"
- name: build & deploy
run: mkdocs gh-deploy --force

View File

@ -1,67 +0,0 @@
name: PyPI Release
on:
workflow_dispatch:
inputs:
publish_package:
description: 'Publish build on PyPi? [true/false]'
required: true
default: 'false'
jobs:
build-and-release:
if: github.repository == 'invoke-ai/InvokeAI'
runs-on: ubuntu-22.04
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
TWINE_NON_INTERACTIVE: 1
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Node 18
uses: actions/setup-node@v4
with:
node-version: '18'
- name: Setup pnpm
uses: pnpm/action-setup@v2
with:
version: '8.12.1'
- name: Install frontend dependencies
run: pnpm install --prefer-frozen-lockfile
working-directory: invokeai/frontend/web
- name: Build frontend
run: pnpm run build
working-directory: invokeai/frontend/web
- name: Install python dependencies
run: pip install --upgrade build twine
- name: Build python package
run: python3 -m build
- name: Upload build as workflow artifact
uses: actions/upload-artifact@v4
with:
name: dist
path: dist
- name: Check distribution
run: twine check dist/*
- name: Check PyPI versions
if: github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')
run: |
pip install --upgrade requests
python -c "\
import scripts.pypi_helper; \
EXISTS=scripts.pypi_helper.local_on_pypi(); \
print(f'PACKAGE_EXISTS={EXISTS}')" >> $GITHUB_ENV
- name: Publish build on PyPi
if: env.PACKAGE_EXISTS == 'False' && env.TWINE_PASSWORD != '' && github.event.inputs.publish_package == 'true'
run: twine upload dist/*

64
.github/workflows/python-checks.yml vendored Normal file
View File

@ -0,0 +1,64 @@
# Runs python code quality checks.
#
# Checks for changes to python files before running the checks.
# When manually triggered or called from another workflow, always runs the tests.
#
# TODO: Add mypy or pyright to the checks.
name: 'python checks'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
workflow_call:
jobs:
python-checks:
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min
steps:
- name: checkout
uses: actions/checkout@v4
- name: check for changed python files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: setup python
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: install ruff
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: pip install ruff
shell: bash
- name: ruff check
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: ruff check --output-format=github .
shell: bash
- name: ruff format
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: ruff format --check .
shell: bash

94
.github/workflows/python-tests.yml vendored Normal file
View File

@ -0,0 +1,94 @@
# Runs python tests on a matrix of python versions and platforms.
#
# Checks for changes to python files before running the tests.
# When manually triggered or called from another workflow, always runs the tests.
name: 'python tests'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
workflow_call:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
matrix:
strategy:
matrix:
python-version:
- '3.10'
- '3.11'
platform:
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
include:
- platform: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
- platform: linux-rocm-5_2
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
github-env: $GITHUB_ENV
- platform: linux-cpu
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- platform: macos-default
os: macOS-12
github-env: $GITHUB_ENV
- platform: windows-cpu
os: windows-2022
github-env: $env:GITHUB_ENV
name: 'py${{ matrix.python-version }}: ${{ matrix.platform }}'
runs-on: ${{ matrix.os }}
timeout-minutes: 15 # expected run time: 2-6 min, depending on platform
env:
PIP_USE_PEP517: '1'
steps:
- name: checkout
uses: actions/checkout@v4
- name: check for changed python files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: setup python
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: pyproject.toml
- name: install dependencies
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: >
pip3 install --editable=".[test]"
- name: run pytest
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: pytest

96
.github/workflows/release.yml vendored Normal file
View File

@ -0,0 +1,96 @@
# Main release workflow. Triggered on tag push or manual trigger.
#
# - Runs all code checks and tests
# - Verifies the app version matches the tag version.
# - Builds the installer and build, uploading them as artifacts.
# - Publishes to TestPyPI and PyPI. Both are conditional on the previous steps passing and require a manual approval.
#
# See docs/RELEASE.md for more information on the release process.
name: release
on:
push:
tags:
- 'v*'
workflow_dispatch:
jobs:
check-version:
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v4
- name: check python version
uses: samuelcolvin/check-python-version@v4
id: check-python-version
with:
version_file_path: invokeai/version/invokeai_version.py
frontend-checks:
uses: ./.github/workflows/frontend-checks.yml
frontend-tests:
uses: ./.github/workflows/frontend-tests.yml
python-checks:
uses: ./.github/workflows/python-checks.yml
python-tests:
uses: ./.github/workflows/python-tests.yml
build:
uses: ./.github/workflows/build-installer.yml
publish-testpypi:
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min
needs:
[
check-version,
frontend-checks,
frontend-tests,
python-checks,
python-tests,
build,
]
environment:
name: testpypi
url: https://test.pypi.org/p/invokeai
steps:
- name: download distribution from build job
uses: actions/download-artifact@v4
with:
name: dist
path: dist/
- name: publish distribution to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/
publish-pypi:
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min
needs:
[
check-version,
frontend-checks,
frontend-tests,
python-checks,
python-tests,
build,
]
environment:
name: pypi
url: https://pypi.org/p/invokeai
steps:
- name: download distribution from build job
uses: actions/download-artifact@v4
with:
name: dist
path: dist/
- name: publish distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

View File

@ -1,24 +0,0 @@
name: style checks
on:
pull_request:
push:
branches: main
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies with pip
run: |
pip install ruff
- run: ruff check --output-format=github .
- run: ruff format --check .

View File

@ -1,129 +0,0 @@
name: Test invoke.py pip
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
matrix:
if: github.event.pull_request.draft == false
strategy:
matrix:
python-version:
# - '3.9'
- '3.10'
pytorch:
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
include:
- pytorch: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
- pytorch: linux-rocm-5_2
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
github-env: $GITHUB_ENV
- pytorch: linux-cpu
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- pytorch: macos-default
os: macOS-12
github-env: $GITHUB_ENV
- pytorch: windows-cpu
os: windows-2022
github-env: $env:GITHUB_ENV
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
env:
PIP_USE_PEP517: '1'
steps:
- name: Checkout sources
id: checkout-sources
uses: actions/checkout@v3
- name: Check for changed python files
id: changed-files
uses: tj-actions/changed-files@v41
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: set test prompt to main branch validation
if: steps.changed-files.outputs.python_any_changed == 'true'
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
- name: setup python
if: steps.changed-files.outputs.python_any_changed == 'true'
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: pyproject.toml
- name: install invokeai
if: steps.changed-files.outputs.python_any_changed == 'true'
env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: >
pip3 install
--editable=".[test]"
- name: run pytest
if: steps.changed-files.outputs.python_any_changed == 'true'
id: run-pytest
run: pytest
# - name: run invokeai-configure
# env:
# HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
# run: >
# invokeai-configure
# --yes
# --default_only
# --full-precision
# # can't use fp16 weights without a GPU
# - name: run invokeai
# id: run-invokeai
# env:
# # Set offline mode to make sure configure preloaded successfully.
# HF_HUB_OFFLINE: 1
# HF_DATASETS_OFFLINE: 1
# TRANSFORMERS_OFFLINE: 1
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
# run: >
# invokeai
# --no-patchmatch
# --no-nsfw_checker
# --precision=float32
# --always_use_cpu
# --use_memory_db
# --outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
# --from_file ${{ env.TEST_PROMPTS }}
# - name: Archive results
# env:
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
# uses: actions/upload-artifact@v3
# with:
# name: results
# path: ${{ env.INVOKEAI_OUTDIR }}

View File

@ -7,7 +7,7 @@ embeddedLanguageFormatting: auto
overrides:
- files: '*.md'
options:
proseWrap: always
proseWrap: preserve
printWidth: 80
parser: markdown
cursorOffset: -1

View File

@ -10,10 +10,11 @@ help:
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
@echo "test" Run the unit tests.
@echo "frontend-install" Install the pnpm modules needed for the front end
@echo "test Run the unit tests."
@echo "frontend-install Install the pnpm modules needed for the front end"
@echo "frontend-build Build the frontend in order to run on localhost:9090"
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
@echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@ -53,6 +54,9 @@ frontend-build:
frontend-dev:
cd invokeai/frontend/web && pnpm dev
frontend-typegen:
cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen
# Installer zip file
installer-zip:
cd installer && ./create_installer.sh

142
docs/RELEASE.md Normal file
View File

@ -0,0 +1,142 @@
# Release Process
The app is published in twice, in different build formats.
- A [PyPI] distribution. This includes both a source distribution and built distribution (a wheel). Users install with `pip install invokeai`. The updater uses this build.
- An installer on the [InvokeAI Releases Page]. This is a zip file with install scripts and a wheel. This is only used for new installs.
## General Prep
Make a developer call-out for PRs to merge. Merge and test things out.
While the release workflow does not include end-to-end tests, it does pause before publishing so you can download and test the final build.
## Release Workflow
The `release.yml` workflow runs a number of jobs to handle code checks, tests, build and publish on PyPI.
It is triggered on **tag push**, when the tag matches `v*`. It doesn't matter if you've prepped a release branch like `release/v3.5.0` or are releasing from `main` - it works the same.
> Because commits are reference-counted, it is safe to create a release branch, tag it, let the workflow run, then delete the branch. So long as the tag exists, that commit will exist.
### Triggering the Workflow
Run `make tag-release` to tag the current commit and kick off the workflow.
The release may also be dispatched [manually].
### Workflow Jobs and Process
The workflow consists of a number of concurrently-run jobs, and two final publish jobs.
The publish jobs require manual approval and are only run if the other jobs succeed.
#### `check-version` Job
This job checks that the git ref matches the app version. It matches the ref against the `__version__` variable in `invokeai/version/invokeai_version.py`.
When the workflow is triggered by tag push, the ref is the tag. If the workflow is run manually, the ref is the target selected from the **Use workflow from** dropdown.
This job uses [samuelcolvin/check-python-version].
> Any valid [version specifier] works, so long as the tag matches the version. The release workflow works exactly the same for `RC`, `post`, `dev`, etc.
#### Check and Test Jobs
- **`python-tests`**: runs `pytest` on matrix of platforms
- **`python-checks`**: runs `ruff` (format and lint)
- **`frontend-tests`**: runs `vitest`
- **`frontend-checks`**: runs `prettier` (format), `eslint` (lint), `dpdm` (circular refs), `tsc` (static type check) and `knip` (unused imports)
> **TODO** We should add `mypy` or `pyright` to the **`check-python`** job.
> **TODO** We should add an end-to-end test job that generates an image.
#### `build-installer` Job
This sets up both python and frontend dependencies and builds the python package. Internally, this runs `installer/create_installer.sh` and uploads two artifacts:
- **`dist`**: the python distribution, to be published on PyPI
- **`InvokeAI-installer-${VERSION}.zip`**: the installer to be included in the GitHub release
#### Sanity Check & Smoke Test
At this point, the release workflow pauses as the remaining publish jobs require approval.
A maintainer should go to the **Summary** tab of the workflow, download the installer and test it. Ensure the app loads and generates.
> The same wheel file is bundled in the installer and in the `dist` artifact, which is uploaded to PyPI. You should end up with the exactly the same installation of the `invokeai` package from any of these methods.
#### PyPI Publish Jobs
The publish jobs will run if any of the previous jobs fail.
They use [GitHub environments], which are configured as [trusted publishers] on PyPI.
Both jobs require a maintainer to approve them from the workflow's **Summary** tab.
- Click the **Review deployments** button
- Select the environment (either `testpypi` or `pypi`)
- Click **Approve and deploy**
> **If the version already exists on PyPI, the publish jobs will fail.** PyPI only allows a given version to be published once - you cannot change it. If version published on PyPI has a problem, you'll need to "fail forward" by bumping the app version and publishing a followup release.
#### `publish-testpypi` Job
Publishes the distribution on the [Test PyPI] index, using the `testpypi` GitHub environment.
This job is not required for the production PyPI publish, but included just in case you want to test the PyPI release.
If approved and successful, you could try out the test release like this:
```sh
# Create a new virtual environment
python -m venv ~/.test-invokeai-dist --prompt test-invokeai-dist
# Install the distribution from Test PyPI
pip install --index-url https://test.pypi.org/simple/ invokeai
# Run and test the app
invokeai-web
# Cleanup
deactivate
rm -rf ~/.test-invokeai-dist
```
#### `publish-pypi` Job
Publishes the distribution on the production PyPI index, using the `pypi` GitHub environment.
## Publish the GitHub Release with installer
Once the release is published to PyPI, it's time to publish the GitHub release.
1. [Draft a new release] on GitHub, choosing the tag that triggered the release.
2. Write the release notes, describing important changes. The **Generate release notes** button automatically inserts the changelog and new contributors, and you can copy/paste the intro from previous releases.
3. Upload the zip file created in **`build`** job into the Assets section of the release notes. You can also upload the zip into the body of the release notes, since it can be hard for users to find the Assets section.
4. Check the **Set as a pre-release** and **Create a discussion for this release** checkboxes at the bottom of the release page.
5. Publish the pre-release.
6. Announce the pre-release in Discord.
> **TODO** Workflows can create a GitHub release from a template and upload release assets. One popular action to handle this is [ncipollo/release-action]. A future enhancement to the release process could set this up.
## Manual Build
The `build installer` workflow can be dispatched manually. This is useful to test the installer for a given branch or tag.
No checks are run, it just builds.
## Manual Release
The `release` workflow can be dispatched manually. You must dispatch the workflow from the right tag, else it will fail the version check.
This functionality is available as a fallback in case something goes wonky. Typically, releases should be triggered via tag push as described above.
[InvokeAI Releases Page]: https://github.com/invoke-ai/InvokeAI/releases
[PyPI]: https://pypi.org/
[Draft a new release]: https://github.com/invoke-ai/InvokeAI/releases/new
[Test PyPI]: https://test.pypi.org/
[version specifier]: https://packaging.python.org/en/latest/specifications/version-specifiers/
[ncipollo/release-action]: https://github.com/ncipollo/release-action
[GitHub environments]: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment
[trusted publishers]: https://docs.pypi.org/trusted-publishers/
[samuelcolvin/check-python-version]: https://github.com/samuelcolvin/check-python-version
[manually]: #manual-release

View File

@ -32,7 +32,6 @@ model. These are the:
Responsible for loading a model from disk
into RAM and VRAM and getting it ready for inference.
## Location of the Code
The four main services can be found in
@ -67,19 +66,17 @@ provides the following fields:
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
| `base_model` | BaseModelType | The base model that the model is compatible with |
| `path` | str | Location of model on disk |
| `original_hash` | str | Hash of the model when it was first installed |
| `current_hash` | str | Most recent hash of the model's contents |
| `hash` | str | Hash of the model |
| `description` | str | Human-readable description of the model (optional) |
| `source` | str | Model's source URL or repo id (optional) |
The `key` is a unique 32-character random ID which was generated at
install time. The `original_hash` field stores a hash of the model's
install time. The `hash` field stores a hash of the model's
contents at install time obtained by sampling several parts of the
model's files using the `imohash` library. Over the course of the
model's lifetime it may be transformed in various ways, such as
changing its precision or converting it from a .safetensors to a
diffusers model. When this happens, `original_hash` is unchanged, but
`current_hash` is updated to indicate the current contents.
diffusers model.
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
are defined in `invokeai.backend.model_manager.config`. They are also
@ -94,7 +91,6 @@ The `path` field can be absolute or relative. If relative, it is taken
to be relative to the `models_dir` setting in the user's
`invokeai.yaml` file.
### CheckpointConfig
This adds support for checkpoint configurations, and adds the
@ -228,9 +224,9 @@ The way it works is as follows:
1. Retrieve the value of the `model_config_db` option from the user's
`invokeai.yaml` config file.
2. If `model_config_db` is `auto` (the default), then:
- Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
* Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
opened on the passed connection and lock.
- Open up a new connection to `databases/invokeai.db` if `conn`
* Open up a new connection to `databases/invokeai.db` if `conn`
and/or `lock` are missing (see note below).
3. If `model_config_db` is a Path, then use `from_db_file`
to return the appropriate type of ModelRecordService.
@ -255,7 +251,7 @@ store = ModelRecordServiceBase.open(config, db_conn, lock)
Configurations can be retrieved in several ways.
#### get_model(key) -> AnyModelConfig:
#### get_model(key) -> AnyModelConfig
The basic functionality is to call the record store object's
`get_model()` method with the desired model's unique key. It returns
@ -272,28 +268,28 @@ print(model_conf.path)
If the key is unrecognized, this call raises an
`UnknownModelException`.
#### exists(key) -> AnyModelConfig:
#### exists(key) -> AnyModelConfig
Returns True if a model with the given key exists in the databsae.
#### search_by_path(path) -> AnyModelConfig:
#### search_by_path(path) -> AnyModelConfig
Returns the configuration of the model whose path is `path`. The path
is matched using a simple string comparison and won't correctly match
models referred to by different paths (e.g. using symbolic links).
#### search_by_name(name, base, type) -> List[AnyModelConfig]:
#### search_by_name(name, base, type) -> List[AnyModelConfig]
This method searches for models that match some combination of `name`,
`BaseType` and `ModelType`. Calling without any arguments will return
all the models in the database.
#### all_models() -> List[AnyModelConfig]:
#### all_models() -> List[AnyModelConfig]
Return all the model configs in the database. Exactly equivalent to
calling `search_by_name()` with no arguments.
#### search_by_tag(tags) -> List[AnyModelConfig]:
#### search_by_tag(tags) -> List[AnyModelConfig]
`tags` is a list of strings. This method returns a list of model
configs that contain all of the given tags. Examples:
@ -312,11 +308,11 @@ commercializable_models = [x for x in store.all_models() \
if x.license.contains('allowCommercialUse=Sell')]
```
#### version() -> str:
#### version() -> str
Returns the version of the database, currently at `3.2`
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase:
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase
This method exists to ease the transition from the previous version of
the model manager, in which `get_model()` took the three arguments
@ -337,7 +333,7 @@ model and pass its key to `get_model()`.
Several methods allow you to create and update stored model config
records.
#### add_model(key, config) -> AnyModelConfig:
#### add_model(key, config) -> AnyModelConfig
Given a key and a configuration, this will add the model's
configuration record to the database. `config` can either be a subclass of
@ -352,7 +348,7 @@ model with the same key is already in the database, or an
`InvalidModelConfigException` if a dict was passed and Pydantic
experienced a parse or validation error.
### update_model(key, config) -> AnyModelConfig:
### update_model(key, config) -> AnyModelConfig
Given a key and a configuration, this will update the model
configuration record in the database. `config` can be either a
@ -370,31 +366,31 @@ The `ModelInstallService` class implements the
shop for all your model install needs. It provides the following
functionality:
- Registering a model config record for a model already located on the
* Registering a model config record for a model already located on the
local filesystem, without moving it or changing its path.
- Installing a model alreadiy located on the local filesystem, by
* Installing a model alreadiy located on the local filesystem, by
moving it into the InvokeAI root directory under the
`models` folder (or wherever config parameter `models_dir`
specifies).
- Probing of models to determine their type, base type and other key
* Probing of models to determine their type, base type and other key
information.
- Interface with the InvokeAI event bus to provide status updates on
* Interface with the InvokeAI event bus to provide status updates on
the download, installation and registration process.
- Downloading a model from an arbitrary URL and installing it in
* Downloading a model from an arbitrary URL and installing it in
`models_dir`.
- Special handling for Civitai model URLs which allow the user to
* Special handling for Civitai model URLs which allow the user to
paste in a model page's URL or download link
- Special handling for HuggingFace repo_ids to recursively download
* Special handling for HuggingFace repo_ids to recursively download
the contents of the repository, paying attention to alternative
variants such as fp16.
- Saving tags and other metadata about the model into the invokeai database
* Saving tags and other metadata about the model into the invokeai database
when fetching from a repo that provides that type of information,
(currently only Civitai and HuggingFace).
@ -443,7 +439,6 @@ required parameters:
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
Once initialized, the installer will provide the following methods:
#### install_job = installer.heuristic_import(source, [config], [access_token])
@ -457,12 +452,12 @@ The `source` is a string that can be any of these forms
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
3. A HuggingFace repo_id with any of the following formats:
- `model/name` -- entire model
- `model/name:fp32` -- entire model, using the fp32 variant
- `model/name:fp16:vae` -- vae submodel, using the fp16 variant
- `model/name::vae` -- vae submodel, using default precision
- `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
- `model/name::path/to/model.safetensors` -- an individual model file, default variant
* `model/name` -- entire model
* `model/name:fp32` -- entire model, using the fp32 variant
* `model/name:fp16:vae` -- vae submodel, using the fp16 variant
* `model/name::vae` -- vae submodel, using default precision
* `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
* `model/name::path/to/model.safetensors` -- an individual model file, default variant
Note that by specifying a relative path to the top of the HuggingFace
repo, you can download and install arbitrary models files.
@ -566,7 +561,6 @@ details.
This is used for a model that is located on a locally-accessible Posix
filesystem, such as a local disk or networked fileshare.
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `path` | str | Path | None | Path to the model file or directory |
@ -625,7 +619,6 @@ HuggingFace has the most complicated `ModelSource` structure:
| `subfolder` | Path | None | Look for the model in a subfolder of the repo. |
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`.
The `variant` is one of the various diffusers formats that HuggingFace
@ -661,7 +654,6 @@ in. To download these files, you must provide an
`HfFolder.get_token()` will be called to fill it in with the cached
one.
#### Monitoring the install job process
When you create an install job with `import_model()`, it launches the
@ -682,7 +674,6 @@ The `ModelInstallJob` class has the following structure:
| `error_type` | `str` | Name of the exception that led to an error status |
| `error` | `str` | Traceback of the error |
If the `event_bus` argument was provided, events will also be
broadcast to the InvokeAI event bus. The events will appear on the bus
as an event of type `EventServiceBase.model_event`, a timestamp and
@ -702,14 +693,13 @@ following keys:
| `total_bytes` | int | Total size of all the files that make up the model |
| `parts` | List[Dict]| Information on the progress of the individual files that make up the model |
The parts is a list of dictionaries that give information on each of
the components pieces of the download. The dictionary's keys are
`source`, `local_path`, `bytes` and `total_bytes`, and correspond to
the like-named keys in the main event.
Note that downloading events will not be issued for local models, and
that downloading events occur *before* the running event.
that downloading events occur _before_ the running event.
##### `model_install_running`
@ -752,7 +742,6 @@ properties: `waiting`, `downloading`, `running`, `complete`, `errored`
and `cancelled`, as well as `in_terminal_state`. The last will return
True if the job is in the complete, errored or cancelled states.
#### Model configuration and probing
The install service uses the `invokeai.backend.model_manager.probe`
@ -862,7 +851,6 @@ This method is similar to `unregister()`, but also unconditionally
deletes the corresponding model weights file(s), regardless of whether
they are inside or outside the InvokeAI models hierarchy.
#### path = installer.download_and_cache(remote_source, [access_token], [timeout])
This utility routine will download the model file located at source,
@ -974,7 +962,7 @@ is in its lifecycle. Values are defined in the string enum
`DownloadJobStatus`, a symbol available from
`invokeai.app.services.download_manager`. Possible values are:
| **Value** | **String Value** | ** Description ** |
| **Value** | **String Value** | **Description** |
|--------------|---------------------|-------------------|
| `IDLE` | idle | Job created, but not submitted to the queue |
| `ENQUEUED` | enqueued | Job is patiently waiting on the queue |
@ -1040,11 +1028,11 @@ While a job is being downloaded, the queue will emit events at
periodic intervals. A typical series of events during a successful
download session will look like this:
- enqueued
- running
- running
- running
- completed
* enqueued
* running
* running
* running
* completed
There will be a single enqueued event, followed by one or more running
events, and finally one `completed`, `error` or `cancelled`
@ -1053,12 +1041,12 @@ events.
It is possible for a caller to pause download temporarily, in which
case the events may look something like this:
- enqueued
- running
- running
- paused
- running
- completed
* enqueued
* running
* running
* paused
* running
* completed
The download queue logs when downloads start and end (unless `quiet`
is set to True at initialization time) but doesn't log any progress
@ -1187,7 +1175,6 @@ and is equivalent to manually specifying a destination of
Here is the full list of arguments that can be provided to
`create_download_job()`:
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `source` | Union[str, Path, AnyHttpUrl] | | Download remote or local source |
@ -1275,7 +1262,7 @@ for getting the model to run. For example "author" is metadata, while
"type", "base" and "format" are not. The latter fields are part of the
model's config, as defined in `invokeai.backend.model_manager.config`.
### Example Usage:
### Example Usage
```
from invokeai.backend.model_manager.metadata import (
@ -1328,7 +1315,6 @@ This is the common base class for metadata:
| `author` | str | Model's author |
| `tags` | Set[str] | Model tags |
Note that the model config record also has a `name` field. It is
intended that the config record version be locally customizable, while
the metadata version is read-only. However, enforcing this is expected
@ -1348,7 +1334,6 @@ This descends from `ModelMetadataBase` and adds the following fields:
| `last_modified`| datetime | Date of last commit of this model to the repo |
| `files` | List[Path] | List of the files in the model repo |
#### `CivitaiMetadata`
This descends from `ModelMetadataBase` and adds the following fields:
@ -1415,7 +1400,6 @@ testing suite to avoid hitting the internet.
The HuggingFace and Civitai fetcher subclasses add additional
repo-specific fetching methods:
#### HuggingFaceMetadataFetch
This overrides its base class `from_json()` method to return a
@ -1434,7 +1418,6 @@ retrieves its metadata. Functionally equivalent to `from_id()`, the
only difference is that it returna a `CivitaiMetadata` object rather
than an `AnyModelRepoMetadata`.
### Metadata Storage
The `ModelMetadataStore` provides a simple facility to store model
@ -1567,7 +1550,6 @@ The returned `LoadedModel` object contains a copy of the configuration
record returned by the model record `get_model()` method, as well as
the in-memory loaded model:
| **Attribute Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
@ -1581,7 +1563,6 @@ return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious.
`LoadedModel` acts as a context manager. The context loads the model
into the execution device (e.g. VRAM on CUDA systems), locks the model
in the execution device for the duration of the context, and returns
@ -1595,9 +1576,9 @@ with model_info as vae:
`get_model_by_key()` may raise any of the following exceptions:
- `UnknownModelException` -- key not in database
- `ModelNotFoundException` -- key in database but model not found at path
- `NotImplementedException` -- the loader doesn't know how to load this type of model
* `UnknownModelException` -- key not in database
* `ModelNotFoundException` -- key in database but model not found at path
* `NotImplementedException` -- the loader doesn't know how to load this type of model
### Emitting model loading events
@ -1724,6 +1705,7 @@ object, or in `context.services.model_manager` from within an
invocation.
In the examples below, we have retrieved the manager using:
```
mm = ApiDependencies.invoker.services.model_manager
```

View File

@ -0,0 +1,45 @@
# Invocation API
Each invocation's `invoke` method is provided a single arg - the Invocation
Context.
This object provides access to various methods, used to interact with the
application. Loading and saving images, logging messages, etc.
!!! warning ""
This API may shift slightly until the release of v4.0.0 as we work through a few final updates to the Model Manager.
```py
class MyInvocation(BaseInvocation):
...
def invoke(self, context: InvocationContext) -> ImageOutput:
image_pil = context.images.get_pil(image_name)
# Do something to the image
image_dto = context.images.save(image_pil)
# Log a message
context.logger.info(f"Did something cool, image saved!")
...
```
<!-- prettier-ignore-start -->
::: invokeai.app.services.shared.invocation_context.InvocationContext
options:
members: false
::: invokeai.app.services.shared.invocation_context.ImagesInterface
::: invokeai.app.services.shared.invocation_context.TensorsInterface
::: invokeai.app.services.shared.invocation_context.ConditioningInterface
::: invokeai.app.services.shared.invocation_context.ModelsInterface
::: invokeai.app.services.shared.invocation_context.LoggerInterface
::: invokeai.app.services.shared.invocation_context.ConfigInterface
::: invokeai.app.services.shared.invocation_context.UtilInterface
::: invokeai.app.services.shared.invocation_context.BoardsInterface
<!-- prettier-ignore-end -->

View File

@ -0,0 +1,148 @@
# Invoke v4.0.0 Nodes API Migration guide
Invoke v4.0.0 is versioned as such due to breaking changes to the API utilized
by nodes, both core and custom.
## Motivation
Prior to v4.0.0, the `invokeai` python package has not be set up to be utilized
as a library. That is to say, it didn't have any explicitly public API, and node
authors had to work with the unstable internal application API.
v4.0.0 introduces a stable public API for nodes.
## Changes
There are two node-author-facing changes:
1. Import Paths
1. Invocation Context API
### Import Paths
All public objects are now exported from `invokeai.invocation_api`:
```py
# Old
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
InputField,
InvocationContext,
invocation,
)
from invokeai.app.invocations.primitives import ImageField
# New
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
InvocationContext,
invocation,
)
```
It's possible that we've missed some classes you need in your node. Please let
us know if that's the case.
### Invocation Context API
Most nodes utilize the Invocation Context, an object that is passed to the
`invoke` that provides access to data and services a node may need.
Until now, that object and the services it exposed were internal. Exposing them
to nodes means that changes to our internal implementation could break nodes.
The methods on the services are also often fairly complicated and allowed nodes
to footgun.
In v4.0.0, this object has been refactored to be much simpler.
See [INVOCATION_API](./INVOCATION_API.md) for full details of the API.
!!! warning ""
This API may shift slightly until the release of v4.0.0 as we work through a few final updates to the Model Manager.
#### Improved Service Methods
The biggest offender was the image save method:
```py
# Old
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=context.workflow,
)
# New
image_dto = context.images.save(image=image)
```
Other methods are simplified, or enhanced with additional functionality:
```py
# Old
image = context.services.images.get_pil_image(image_name)
# New
image = context.images.get_pil(image_name)
image_cmyk = context.images.get_pil(image_name, "CMYK")
```
We also had some typing issues around tensors:
```py
# Old
# `latents` typed as `torch.Tensor`, but could be `ConditioningFieldData`
latents = context.services.latents.get(self.latents.latents_name)
# `data` typed as `torch.Tenssor,` but could be `ConditioningFieldData`
context.services.latents.save(latents_name, data)
# New - separate methods for tensors and conditioning data w/ correct typing
# Also, the service generates the names
tensor_name = context.tensors.save(tensor)
tensor = context.tensors.load(tensor_name)
# For conditioning
cond_name = context.conditioning.save(cond_data)
cond_data = context.conditioning.load(cond_name)
```
#### Output Construction
Core Outputs have builder functions right on them - no need to manually
construct these objects, or use an extra utility:
```py
# Old
image_output = ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
latents_output = build_latents_output(latents_name=name, latents=latents, seed=None)
noise_output = NoiseOutput(
noise=LatentsField(latents_name=latents_name, seed=seed),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
cond_output = ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
# New
image_output = ImageOutput.build(image_dto)
latents_output = LatentsOutput.build(latents_name=name, latents=noise, seed=self.seed)
noise_output = NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed)
cond_output = ConditioningOutput.build(conditioning_name)
```
You can still create the objects using constructors if you want, but we suggest
using the builder methods.

View File

@ -32,6 +32,7 @@ To use a community workflow, download the the `.json` node graph file and load i
+ [Image to Character Art Image Nodes](#image-to-character-art-image-nodes)
+ [Image Picker](#image-picker)
+ [Image Resize Plus](#image-resize-plus)
+ [Latent Upscale](#latent-upscale)
+ [Load Video Frame](#load-video-frame)
+ [Make 3D](#make-3d)
+ [Mask Operations](#mask-operations)
@ -290,6 +291,13 @@ View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/image-resize-plus-node/master/.readme/node.png" width="500" />
--------------------------------
### Latent Upscale
**Description:** This node uses a small (~2.4mb) model to upscale the latents used in a Stable Diffusion 1.5 or Stable Diffusion XL image generation, rather than the typical interpolation method, avoiding the traditional downsides of the latent upscale technique.
**Node Link:** [https://github.com/gogurtenjoyer/latent-upscale](https://github.com/gogurtenjoyer/latent-upscale)
--------------------------------
### Load Video Frame
@ -346,12 +354,21 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
**Description:** A set of nodes for Metadata. Collect Metadata from within an `iterate` node & extract metadata from an image.
- `Metadata Item Linked` - Allows collecting of metadata while within an iterate node with no need for a collect node or conversion to metadata node.
- `Metadata From Image` - Provides Metadata from an image.
- `Metadata To String` - Extracts a String value of a label from metadata.
- `Metadata To Integer` - Extracts an Integer value of a label from metadata.
- `Metadata To Float` - Extracts a Float value of a label from metadata.
- `Metadata To Scheduler` - Extracts a Scheduler value of a label from metadata.
- `Metadata Item Linked` - Allows collecting of metadata while within an iterate node with no need for a collect node or conversion to metadata node
- `Metadata From Image` - Provides Metadata from an image
- `Metadata To String` - Extracts a String value of a label from metadata
- `Metadata To Integer` - Extracts an Integer value of a label from metadata
- `Metadata To Float` - Extracts a Float value of a label from metadata
- `Metadata To Scheduler` - Extracts a Scheduler value of a label from metadata
- `Metadata To Bool` - Extracts Bool types from metadata
- `Metadata To Model` - Extracts model types from metadata
- `Metadata To SDXL Model` - Extracts SDXL model types from metadata
- `Metadata To LoRAs` - Extracts Loras from metadata.
- `Metadata To SDXL LoRAs` - Extracts SDXL Loras from metadata
- `Metadata To ControlNets` - Extracts ControNets from metadata
- `Metadata To IP-Adapters` - Extracts IP-Adapters from metadata
- `Metadata To T2I-Adapters` - Extracts T2I-Adapters from metadata
- `Denoise Latents + Metadata` - This is an inherited version of the existing `Denoise Latents` node but with a metadata input and output.
**Node Link:** https://github.com/skunkworxdark/metadata-linked-nodes

View File

@ -19,6 +19,8 @@ their descriptions.
| Conditioning Primitive | A conditioning tensor primitive value |
| Content Shuffle Processor | Applies content shuffle processing to image |
| ControlNet | Collects ControlNet info to pass to other nodes |
| Create Denoise Mask | Converts a greyscale or transparency image into a mask for denoising. |
| Create Gradient Mask | Creates a mask for Gradient ("soft", "differential") inpainting that gradually expands during denoising. Improves edge coherence. |
| Denoise Latents | Denoises noisy latents to decodable images |
| Divide Integers | Divides two numbers |
| Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator |

View File

@ -1,5 +0,0 @@
mkdocs
mkdocs-material>=8, <9
mkdocs-git-revision-date-localized-plugin
mkdocs-redirects==1.2.0

View File

@ -1,5 +0,0 @@
:root {
--md-primary-fg-color: #35A4DB;
--md-primary-fg-color--light: #35A4DB;
--md-primary-fg-color--dark: #35A4DB;
}

View File

@ -2,22 +2,18 @@
set -e
BCYAN="\e[1;36m"
BYELLOW="\e[1;33m"
BGREEN="\e[1;32m"
BRED="\e[1;31m"
RED="\e[31m"
RESET="\e[0m"
function is_bin_in_path {
builtin type -P "$1" &>/dev/null
}
BCYAN="\033[1;36m"
BYELLOW="\033[1;33m"
BGREEN="\033[1;32m"
BRED="\033[1;31m"
RED="\033[31m"
RESET="\033[0m"
function git_show {
git show -s --format=oneline --abbrev-commit "$1" | cat
}
if [[ -v "VIRTUAL_ENV" ]]; then
if [[ ! -z "${VIRTUAL_ENV}" ]]; then
# we can't just call 'deactivate' because this function is not exported
# to the environment of this script from the bash process that runs the script
echo -e "${BRED}A virtual environment is activated. Please deactivate it before proceeding.${RESET}"
@ -26,31 +22,63 @@ fi
cd "$(dirname "$0")"
echo
echo -e "${BYELLOW}This script must be run from the installer directory!${RESET}"
echo "The current working directory is $(pwd)"
read -p "If that looks right, press any key to proceed, or CTRL-C to exit..."
echo
# Some machines only have `python3` in PATH, others have `python` - make an alias.
# We can use a function to approximate an alias within a non-interactive shell.
if ! is_bin_in_path python && is_bin_in_path python3; then
function python {
python3 "$@"
}
fi
VERSION=$(
cd ..
python -c "from invokeai.version import __version__ as version; print(version)"
python3 -c "from invokeai.version import __version__ as version; print(version)"
)
PATCH=""
VERSION="v${VERSION}${PATCH}"
VERSION="v${VERSION}"
if [[ ! -z ${CI} ]]; then
echo
echo -e "${BCYAN}CI environment detected${RESET}"
echo
else
echo
echo -e "${BYELLOW}This script must be run from the installer directory!${RESET}"
echo "The current working directory is $(pwd)"
read -p "If that looks right, press any key to proceed, or CTRL-C to exit..."
echo
fi
echo -e "${BGREEN}HEAD${RESET}:"
git_show HEAD
echo
# ---------------------- FRONTEND ----------------------
pushd ../invokeai/frontend/web >/dev/null
echo "Installing frontend dependencies..."
echo
pnpm i --frozen-lockfile
echo
if [[ ! -z ${CI} ]]; then
echo "Building frontend without checks..."
# In CI, we have already done the frontend checks and can just build
pnpm vite build
else
echo "Running checks and building frontend..."
# This runs all the frontend checks and builds
pnpm build
fi
echo
popd
# ---------------------- BACKEND ----------------------
echo
echo "Building wheel..."
echo
# install the 'build' package in the user site packages, if needed
# could be improved by using a temporary venv, but it's tiny and harmless
if [[ $(python3 -c 'from importlib.util import find_spec; print(find_spec("build") is None)') == "True" ]]; then
pip install --user build
fi
rm -rf ../build
python3 -m build --outdir dist/ ../.
# ----------------------
echo
@ -78,10 +106,28 @@ chmod a+x InvokeAI-Installer/install.sh
cp install.bat.in InvokeAI-Installer/install.bat
cp WinLongPathsEnabled.reg InvokeAI-Installer/
# Zip everything up
zip -r InvokeAI-installer-$VERSION.zip InvokeAI-Installer
FILENAME=InvokeAI-installer-$VERSION.zip
# clean up
rm -rf InvokeAI-Installer tmp dist ../invokeai/frontend/web/dist/
# Zip everything up
zip -r ${FILENAME} InvokeAI-Installer
echo
echo -e "${BGREEN}Built installer: ./${FILENAME}${RESET}"
echo -e "${BGREEN}Built PyPi distribution: ./dist${RESET}"
# clean up, but only if we are not in a github action
if [[ -z ${CI} ]]; then
echo
echo "Cleaning up intermediate build files..."
rm -rf InvokeAI-Installer tmp ../invokeai/frontend/web/dist/
fi
if [[ ! -z ${CI} ]]; then
echo
echo "Setting GitHub action outputs..."
echo "INSTALLER_FILENAME=${FILENAME}" >>$GITHUB_OUTPUT
echo "INSTALLER_PATH=installer/${FILENAME}" >>$GITHUB_OUTPUT
echo "DIST_PATH=installer/dist/" >>$GITHUB_OUTPUT
fi
exit 0

View File

@ -2,12 +2,12 @@
set -e
BCYAN="\e[1;36m"
BYELLOW="\e[1;33m"
BGREEN="\e[1;32m"
BRED="\e[1;31m"
RED="\e[31m"
RESET="\e[0m"
BCYAN="\033[1;36m"
BYELLOW="\033[1;33m"
BGREEN="\033[1;32m"
BRED="\033[1;31m"
RED="\033[31m"
RESET="\033[0m"
function does_tag_exist {
git rev-parse --quiet --verify "refs/tags/$1" >/dev/null
@ -23,49 +23,40 @@ function git_show {
VERSION=$(
cd ..
python -c "from invokeai.version import __version__ as version; print(version)"
python3 -c "from invokeai.version import __version__ as version; print(version)"
)
PATCH=""
MAJOR_VERSION=$(echo $VERSION | sed 's/\..*$//')
VERSION="v${VERSION}${PATCH}"
LATEST_TAG="v${MAJOR_VERSION}-latest"
if does_tag_exist $VERSION; then
echo -e "${BCYAN}${VERSION}${RESET} already exists:"
git_show_ref tags/$VERSION
echo
fi
if does_tag_exist $LATEST_TAG; then
echo -e "${BCYAN}${LATEST_TAG}${RESET} already exists:"
git_show_ref tags/$LATEST_TAG
echo
fi
echo -e "${BGREEN}HEAD${RESET}:"
git_show
echo
echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} and ${BCYAN}${LATEST_TAG}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on remote${RESET}? "
echo -e "${BGREEN}git remote -v${RESET}:"
git remote -v
echo
echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on origin remote${RESET}? "
read -e -p 'y/n [n]: ' input
RESPONSE=${input:='n'}
if [ "$RESPONSE" == 'y' ]; then
echo
echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on remote..."
git push --delete origin $VERSION
echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on origin remote..."
git push origin :refs/tags/$VERSION
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} locally..."
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} on locally..."
if ! git tag -fa $VERSION; then
echo "Existing/invalid tag"
exit -1
fi
echo -e "Deleting ${BCYAN}${LATEST_TAG}${RESET} tag on remote..."
git push --delete origin $LATEST_TAG
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${LATEST_TAG}${RESET} locally..."
git tag -fa $LATEST_TAG
echo -e "Pushing updated tags to remote..."
echo -e "Pushing updated tags to origin remote..."
git push origin --tags
fi
exit 0

0
invokeai/app/__init__.py Normal file
View File

View File

@ -25,8 +25,8 @@ from ..services.invocation_cache.invocation_cache_memory import MemoryInvocation
from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker
from ..services.model_images.model_images_default import ModelImageFileStorageDisk
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_metadata import ModelMetadataStoreSQL
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
@ -72,6 +72,8 @@ class ApiDependencies:
image_files = DiskImageFileStorage(f"{output_folder}/images")
model_images_folder = config.models_path
db = init_db(config=config, logger=logger, image_files=image_files)
configuration = config
@ -93,10 +95,10 @@ class ApiDependencies:
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
)
download_queue_service = DownloadQueueService(event_bus=events)
model_metadata_service = ModelMetadataStoreSQL(db=db)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
model_record_service=ModelRecordServiceSQL(db=db),
download_queue=download_queue_service,
events=events,
)
@ -120,6 +122,7 @@ class ApiDependencies:
images=images,
invocation_cache=invocation_cache,
logger=logger,
model_images=model_images_service,
model_manager=model_manager,
download_queue=download_queue_service,
names=names,

View File

@ -1,27 +1,26 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records."""
import io
import pathlib
import shutil
from hashlib import sha1
from random import randbytes
from typing import Any, Dict, List, Optional, Set
import traceback
from typing import Any, Dict, List, Optional
from fastapi import Body, Path, Query, Response
from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
ModelRecordOrderBy,
ModelSummary,
UnknownModelException,
)
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
@ -30,14 +29,15 @@ from invokeai.backend.model_manager.config import (
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
# images are immutable; set a high max-age
IMAGE_MAX_AGE = 31536000
class ModelsList(BaseModel):
"""Return list of configs."""
@ -47,15 +47,6 @@ class ModelsList(BaseModel):
model_config = ConfigDict(use_enum_values=True)
class ModelTagSet(BaseModel):
"""Return tags for a set of models."""
key: str
name: str
author: str
tags: Set[str]
##############################################################################
# These are example inputs and outputs that are used in places where Swagger
# is unable to generate a correct example.
@ -66,19 +57,16 @@ example_model_config = {
"base": "sd-1",
"type": "main",
"format": "checkpoint",
"config": "string",
"config_path": "string",
"key": "string",
"original_hash": "string",
"current_hash": "string",
"hash": "string",
"description": "string",
"source": "string",
"last_modified": 0,
"vae": "string",
"converted_at": 0,
"variant": "normal",
"prediction_type": "epsilon",
"repo_variant": "fp16",
"upcast_attention": False,
"ztsnr_training": False,
}
example_model_input = {
@ -87,50 +75,12 @@ example_model_input = {
"base": "sd-1",
"type": "main",
"format": "checkpoint",
"config": "configs/stable-diffusion/v1-inference.yaml",
"config_path": "configs/stable-diffusion/v1-inference.yaml",
"description": "Model description",
"vae": None,
"variant": "normal",
}
example_model_metadata = {
"name": "ip_adapter_sd_image_encoder",
"author": "InvokeAI",
"tags": [
"transformers",
"safetensors",
"clip_vision_model",
"endpoints_compatible",
"region:us",
"has_space",
"license:apache-2.0",
],
"files": [
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
"path": "ip_adapter_sd_image_encoder/README.md",
"size": 628,
"sha256": None,
},
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
"path": "ip_adapter_sd_image_encoder/config.json",
"size": 560,
"sha256": None,
},
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
"path": "ip_adapter_sd_image_encoder/model.safetensors",
"size": 2528373448,
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
},
],
"type": "huggingface",
"id": "InvokeAI/ip_adapter_sd_image_encoder",
"tag_dict": {"license": "apache-2.0"},
"last_modified": "2023-09-23T17:33:25Z",
}
##############################################################################
# ROUTES
##############################################################################
@ -162,6 +112,9 @@ async def list_model_records(
found_models.extend(
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
)
for model in found_models:
cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key)
model.cover_image = cover_image
return ModelsList(models=found_models)
@ -205,53 +158,23 @@ async def get_model_record(
record_store = ApiDependencies.invoker.services.model_manager.store
try:
config: AnyModelConfig = record_store.get_model(key)
cover_image = ApiDependencies.invoker.services.model_images.get_url(key)
config.cover_image = cover_image
return config
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@model_manager_router.get("/summary", operation_id="list_model_summary")
async def list_model_summary(
page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of models per page"),
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
) -> PaginatedResults[ModelSummary]:
"""Gets a page of model summary data."""
record_store = ApiDependencies.invoker.services.model_manager.store
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
return results
@model_manager_router.get(
"/i/{key}/metadata",
operation_id="get_model_metadata",
responses={
200: {
"description": "The model metadata was retrieved successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
},
)
async def get_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
) -> Optional[AnyModelRepoMetadata]:
"""Get a model metadata object."""
record_store = ApiDependencies.invoker.services.model_manager.store
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
return result
@model_manager_router.get(
"/tags",
operation_id="list_tags",
)
async def list_tags() -> Set[str]:
"""Get a unique set of all the model tags."""
record_store = ApiDependencies.invoker.services.model_manager.store
result: Set[str] = record_store.list_tags()
return result
# @model_manager_router.get("/summary", operation_id="list_model_summary")
# async def list_model_summary(
# page: int = Query(default=0, description="The page to get"),
# per_page: int = Query(default=10, description="The number of models per page"),
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
# ) -> PaginatedResults[ModelSummary]:
# """Gets a page of model summary data."""
# record_store = ApiDependencies.invoker.services.model_manager.store
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
# return results
class FoundModel(BaseModel):
@ -323,19 +246,6 @@ async def scan_for_models(
return scan_results
@model_manager_router.get(
"/tags/search",
operation_id="search_by_metadata_tags",
)
async def search_by_metadata_tags(
tags: Set[str] = Query(default=None, description="Tags to search for"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_manager.store
results = record_store.search_by_metadata_tag(tags)
return ModelsList(models=results)
@model_manager_router.patch(
"/i/{key}",
operation_id="update_model_record",
@ -352,15 +262,13 @@ async def search_by_metadata_tags(
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
"""Update a model's config."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
try:
model_response: AnyModelConfig = record_store.update_model(key, config=info)
model_response: AnyModelConfig = record_store.update_model(key, changes=changes)
logger.info(f"Updated model: {key}")
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@ -370,16 +278,85 @@ async def update_model_record(
return model_response
@model_manager_router.get(
"/i/{key}/image",
operation_id="get_model_image",
responses={
200: {
"description": "The model image was fetched successfully",
},
400: {"description": "Bad request"},
404: {"description": "The model image could not be found"},
},
status_code=200,
)
async def get_model_image(
key: str = Path(description="The name of model image file to get"),
) -> FileResponse:
"""Gets an image file that previews the model"""
try:
path = ApiDependencies.invoker.services.model_images.get_path(key)
response = FileResponse(
path,
media_type="image/png",
filename=key + ".png",
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception:
raise HTTPException(status_code=404)
@model_manager_router.patch(
"/i/{key}/image",
operation_id="update_model_image",
responses={
200: {
"description": "The model image was updated successfully",
},
400: {"description": "Bad request"},
},
status_code=200,
)
async def update_model_image(
key: Annotated[str, Path(description="Unique key of model")],
image: UploadFile,
) -> None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
logger = ApiDependencies.invoker.services.logger
model_images = ApiDependencies.invoker.services.model_images
try:
model_images.save(pil_image, key)
logger.info(f"Updated image for model: {key}")
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return
@model_manager_router.delete(
"/i/{key}",
operation_id="del_model_record",
operation_id="delete_model",
responses={
204: {"description": "Model deleted successfully"},
404: {"description": "Model not found"},
},
status_code=204,
)
async def del_model_record(
async def delete_model(
key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response:
"""
@ -400,42 +377,62 @@ async def del_model_record(
raise HTTPException(status_code=404, detail=str(e))
@model_manager_router.post(
"/i/",
operation_id="add_model_record",
@model_manager_router.delete(
"/i/{key}/image",
operation_id="delete_model_image",
responses={
201: {
"description": "The model added successfully",
"content": {"application/json": {"example": example_model_config}},
204: {"description": "Model image deleted successfully"},
404: {"description": "Model image not found"},
},
409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"},
},
status_code=201,
status_code=204,
)
async def add_model_record(
config: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
) -> AnyModelConfig:
"""Add a model using the configuration information appropriate for its type."""
async def delete_model_image(
key: str = Path(description="Unique key of model image to remove from model_images directory."),
) -> None:
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
if config.key == "<NOKEY>":
config.key = sha1(randbytes(100)).hexdigest()
logger.info(f"Created model {config.key} for {config.name}")
model_images = ApiDependencies.invoker.services.model_images
try:
record_store.add_model(config.key, config)
except DuplicateModelException as e:
model_images.delete(key)
logger.info(f"Deleted model image: {key}")
return
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
raise HTTPException(status_code=404, detail=str(e))
# now fetch it out
result: AnyModelConfig = record_store.get_model(config.key)
return result
# @model_manager_router.post(
# "/i/",
# operation_id="add_model_record",
# responses={
# 201: {
# "description": "The model added successfully",
# "content": {"application/json": {"example": example_model_config}},
# },
# 409: {"description": "There is already a model corresponding to this path or repo_id"},
# 415: {"description": "Unrecognized file/folder format"},
# },
# status_code=201,
# )
# async def add_model_record(
# config: Annotated[
# AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
# ],
# ) -> AnyModelConfig:
# """Add a model using the configuration information appropriate for its type."""
# logger = ApiDependencies.invoker.services.logger
# record_store = ApiDependencies.invoker.services.model_manager.store
# try:
# record_store.add_model(config)
# except DuplicateModelException as e:
# logger.error(str(e))
# raise HTTPException(status_code=409, detail=str(e))
# except InvalidModelException as e:
# logger.error(str(e))
# raise HTTPException(status_code=415)
# # now fetch it out
# result: AnyModelConfig = record_store.get_model(config.key)
# return result
@model_manager_router.post(
@ -451,6 +448,7 @@ async def add_model_record(
)
async def install_model(
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
# TODO(MM2): Can we type this?
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
@ -493,6 +491,7 @@ async def install_model(
source=source,
config=config,
access_token=access_token,
inplace=bool(inplace),
)
logger.info(f"Started installation of {source}")
except UnknownModelException as e:
@ -508,10 +507,10 @@ async def install_model(
@model_manager_router.get(
"/import",
operation_id="list_model_install_jobs",
"/install",
operation_id="list_model_installs",
)
async def list_model_install_jobs() -> List[ModelInstallJob]:
async def list_model_installs() -> List[ModelInstallJob]:
"""Return the list of model install jobs.
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
@ -525,9 +524,8 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
* "cancelled" -- Job was cancelled before completion.
Once completed, information about the model such as its size, base
model, type, and metadata can be retrieved from the `config_out`
field. For multi-file models such as diffusers, information on individual files
can be retrieved from `download_parts`.
model and type can be retrieved from the `config_out` field. For multi-file models such as diffusers,
information on individual files can be retrieved from `download_parts`.
See the example and schema below for more information.
"""
@ -536,7 +534,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
@model_manager_router.get(
"/import/{id}",
"/install/{id}",
operation_id="get_model_install_job",
responses={
200: {"description": "Success"},
@ -556,7 +554,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
@model_manager_router.delete(
"/import/{id}",
"/install/{id}",
operation_id="cancel_model_install_job",
responses={
201: {"description": "The job was cancelled successfully"},
@ -574,8 +572,8 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
installer.cancel_job(job)
@model_manager_router.patch(
"/import",
@model_manager_router.delete(
"/install",
operation_id="prune_model_install_jobs",
responses={
204: {"description": "All completed and errored jobs have been pruned"},
@ -645,7 +643,7 @@ async def convert_model(
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
# loading the model will convert it into a cached diffusers file
model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler)
# Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key)
@ -654,7 +652,8 @@ async def convert_model(
# temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name
model_config.name = f"{original_name}.DELETE"
store.update_model(key, config=model_config)
changes = ModelRecordChanges(name=model_config.name)
store.update_model(key, changes=changes)
# install the diffusers
try:
@ -663,7 +662,7 @@ async def convert_model(
config={
"name": original_name,
"description": model_config.description,
"original_hash": model_config.original_hash,
"hash": model_config.hash,
"source": model_config.source,
},
)
@ -671,10 +670,6 @@ async def convert_model(
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
# get the original metadata
if orig_metadata := store.get_metadata(key):
store.metadata_store.add_metadata(new_key, orig_metadata)
# delete the original safetensors file
installer.delete(key)
@ -686,66 +681,66 @@ async def convert_model(
return new_config
@model_manager_router.put(
"/merge",
operation_id="merge",
responses={
200: {
"description": "Model converted successfully",
"content": {"application/json": {"example": example_model_config}},
},
400: {"description": "Bad request"},
404: {"description": "Model not found"},
409: {"description": "There is already a model registered at this location"},
},
)
async def merge(
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
force: bool = Body(
description="Force merging of models created with different versions of diffusers",
default=False,
),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
merge_dest_directory: Optional[str] = Body(
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
default=None,
),
) -> AnyModelConfig:
"""
Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
```
Argument Description [default]
-------- ----------------------
keys List of 2-3 model keys to merge together. All models must use the same base type.
merged_model_name Name for the merged model [Concat model names]
alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
merge_dest_directory Specify a directory to store the merged model in [models directory]
```
"""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
installer = ApiDependencies.invoker.services.model_manager.install
merger = ModelMerger(installer)
model_names = [installer.record_store.get_model(x).name for x in keys]
response = merger.merge_diffusion_models_and_save(
model_keys=keys,
merged_model_name=merged_model_name or "+".join(model_names),
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory=dest,
)
except UnknownModelException:
raise HTTPException(
status_code=404,
detail=f"One or more of the models '{keys}' not found",
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
# @model_manager_router.put(
# "/merge",
# operation_id="merge",
# responses={
# 200: {
# "description": "Model converted successfully",
# "content": {"application/json": {"example": example_model_config}},
# },
# 400: {"description": "Bad request"},
# 404: {"description": "Model not found"},
# 409: {"description": "There is already a model registered at this location"},
# },
# )
# async def merge(
# keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
# merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
# alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
# force: bool = Body(
# description="Force merging of models created with different versions of diffusers",
# default=False,
# ),
# interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
# merge_dest_directory: Optional[str] = Body(
# description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
# default=None,
# ),
# ) -> AnyModelConfig:
# """
# Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
# ```
# Argument Description [default]
# -------- ----------------------
# keys List of 2-3 model keys to merge together. All models must use the same base type.
# merged_model_name Name for the merged model [Concat model names]
# alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
# force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
# interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
# merge_dest_directory Specify a directory to store the merged model in [models directory]
# ```
# """
# logger = ApiDependencies.invoker.services.logger
# try:
# logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
# dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
# installer = ApiDependencies.invoker.services.model_manager.install
# merger = ModelMerger(installer)
# model_names = [installer.record_store.get_model(x).name for x in keys]
# response = merger.merge_diffusion_models_and_save(
# model_keys=keys,
# merged_model_name=merged_model_name or "+".join(model_names),
# alpha=alpha,
# interp=interp,
# force=force,
# merge_dest_directory=dest,
# )
# except UnknownModelException:
# raise HTTPException(
# status_code=404,
# detail=f"One or more of the models '{keys}' not found",
# )
# except ValueError as e:
# raise HTTPException(status_code=400, detail=str(e))
# return response

View File

@ -3,10 +3,8 @@
# values from the command line or config file.
import sys
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.version.invokeai_version import __version__
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
from .services.config import InvokeAIAppConfig
app_config = InvokeAIAppConfig.get_config()
@ -19,6 +17,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
import asyncio
import mimetypes
import socket
from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path
from typing import Any
@ -39,6 +38,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
@ -58,6 +58,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
BaseInvocation,
UIConfigBase,
)
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
if is_mps_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
@ -71,9 +72,25 @@ logger = InvokeAILogger.get_logger(config=app_config)
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
@asynccontextmanager
async def lifespan(app: FastAPI):
# Add startup event to load dependencies
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
yield
# Shut down threads
ApiDependencies.shutdown()
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI(title="Invoke - Community Edition", docs_url=None, redoc_url=None, separate_input_output_schemas=False)
app = FastAPI(
title="Invoke - Community Edition",
docs_url=None,
redoc_url=None,
separate_input_output_schemas=False,
lifespan=lifespan,
)
# Add event handler
event_handler_id: int = id(app)
@ -96,18 +113,6 @@ app.add_middleware(
app.add_middleware(GZipMiddleware, minimum_size=1000)
# Add startup event to load dependencies
@app.on_event("startup")
async def startup_event() -> None:
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
# Shut down threads
@app.on_event("shutdown")
async def shutdown_event() -> None:
ApiDependencies.shutdown()
# Include all routers
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(model_manager.model_manager_router, prefix="/api")

View File

@ -1,17 +1,11 @@
from typing import Iterator, List, Optional, Tuple, Union
from typing import Iterator, List, Optional, Tuple, Union, cast
import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
OutputField,
UIComponent,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
@ -25,13 +19,8 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
)
from invokeai.backend.util.devices import torch_dtype
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .model import ClipField
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import CLIPField
# unconditioned: Optional[torch.Tensor]
@ -57,7 +46,7 @@ class CompelInvocation(BaseInvocation):
description=FieldDescriptions.compel_prompt,
ui_component=UIComponent.Textarea,
)
clip: ClipField = InputField(
clip: CLIPField = InputField(
title="CLIP",
description=FieldDescriptions.clip,
input=Input.Connection,
@ -65,16 +54,16 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
tokenizer_info = context.models.load(self.clip.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
text_encoder_info = context.models.load(self.clip.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
@ -138,18 +127,18 @@ class SDXLPromptInvocationBase:
def run_clip_compel(
self,
context: InvocationContext,
clip_field: ClipField,
clip_field: CLIPField,
prompt: str,
get_pooled: bool,
lora_prefix: str,
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
tokenizer_info = context.models.load(clip_field.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
text_encoder_info = context.models.load(clip_field.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
# return zero on empty
if prompt == "" and zero_on_empty:
@ -174,7 +163,7 @@ class SDXLPromptInvocationBase:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
lora_info = context.models.load(lora.lora)
lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight)
@ -196,7 +185,8 @@ class SDXLPromptInvocationBase:
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
):
assert isinstance(text_encoder, CLIPTextModel)
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
text_encoder = cast(CLIPTextModel, text_encoder)
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
@ -263,8 +253,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
crop_left: int = InputField(default=0, description="")
target_width: int = InputField(default=1024, description="")
target_height: int = InputField(default=1024, description="")
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -350,7 +340,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
crop_top: int = InputField(default=0, description="")
crop_left: int = InputField(default=0, description="")
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -380,10 +370,10 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
@invocation_output("clip_skip_output")
class ClipSkipInvocationOutput(BaseInvocationOutput):
"""Clip skip node output"""
class CLIPSkipInvocationOutput(BaseInvocationOutput):
"""CLIP skip node output"""
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation(
@ -393,15 +383,15 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
category="conditioning",
version="1.0.0",
)
class ClipSkipInvocation(BaseInvocation):
class CLIPSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model."""
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
def invoke(self, context: InvocationContext) -> CLIPSkipInvocationOutput:
self.clip.skipped_layers += self.skipped_layers
return ClipSkipInvocationOutput(
return CLIPSkipInvocationOutput(
clip=self.clip,
)

View File

@ -34,6 +34,7 @@ from invokeai.app.invocations.fields import (
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
@ -51,15 +52,9 @@ CONTROLNET_RESIZE_VALUES = Literal[
]
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
key: str = Field(description="Model config record key for the ControlNet model")
class ControlField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
control_model: ModelField = Field(description="The ControlNet model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
@ -95,7 +90,7 @@ class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
control_model: ModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
)

View File

@ -228,7 +228,7 @@ class ConditioningField(BaseModel):
# endregion
class MetadataField(RootModel):
class MetadataField(RootModel[dict[str, Any]]):
"""
Pydantic model for metadata with custom root of type dict[str, Any].
Metadata is stored without a strict schema.

View File

@ -11,25 +11,17 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType, ModelType
# LS: Consider moving these two classes into model.py
class IPAdapterModelField(BaseModel):
key: str = Field(description="Key to the IP-Adapter model")
class CLIPVisionModelField(BaseModel):
key: str = Field(description="Key to the CLIP Vision image encoder model")
from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig, ModelType
class IPAdapterField(BaseModel):
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
ip_adapter_model: ModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: ModelField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
@ -62,7 +54,7 @@ class IPAdapterInvocation(BaseInvocation):
# Inputs
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
ip_adapter_model: IPAdapterModelField = InputField(
ip_adapter_model: ModelField = InputField(
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
)
@ -90,18 +82,18 @@ class IPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, IPAdapterConfig)
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_models = context.models.search_by_attrs(
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
assert len(image_encoder_models) == 1
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=image_encoder_model,
image_encoder_model=ModelField(key=image_encoder_models[0].key),
weight=self.weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,

View File

@ -26,6 +26,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
from PIL import Image, ImageFilter
from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import (
@ -75,7 +76,7 @@ from .baseinvocation import (
invocation_output,
)
from .controlnet_image_processors import ControlField
from .model import ModelInfo, UNetField, VaeField
from .model import ModelField, UNetField, VAEField
if choose_torch_device() == torch.device("mps"):
from torch import mps
@ -118,7 +119,7 @@ class SchedulerInvocation(BaseInvocation):
class CreateDenoiseMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
@ -153,7 +154,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
)
if image_tensor is not None:
vae_info = context.models.load(**self.vae.vae.model_dump())
vae_info = context.models.load(self.vae.vae)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
@ -173,6 +174,16 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
)
@invocation_output("gradient_mask_output")
class GradientMaskOutput(BaseInvocationOutput):
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
expanded_mask_area: ImageField = OutputField(
description="Image representing the total gradient area of the mask. For paste-back purposes."
)
@invocation(
"create_gradient_mask",
title="Create Gradient Mask",
@ -193,49 +204,53 @@ class CreateGradientMaskInvocation(BaseInvocation):
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
if self.edge_radius > 0:
if self.coherence_mode == "Box Blur":
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
else: # Gaussian Blur OR Staged
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
# redistribute blur so that the edges are 0 and blur out to 1
# redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2
threshold = 1 - self.minimum_denoise
if self.coherence_mode == "Staged":
# wherever the blur_tensor is masked to any degree, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
# wherever the blur_tensor is less than fully masked, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
# multiply original mask to force actually masked regions to 0
blur_tensor = mask_tensor * blur_tensor
else:
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
return DenoiseMaskOutput.build(
mask_name=mask_name,
masked_latents_name=None,
gradient=True,
# compute a [0, 1] mask from the blur_tensor
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
expanded_image_dto = context.images.save(expanded_mask_image)
return GradientMaskOutput(
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
)
def get_scheduler(
context: InvocationContext,
scheduler_info: ModelInfo,
scheduler_info: ModelField,
scheduler_name: str,
seed: int,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
orig_scheduler_info = context.models.load(scheduler_info)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@ -360,7 +375,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
) -> ConditioningData:
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = c.extra_conditioning
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
@ -370,7 +384,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
text_embeddings=c,
guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
extra=extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=0.0, # threshold,
warmup=0.2, # warmup,
@ -449,7 +462,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# and if weight is None, populate with default 1.0?
controlnet_data = []
for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
# control_models.append(control_model)
control_image_field = control_info.image
@ -511,11 +524,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
context.models.load(single_ip_adapter.ip_adapter_model)
)
image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
@ -526,6 +538,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
single_ipa_images, image_encoder_model
@ -565,8 +578,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(key=t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
image = context.images.get_pil(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
@ -719,12 +732,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
unet_info = context.models.load(**self.unet.unet.model_dump())
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
@ -777,10 +791,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_end=self.denoising_end,
)
(
result_latents,
result_attention_map_saver,
) = pipeline.latents_from_embeddings(
result_latents = pipeline.latents_from_embeddings(
latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
@ -821,7 +832,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VaeField = InputField(
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@ -832,8 +843,8 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(**self.vae.vae.model_dump())
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module)
latents = latents.to(vae.device)
@ -999,7 +1010,7 @@ class ImageToLatentsInvocation(BaseInvocation):
image: ImageField = InputField(
description="The image to encode",
)
vae: VaeField = InputField(
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@ -1055,7 +1066,7 @@ class ImageToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(**self.vae.vae.model_dump())
vae_info = context.models.load(self.vae.vae)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:

View File

@ -8,7 +8,10 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.controlnet_image_processors import (
CONTROLNET_MODE_VALUES,
CONTROLNET_RESIZE_VALUES,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
@ -17,10 +20,8 @@ from invokeai.app.invocations.fields import (
OutputField,
UIType,
)
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from ...version import __version__
@ -30,10 +31,20 @@ class MetadataItemField(BaseModel):
value: Any = Field(description=FieldDescriptions.metadata_item_value)
class ModelMetadataField(BaseModel):
"""Model Metadata Field"""
key: str
hash: str
name: str
base: BaseModelType
type: ModelType
class LoRAMetadataField(BaseModel):
"""LoRA Metadata Field"""
model: LoRAModelField = Field(description=FieldDescriptions.lora_model)
model: ModelMetadataField = Field(description=FieldDescriptions.lora_model)
weight: float = Field(description=FieldDescriptions.lora_weight)
@ -41,7 +52,7 @@ class IPAdapterMetadataField(BaseModel):
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: IPAdapterModelField = Field(
ip_adapter_model: ModelMetadataField = Field(
description="The IP-Adapter model.",
)
weight: Union[float, list[float]] = Field(
@ -51,6 +62,33 @@ class IPAdapterMetadataField(BaseModel):
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
class T2IAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: ModelMetadataField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
class ControlNetMetadataField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ModelMetadataField = Field(description="The ControlNet model to use")
control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@invocation_output("metadata_item_output")
class MetadataItemOutput(BaseInvocationOutput):
"""Metadata Item Output"""
@ -140,14 +178,14 @@ class CoreMetadataInvocation(BaseInvocation):
default=None,
description="The number of skipped CLIP layers",
)
model: Optional[MainModelField] = InputField(default=None, description="The main model used for inference")
controlnets: Optional[list[ControlField]] = InputField(
model: Optional[ModelMetadataField] = InputField(default=None, description="The main model used for inference")
controlnets: Optional[list[ControlNetMetadataField]] = InputField(
default=None, description="The ControlNets used for inference"
)
ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField(
default=None, description="The IP Adapters used for inference"
)
t2iAdapters: Optional[list[T2IAdapterField]] = InputField(
t2iAdapters: Optional[list[T2IAdapterMetadataField]] = InputField(
default=None, description="The IP Adapters used for inference"
)
loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference")
@ -159,7 +197,7 @@ class CoreMetadataInvocation(BaseInvocation):
default=None,
description="The name of the initial image",
)
vae: Optional[VAEModelField] = InputField(
vae: Optional[ModelMetadataField] = InputField(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
@ -190,7 +228,7 @@ class CoreMetadataInvocation(BaseInvocation):
)
# SDXL Refiner
refiner_model: Optional[MainModelField] = InputField(
refiner_model: Optional[ModelMetadataField] = InputField(
default=None,
description="The SDXL Refiner model used",
)
@ -222,10 +260,9 @@ class CoreMetadataInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> MetadataOutput:
"""Collects and outputs a CoreMetadata object"""
return MetadataOutput(
metadata=MetadataField.model_validate(
self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
)
)
as_dict = self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
as_dict["app_version"] = __version__
return MetadataOutput(metadata=MetadataField.model_validate(as_dict))
model_config = ConfigDict(extra="allow")

View File

@ -6,8 +6,8 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import SubModelType
from ...backend.model_manager import SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -16,33 +16,33 @@ from .baseinvocation import (
)
class ModelInfo(BaseModel):
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
class ModelField(BaseModel):
key: str = Field(description="Key of the model")
submodel_type: Optional[SubModelType] = Field(description="Submodel type", default=None)
class LoraInfo(ModelInfo):
weight: float = Field(description="Lora's weight which to use when apply to model")
class LoRAField(BaseModel):
lora: ModelField = Field(description="Info to load lora model")
weight: float = Field(description="Weight to apply to lora model")
class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
unet: ModelField = Field(description="Info to load unet submodel")
scheduler: ModelField = Field(description="Info to load scheduler submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
class ClipField(BaseModel):
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
class CLIPField(BaseModel):
tokenizer: ModelField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelField = Field(description="Info to load text_encoder submodel")
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
class VAEField(BaseModel):
vae: ModelField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@ -57,14 +57,14 @@ class UNetOutput(BaseInvocationOutput):
class VAEOutput(BaseInvocationOutput):
"""Base class for invocations that output a VAE field"""
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation_output("clip_output")
class CLIPOutput(BaseInvocationOutput):
"""Base class for invocations that output a CLIP field"""
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
@invocation_output("model_loader_output")
@ -74,18 +74,6 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
pass
class MainModelField(BaseModel):
"""Main model field"""
key: str = Field(description="Model key")
class LoRAModelField(BaseModel):
"""LoRA model field"""
key: str = Field(description="LoRA model key")
@invocation(
"main_model_loader",
title="Main Model",
@ -96,62 +84,40 @@ class LoRAModelField(BaseModel):
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
model: ModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
# TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
key = self.model.key
# TODO: not found exceptions
if not context.models.exists(key):
raise Exception(f"Unknown model {key}")
if not context.models.exists(self.model.key):
raise Exception(f"Unknown model {self.model.key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
key=key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=key,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=key,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=key,
submodel_type=SubModelType.Vae,
),
),
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
vae=VAEField(vae=vae),
)
@invocation_output("lora_loader_output")
class LoraLoaderOutput(BaseInvocationOutput):
class LoRALoaderOutput(BaseInvocationOutput):
"""Model loader output"""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1")
class LoraLoaderInvocation(BaseInvocation):
class LoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
default=None,
@ -159,46 +125,41 @@ class LoraLoaderInvocation(BaseInvocation):
input=Input.Connection,
title="UNet",
)
clip: Optional[ClipField] = InputField(
clip: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP",
)
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise Exception(f"Unkown lora: {lora_key}!")
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'LoRA "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip')
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'LoRA "{lora_key}" already applied to clip')
output = LoraLoaderOutput()
output = LoRALoaderOutput()
if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
@ -207,12 +168,12 @@ class LoraLoaderInvocation(BaseInvocation):
@invocation_output("sdxl_lora_loader_output")
class SDXLLoraLoaderOutput(BaseInvocationOutput):
class SDXLLoRALoaderOutput(BaseInvocationOutput):
"""SDXL LoRA Loader Output"""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
clip2: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
@invocation(
@ -222,10 +183,10 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
category="model",
version="1.0.1",
)
class SDXLLoraLoaderInvocation(BaseInvocation):
class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
default=None,
@ -233,65 +194,59 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
input=Input.Connection,
title="UNet",
)
clip: Optional[ClipField] = InputField(
clip: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP 1",
)
clip2: Optional[ClipField] = InputField(
clip2: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP 2",
)
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise Exception(f"Unknown lora: {lora_key}!")
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'LoRA "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip')
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'LoRA "{lora_key}" already applied to clip')
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip2')
if self.clip2 is not None and any(lora.lora.key == lora_key for lora in self.clip2.loras):
raise Exception(f'LoRA "{lora_key}" already applied to clip2')
output = SDXLLoraLoaderOutput()
output = SDXLLoRALoaderOutput()
if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.clip2 is not None:
output.clip2 = copy.deepcopy(self.clip2)
output.clip2 = self.clip2.model_copy(deep=True)
output.clip2.loras.append(
LoraInfo(
key=lora_key,
submodel_type=None,
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
@ -299,17 +254,11 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
return output
class VAEModelField(BaseModel):
"""Vae model field"""
key: str = Field(description="Model's key")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
class VaeLoaderInvocation(BaseInvocation):
class VAELoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: VAEModelField = InputField(
vae_model: ModelField = InputField(
description=FieldDescriptions.vae_model,
input=Input.Direct,
title="VAE",
@ -321,7 +270,7 @@ class VaeLoaderInvocation(BaseInvocation):
if not context.models.exists(key):
raise Exception(f"Unkown vae: {key}!")
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
return VAEOutput(vae=VAEField(vae=self.vae_model))
@invocation_output("seamless_output")
@ -329,7 +278,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
"""Modified Seamless Model output"""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
vae: Optional[VaeField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
vae: Optional[VAEField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
@invocation(
@ -348,7 +297,7 @@ class SeamlessModeInvocation(BaseInvocation):
input=Input.Connection,
title="UNet",
)
vae: Optional[VaeField] = InputField(
vae: Optional[VAEField] = InputField(
default=None,
description=FieldDescriptions.vae_model,
input=Input.Connection,

View File

@ -8,7 +8,7 @@ from .baseinvocation import (
invocation,
invocation_output,
)
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
from .model import CLIPField, ModelField, UNetField, VAEField
@invocation_output("sdxl_model_loader_output")
@ -16,9 +16,9 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
"""SDXL base model loader output"""
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation_output("sdxl_refiner_model_loader_output")
@ -26,15 +26,15 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
"""SDXL refiner model loader output"""
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1")
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
model: MainModelField = InputField(
model: ModelField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
)
# TODO: precision?
@ -46,48 +46,19 @@ class SDXLModelLoaderInvocation(BaseInvocation):
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return SDXLModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
key=model_key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=model_key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel_type=SubModelType.Vae,
),
),
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
vae=VAEField(vae=vae),
)
@ -101,10 +72,8 @@ class SDXLModelLoaderInvocation(BaseInvocation):
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_refiner_model,
input=Input.Direct,
ui_type=UIType.SDXLRefinerModel,
model: ModelField = InputField(
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
)
# TODO: precision?
@ -115,34 +84,14 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return SDXLRefinerModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
key=model_key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=model_key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel_type=SubModelType.Vae,
),
),
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
vae=VAEField(vae=vae),
)

View File

@ -10,17 +10,14 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
class T2IAdapterModelField(BaseModel):
key: str = Field(description="Model record key for the T2I-Adapter model")
class T2IAdapterField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.")
t2i_adapter_model: ModelField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
@ -55,7 +52,7 @@ class T2IAdapterInvocation(BaseInvocation):
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
t2i_adapter_model: T2IAdapterModelField = InputField(
t2i_adapter_model: ModelField = InputField(
description="The T2I-Adapter model.",
title="T2I-Adapter Model",
input=Input.Direct,

View File

@ -166,6 +166,7 @@ two configs are kept in separate sections of the config file:
...
"""
from __future__ import annotations
import os
@ -255,6 +256,7 @@ class InvokeAIAppConfig(InvokeAISettings):
profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development)
profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development)
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce startup time.", json_schema_extra=Categories.Development)
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)

View File

@ -1,4 +1,5 @@
"""Init file for download queue."""
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
from .download_default import DownloadQueueService, TqdmProgress

View File

@ -224,20 +224,13 @@ class DownloadQueueService(DownloadQueueServiceBase):
job.job_started = get_iso_timestamp()
self._do_download(job)
self._signal_job_complete(job)
except (OSError, HTTPError) as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc()
try:
self._signal_job_error(job, excp)
except:
pass
except DownloadJobCancelledException:
try:
self._signal_job_cancelled(job)
self._cleanup_cancelled_job(job)
except:
pass
finally:
job.job_ended = get_iso_timestamp()

View File

@ -41,8 +41,9 @@ class InvocationCacheBase(ABC):
"""Clears the cache"""
pass
@staticmethod
@abstractmethod
def create_key(self, invocation: BaseInvocation) -> int:
def create_key(invocation: BaseInvocation) -> int:
"""Gets the key for the invocation's cache item"""
pass

View File

@ -61,9 +61,7 @@ class MemoryInvocationCache(InvocationCacheBase):
self._delete_oldest_access(number_to_delete)
self._cache[key] = CachedItem(
invocation_output,
invocation_output.model_dump_json(
warnings=False, exclude_defaults=True, exclude_unset=True, include={"type"}
),
invocation_output.model_dump_json(warnings=False, exclude_defaults=True, exclude_unset=True),
)
def _delete_oldest_access(self, number_to_delete: int) -> None:
@ -81,7 +79,7 @@ class MemoryInvocationCache(InvocationCacheBase):
with self._lock:
return self._delete(key)
def clear(self, *args, **kwargs) -> None:
def clear(self) -> None:
with self._lock:
if self._max_cache_size == 0:
return

View File

@ -25,6 +25,7 @@ if TYPE_CHECKING:
from .images.images_base import ImageServiceABC
from .invocation_cache.invocation_cache_base import InvocationCacheBase
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .model_images.model_images_base import ModelImageFileStorageBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
@ -49,6 +50,7 @@ class InvocationServices:
image_files: "ImageFileStorageBase",
image_records: "ImageRecordStorageBase",
logger: "Logger",
model_images: "ModelImageFileStorageBase",
model_manager: "ModelManagerServiceBase",
download_queue: "DownloadQueueServiceBase",
performance_statistics: "InvocationStatsServiceBase",
@ -72,6 +74,7 @@ class InvocationServices:
self.image_files = image_files
self.image_records = image_records
self.logger = logger
self.model_images = model_images
self.model_manager = model_manager
self.download_queue = download_queue
self.performance_statistics = performance_statistics

View File

@ -0,0 +1,33 @@
from abc import ABC, abstractmethod
from pathlib import Path
from PIL.Image import Image as PILImageType
class ModelImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, model_key: str) -> PILImageType:
"""Retrieves a model image as PIL Image."""
pass
@abstractmethod
def get_path(self, model_key: str) -> Path:
"""Gets the internal path to a model image."""
pass
@abstractmethod
def get_url(self, model_key: str) -> str | None:
"""Gets the URL to fetch a model image."""
pass
@abstractmethod
def save(self, image: PILImageType, model_key: str) -> None:
"""Saves a model image."""
pass
@abstractmethod
def delete(self, model_key: str) -> None:
"""Deletes a model image."""
pass

View File

@ -0,0 +1,20 @@
# TODO: Should these excpetions subclass existing python exceptions?
class ModelImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message="Model image file not found"):
super().__init__(message)
class ModelImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message="Model image file not saved"):
super().__init__(message)
class ModelImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message="Model image file not deleted"):
super().__init__(message)

View File

@ -0,0 +1,79 @@
from pathlib import Path
from PIL import Image
from PIL.Image import Image as PILImageType
from send2trash import send2trash
from invokeai.app.services.invoker import Invoker
from invokeai.app.util.thumbnails import make_thumbnail
from .model_images_base import ModelImageFileStorageBase
from .model_images_common import (
ModelImageFileDeleteException,
ModelImageFileNotFoundException,
ModelImageFileSaveException,
)
class ModelImageFileStorageDisk(ModelImageFileStorageBase):
"""Stores images on disk"""
def __init__(self, model_images_folder: Path):
self._model_images_folder = model_images_folder
self._validate_storage_folders()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def get(self, model_key: str) -> PILImageType:
try:
path = self.get_path(model_key)
if not self._validate_path(path):
raise ModelImageFileNotFoundException
return Image.open(path)
except FileNotFoundError as e:
raise ModelImageFileNotFoundException from e
def save(self, image: PILImageType, model_key: str) -> None:
try:
self._validate_storage_folders()
image_path = self._model_images_folder / (model_key + ".webp")
thumbnail = make_thumbnail(image, 256)
thumbnail.save(image_path, format="webp")
except Exception as e:
raise ModelImageFileSaveException from e
def get_path(self, model_key: str) -> Path:
path = self._model_images_folder / (model_key + ".webp")
return path
def get_url(self, model_key: str) -> str | None:
path = self.get_path(model_key)
if not self._validate_path(path):
return
return self._invoker.services.urls.get_model_image_url(model_key)
def delete(self, model_key: str) -> None:
try:
path = self.get_path(model_key)
if not self._validate_path(path):
raise ModelImageFileNotFoundException
send2trash(path)
except Exception as e:
raise ModelImageFileDeleteException from e
def _validate_path(self, path: Path) -> bool:
"""Validates the path given for an image."""
return path.exists()
def _validate_storage_folders(self) -> None:
"""Checks if the required folders exist and create them if they don't"""
self._model_images_folder.mkdir(parents=True, exist_ok=True)

View File

@ -18,16 +18,16 @@ from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..model_metadata import ModelMetadataStoreBase
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
@ -150,6 +150,13 @@ ModelSource = Annotated[
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
CivitaiModelSource: ModelSourceType.CivitAI,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
@ -229,6 +236,11 @@ class ModelInstallJob(BaseModel):
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING
@property
def downloads_done(self) -> bool:
"""Return true if job's downloads ae done."""
return self.status == InstallStatus.DOWNLOADS_DONE
@property
def running(self) -> bool:
"""Return true if job is running."""
@ -254,7 +266,6 @@ class ModelInstallServiceBase(ABC):
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
metadata_store: ModelMetadataStoreBase,
event_bus: Optional["EventServiceBase"] = None,
):
"""
@ -341,6 +352,7 @@ class ModelInstallServiceBase(ABC):
source: str,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
) -> ModelInstallJob:
r"""Install the indicated model using heuristics to interpret user intentions.
@ -386,7 +398,7 @@ class ModelInstallServiceBase(ABC):
will override corresponding autoassigned probe fields in the
model's config record. Use it to override
`name`, `description`, `base_type`, `model_type`, `format`,
`prediction_type`, `image_size`, and/or `ztsnr_training`.
`prediction_type`, and/or `image_size`.
This will download the model located at `source`,
probe it, and install it into the models directory.

View File

@ -4,10 +4,10 @@ import os
import re
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from hashlib import sha256
from pathlib import Path
from queue import Empty, Queue
from random import randbytes
from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Set, Union
@ -21,14 +21,17 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
CheckpointConfigBase,
InvalidModelConfigException,
ModelRepoVariant,
ModelSourceType,
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
CivitaiMetadataFetch,
@ -36,12 +39,14 @@ from invokeai.backend.model_manager.metadata import (
ModelMetadataWithFiles,
RemoteModelFile,
)
from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import Chdir, InvokeAILogger
from invokeai.backend.util.devices import choose_precision, choose_torch_device
from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP,
CivitaiModelSource,
HFModelSource,
InstallStatus,
@ -91,7 +96,6 @@ class ModelInstallService(ModelInstallServiceBase):
self._running = False
self._session = session
self._next_job_id = 0
self._metadata_store = record_store.metadata_store # for convenience
@property
def app_config(self) -> InvokeAIAppConfig: # noqa D102
@ -140,6 +144,7 @@ class ModelInstallService(ModelInstallServiceBase):
config = config or {}
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
config["source_type"] = ModelSourceType.Path
return self._register(model_path, config)
def install_path(
@ -149,11 +154,11 @@ class ModelInstallService(ModelInstallServiceBase):
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
info: AnyModelConfig = self._probe_model(Path(model_path), config)
old_hash = info.current_hash
if self._app_config.skip_model_hash:
config["hash"] = uuid_string()
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
if preferred_name := config.get("name"):
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
@ -167,8 +172,6 @@ class ModelInstallService(ModelInstallServiceBase):
raise DuplicateModelException(
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
) from excp
new_hash = FastModelHash.hash(new_path)
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
return self._register(
new_path,
@ -181,13 +184,14 @@ class ModelInstallService(ModelInstallServiceBase):
source: str,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
) -> ModelInstallJob:
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None
if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source))
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),
@ -277,14 +281,20 @@ class ModelInstallService(ModelInstallServiceBase):
self._scan_models_directory()
if autoimport := self._app_config.autoimport_dir:
self._logger.info("Scanning autoimport directory for new models")
installed = self.scan_directory(self._app_config.root_path / autoimport)
installed: List[str] = []
# Use ThreadPoolExecutor to scan dirs in parallel
with ThreadPoolExecutor() as executor:
future_models = [executor.submit(self.scan_directory, self._app_config.root_path / autoimport / cur_model_type.value) for cur_model_type in ModelType]
[installed.extend(models.result()) for models in as_completed(future_models)]
self._logger.info(f"{len(installed)} new models registered")
self._logger.info("Model installer (re)initialized")
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
if len([entry for entry in os.scandir(scan_dir) if not entry.name.startswith(".")]) == 0:
return []
callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback)
search = ModelSearch(on_model_found=callback, config=self._app_config)
self._models_installed.clear()
search.search(scan_dir)
return list(self._models_installed)
@ -370,21 +380,24 @@ class ModelInstallService(ModelInstallServiceBase):
self._signal_job_errored(job)
elif (
job.waiting or job.downloading
job.waiting or job.downloads_done
): # local jobs will be in waiting state, remote jobs will be downloading state
job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes
self._signal_job_running(job)
job.config_in["source"] = str(job.source)
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any
if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)):
job.config_in["source_api_response"] = job.source_metadata.api_response
if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_phrases:
job.config_in["trigger_phrases"] = job.source_metadata.trigger_phrases
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:
key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
# enter the metadata, if there is any
if job.source_metadata:
self._metadata_store.add_metadata(key, job.source_metadata)
self._signal_job_completed(job)
except InvalidModelConfigException as excp:
@ -442,13 +455,13 @@ class ModelInstallService(ModelInstallServiceBase):
self.unregister(key)
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
for cur_base_model in BaseModelType:
for cur_model_type in ModelType:
models_dir = Path(cur_base_model.value, cur_model_type.value)
installed.update(self.scan_directory(models_dir))
# Use ThreadPoolExecutor to scan dirs in parallel
with ThreadPoolExecutor() as executor:
future_models = [executor.submit(self.scan_directory, Path(cur_base_model.value, cur_model_type.value)) for cur_base_model in BaseModelType for cur_model_type in ModelType]
[installed.update(models.result()) for models in as_completed(future_models)]
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
def _sync_model_path(self, key: str, ignore_hash_change: bool = False) -> AnyModelConfig:
def _sync_model_path(self, key: str) -> AnyModelConfig:
"""
Move model into the location indicated by its basetype, type and name.
@ -469,15 +482,8 @@ class ModelInstallService(ModelInstallServiceBase):
new_path = models_dir / model.base.value / model.type.value / model.name
self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path)
new_hash = FastModelHash.hash(new_path)
model.path = new_path.relative_to(models_dir).as_posix()
if model.current_hash != new_hash:
assert (
ignore_hash_change
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
model.current_hash = new_hash
self._logger.info(f"Model has new hash {model.current_hash}, but will continue to be identified by {key}")
self.record_store.update_model(key, model)
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
return model
def _scan_register(self, model: Path) -> bool:
@ -529,22 +535,14 @@ class ModelInstallService(ModelInstallServiceBase):
move(old_path, new_path)
return new_path
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
if config: # used to override probe fields
for key, value in config.items():
setattr(info, key, value)
return info
def _create_key(self) -> str:
return sha256(randbytes(100)).hexdigest()[0:32]
def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str:
key = self._create_key()
if config and not config.get("key", None):
config["key"] = key
config = config or {}
if self._app_config.skip_model_hash:
config["hash"] = uuid_string()
info = info or ModelProbe.probe(model_path, config)
model_path = model_path.absolute()
@ -554,11 +552,11 @@ class ModelInstallService(ModelInstallServiceBase):
info.path = model_path.as_posix()
# add 'main' specific fields
if hasattr(info, "config"):
if isinstance(info, CheckpointConfigBase):
# make config relative to our root
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
self.record_store.add_model(info.key, info)
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve()
info.config_path = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
self.record_store.add_model(info)
return info.key
def _next_id(self) -> int:
@ -579,13 +577,15 @@ class ModelInstallService(ModelInstallServiceBase):
source=source,
config_in=config or {},
local_path=Path(source.path),
inplace=source.inplace,
inplace=source.inplace or False,
)
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
if not source.access_token:
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
metadata = CivitaiMetadataFetch(self._session).from_id(str(source.version_id))
metadata = CivitaiMetadataFetch(self._session, self.app_config.get_config().civitai_api_key).from_id(
str(source.version_id)
)
assert isinstance(metadata, ModelMetadataWithFiles)
remote_files = metadata.download_urls(session=self._session)
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
@ -613,15 +613,17 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
# URLs from Civitai or HuggingFace will be handled specially
url_patterns = {
r"^https?://civitai.com/": CivitaiMetadataFetch,
r"^https?://huggingface.co/[^/]+/[^/]+$": HuggingFaceMetadataFetch,
}
metadata = None
for pattern, fetcher in url_patterns.items():
if re.match(pattern, str(source.url), re.IGNORECASE):
metadata = fetcher(self._session).from_url(source.url)
break
fetcher = None
try:
fetcher = self.get_fetcher_from_url(str(source.url))
except ValueError:
pass
kwargs: dict[str, Any] = {"session": self._session}
if fetcher is CivitaiMetadataFetch:
kwargs["api_key"] = self._app_config.get_config().civitai_api_key
if fetcher is not None:
metadata = fetcher(**kwargs).from_url(source.url)
self._logger.debug(f"metadata={metadata}")
if metadata and isinstance(metadata, ModelMetadataWithFiles):
remote_files = metadata.download_urls(session=self._session)
@ -636,7 +638,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_remote_model(
self,
source: ModelSource,
source: HFModelSource | CivitaiModelSource | URLModelSource,
remote_files: List[RemoteModelFile],
metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]],
@ -664,7 +666,7 @@ class ModelInstallService(ModelInstallServiceBase):
# In the event that there is a subfolder specified in the source,
# we need to remove it from the destination path in order to avoid
# creating unwanted subfolders
if hasattr(source, "subfolder") and source.subfolder:
if isinstance(source, HFModelSource) and source.subfolder:
root = Path(remote_files[0].path.parts[0])
subfolder = root / source.subfolder
else:
@ -749,8 +751,8 @@ class ModelInstallService(ModelInstallServiceBase):
self._download_cache.pop(download_job.source, None)
# are there any more active jobs left in this task?
if all(x.complete for x in install_job.download_parts):
# now enqueue job for actual installation into the models directory
if install_job.downloading and all(x.complete for x in install_job.download_parts):
install_job.status = InstallStatus.DOWNLOADS_DONE
self._install_queue.put(install_job)
# Let other threads know that the number of downloads has changed
@ -851,3 +853,11 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"{job.source}: model installation was cancelled")
if self._event_bus:
self._event_bus.emit_model_install_cancelled(str(job.source))
@staticmethod
def get_fetcher_from_url(url: str):
if re.match(r"^https?://civitai.com/", url.lower()):
return CivitaiMetadataFetch
elif re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'")

View File

@ -1,15 +1,11 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from abc import ABC, abstractmethod
from typing import Optional
import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase
@ -70,32 +66,3 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def stop(self, invoker: Invoker) -> None:
pass
@abstractmethod
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass
@abstractmethod
def load_model_by_key(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass
@abstractmethod
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass

View File

@ -1,14 +1,10 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
from typing import Optional
import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
@ -18,7 +14,7 @@ from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallService, ModelInstallServiceBase
from ..model_load import ModelLoadService, ModelLoadServiceBase
from ..model_records import ModelRecordServiceBase, UnknownModelException
from ..model_records import ModelRecordServiceBase
from .model_manager_base import ModelManagerServiceBase
@ -64,56 +60,6 @@ class ModelManagerService(ModelManagerServiceBase):
if hasattr(service, "stop"):
service.stop(invoker)
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
return self.load.load_model(model_config, submodel_type, context_data)
def load_model_by_key(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
config = self.store.get_model(key)
return self.load.load_model(config, submodel_type, context_data)
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
"""
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""
configs = self.store.search_by_attr(model_name, base_model, model_type)
if len(configs) == 0:
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
elif len(configs) > 1:
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
else:
return self.load.load_model(configs[0], submodel, context_data)
@classmethod
def build_model_manager(
cls,

View File

@ -1,9 +0,0 @@
"""Init file for ModelMetadataStoreService module."""
from .metadata_store_base import ModelMetadataStoreBase
from .metadata_store_sql import ModelMetadataStoreSQL
__all__ = [
"ModelMetadataStoreBase",
"ModelMetadataStoreSQL",
]

View File

@ -1,65 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Storage for Model Metadata
"""
from abc import ABC, abstractmethod
from typing import List, Set, Tuple
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class ModelMetadataStoreBase(ABC):
"""Store, search and fetch model metadata retrieved from remote repositories."""
@abstractmethod
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
@abstractmethod
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
@abstractmethod
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
@abstractmethod
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
@abstractmethod
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
@abstractmethod
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
@abstractmethod
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
@abstractmethod
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""

View File

@ -1,222 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Storage for Model Metadata
"""
import sqlite3
from typing import List, Optional, Set, Tuple
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
from .metadata_store_base import ModelMetadataStoreBase
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
"""Store, search and fetch model metadata retrieved from remote repositories."""
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_metadata(
id,
metadata
)
VALUES (?,?);
""",
(
model_key,
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
raise UnknownMetadataException from excp
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata FROM model_metadata
WHERE id=?;
""",
(model_key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownMetadataException("model metadata not found")
return ModelMetadataFetchBase.from_json(rows[0])
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT id,metadata FROM model_metadata;
""",
(),
)
rows = self._cursor.fetchall()
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
json_serialized = metadata.model_dump_json() # turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_metadata
SET
metadata=?
WHERE id=?;
""",
(json_serialized, model_key),
)
if self._cursor.rowcount == 0:
raise UnknownMetadataException("model metadata not found")
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_metadata(model_key)
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
self._cursor.execute(
"""--sql
select tag_text from tags;
"""
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Optional[Set[str]] = None
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.model_id FROM model_tags AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
if matches is None:
matches = model_keys
matches = matches.intersection(model_keys)
except sqlite3.Error as e:
raise e
return matches if matches else set()
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE author=?;
""",
(author,),
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE name=?;
""",
(name,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@ -1,4 +1,5 @@
"""Init file for model record services."""
from .model_records_base import ( # noqa F401
DuplicateModelException,
InvalidModelException,

View File

@ -6,20 +6,19 @@ Abstract base class for storing and retrieving model configuration records.
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import List, Optional, Set, Union
from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager import (
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..model_metadata import ModelMetadataStoreBase
from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType
class DuplicateModelException(Exception):
@ -60,11 +59,34 @@ class ModelSummary(BaseModel):
tags: Set[str] = Field(description="tags associated with model")
class ModelRecordChanges(BaseModelExcludeNull):
"""A set of changes to apply to a model."""
# Changes applicable to all models
name: Optional[str] = Field(description="Name of the model.", default=None)
path: Optional[str] = Field(description="Path to the model.", default=None)
description: Optional[str] = Field(description="Model description", default=None)
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
# Checkpoint-specific changes
# TODO(MM2): Should we expose these? Feels footgun-y...
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
prediction_type: Optional[SchedulerPredictionType] = Field(
description="The prediction type of the model.", default=None
)
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
config_path: Optional[str] = Field(description="Path to config file for model", default=None)
class ModelRecordServiceBase(ABC):
"""Abstract base class for storage and retrieval of model configs."""
@abstractmethod
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
"""
Add a model to the database.
@ -88,13 +110,12 @@ class ModelRecordServiceBase(ABC):
pass
@abstractmethod
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
:param key: Unique key for the model to be updated.
:param changes: A set of changes to apply to this model. Changes are validated before being written.
"""
pass
@ -109,40 +130,17 @@ class ModelRecordServiceBase(ABC):
"""
pass
@property
@abstractmethod
def metadata_store(self) -> ModelMetadataStoreBase:
"""Return a ModelMetadataStore initialized on the same database."""
pass
@abstractmethod
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
"""
Retrieve metadata (if any) from when model was downloaded from a repo.
Retrieve the configuration for the indicated model.
:param key: Model key
:param hash: Hash of model config to be fetched.
Exceptions: UnknownModelException
"""
pass
@abstractmethod
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
"""List metadata for all models that have it."""
pass
@abstractmethod
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
"""
Search model metadata for ones with all listed tags and return their corresponding configs.
:param tags: Set of tags to search for. All tags must be present.
"""
pass
@abstractmethod
def list_tags(self) -> Set[str]:
"""Return a unique set of all the model tags in the metadata database."""
pass
@abstractmethod
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
@ -217,21 +215,3 @@ class ModelRecordServiceBase(ABC):
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
)
return model_configs[0]
def rename_model(
self,
key: str,
new_name: str,
) -> AnyModelConfig:
"""
Rename the indicated model. Just a special case of update_model().
In some implementations, renaming the model may involve changing where
it is stored on the filesystem. So this is broken out.
:param key: Model key
:param new_name: New name for model
"""
config = self.get_model(key)
config.name = new_name
return self.update_model(key, config)

View File

@ -39,12 +39,11 @@ Typical usage:
configs = store.search_by_attr(base_model='sd-2', model_type='main')
"""
import json
import sqlite3
from math import ceil
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import List, Optional, Union
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import (
@ -54,12 +53,11 @@ from invokeai.backend.model_manager.config import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import (
DuplicateModelException,
ModelRecordChanges,
ModelRecordOrderBy,
ModelRecordServiceBase,
ModelSummary,
@ -70,7 +68,7 @@ from .model_records_base import (
class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database."""
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
@ -79,14 +77,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
super().__init__()
self._db = db
self._cursor = db.conn.cursor()
self._metadata_store = metadata_store
@property
def db(self) -> SqliteDatabase:
"""Return the underlying database."""
return self._db
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
"""
Add a model to the database.
@ -96,23 +93,19 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_config (
INSERT INTO models (
id,
original_hash,
config
)
VALUES (?,?,?);
VALUES (?,?);
""",
(
key,
record.original_hash,
json_serialized,
config.key,
config.model_dump_json(),
),
)
self._db.conn.commit()
@ -120,12 +113,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "model_config.path" in str(e):
msg = f"A model with path '{record.path}' is already installed"
elif "model_config.name" in str(e):
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
else:
msg = f"A model with key '{key}' is already installed"
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
else:
raise e
@ -133,7 +126,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db.conn.rollback()
raise e
return self.get_model(key)
return self.get_model(config.key)
def del_model(self, key: str) -> None:
"""
@ -147,7 +140,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
try:
self._cursor.execute(
"""--sql
DELETE FROM model_config
DELETE FROM models
WHERE id=?;
""",
(key,),
@ -159,21 +152,20 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db.conn.rollback()
raise e
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
"""
Update the model, returning the updated version.
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
record = self.get_model(key)
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
json_serialized = record.model_dump_json()
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_config
UPDATE models
SET
config=?
WHERE id=?;
@ -200,7 +192,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM model_config
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
@ -211,6 +203,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
@ -221,7 +228,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
"""--sql
select count(*) FROM model_config
select count(*) FROM models
WHERE id=?;
""",
(key,),
@ -247,9 +254,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
If none of the optional filters are passed, will return all
models in the database.
"""
results = []
where_clause = []
bindings = []
where_clause: list[str] = []
bindings: list[str] = []
if model_name:
where_clause.append("name=?")
bindings.append(model_name)
@ -266,14 +272,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
f"""--sql
select config, strftime('%s',updated_at) FROM model_config
SELECT config, strftime('%s',updated_at) FROM models
{where};
""",
tuple(bindings),
)
results = [
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
]
result = self._cursor.fetchall()
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]
return results
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
@ -282,7 +287,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM model_config
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
@ -293,13 +298,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated original_hash."""
"""Return models with the indicated hash."""
results = []
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM model_config
WHERE original_hash=?;
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
@ -308,83 +313,35 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
]
return results
@property
def metadata_store(self) -> ModelMetadataStoreBase:
"""Return a ModelMetadataStore initialized on the same database."""
return self._metadata_store
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""
Retrieve metadata (if any) from when model was downloaded from a repo.
:param key: Model key
"""
store = self.metadata_store
try:
metadata = store.get_metadata(key)
return metadata
except UnknownMetadataException:
return None
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
"""
Search model metadata for ones with all listed tags and return their corresponding configs.
:param tags: Set of tags to search for. All tags must be present.
"""
store = ModelMetadataStoreSQL(self._db)
keys = store.search_by_tag(tags)
return [self.get_model(x) for x in keys]
def list_tags(self) -> Set[str]:
"""Return a unique set of all the model tags in the metadata database."""
store = ModelMetadataStoreSQL(self._db)
return store.list_tags()
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
"""List metadata for all models that have it."""
store = ModelMetadataStoreSQL(self._db)
return store.list_all_metadata()
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name",
ModelRecordOrderBy.Type: "a.type",
ModelRecordOrderBy.Base: "a.base",
ModelRecordOrderBy.Name: "a.name",
ModelRecordOrderBy.Format: "a.format",
ModelRecordOrderBy.Default: "type, base, format, name",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]:
"""Fix up results so that there are no null values."""
result: Dict[str, Union[str, int, Set[str]]] = {}
for key, item in summary.items():
result[key] = item or ""
result["tags"] = set(json.loads(summary["tags"] or "[]"))
return result
# Lock so that the database isn't updated while we're doing the two queries.
with self._db.lock:
# query1: get the total number of model configs
self._cursor.execute(
"""--sql
select count(*) from model_config;
select count(*) from models;
""",
(),
)
total = int(self._cursor.fetchone()[0])
# query2: fetch key fields from the join of model_config and model_metadata
# query2: fetch key fields
self._cursor.execute(
f"""--sql
SELECT a.id as key, a.type, a.base, a.format, a.name,
json_extract(a.config, '$.description') as description,
json_extract(b.metadata, '$.tags') as tags
FROM model_config AS a
LEFT JOIN model_metadata AS b on a.id=b.id
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
@ -395,7 +352,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
),
)
rows = self._cursor.fetchall()
items = [ModelSummary.model_validate(_fixup(dict(x))) for x in rows]
items = [ModelSummary.model_validate(dict(x)) for x in rows]
return PaginatedResults(
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
)

View File

@ -200,6 +200,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._invoker.services.logger.error(
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
)
self._invoker.services.logger.error(error)
# Send error event
self._invoker.services.events.emit_invocation_error(

View File

@ -1,7 +1,7 @@
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union
from PIL.Image import Image
from torch import Tensor
@ -13,15 +13,16 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.model import ModelField
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
"""
@ -65,75 +66,86 @@ class InvocationContextInterface:
class BoardsInterface(InvocationContextInterface):
def create(self, board_name: str) -> BoardDTO:
"""
Creates a board.
"""Creates a board.
:param board_name: The name of the board to create.
Args:
board_name: The name of the board to create.
Returns:
The created board DTO.
"""
return self._services.boards.create(board_name)
def get_dto(self, board_id: str) -> BoardDTO:
"""
Gets a board DTO.
"""Gets a board DTO.
:param board_id: The ID of the board to get.
Args:
board_id: The ID of the board to get.
Returns:
The board DTO.
"""
return self._services.boards.get_dto(board_id)
def get_all(self) -> list[BoardDTO]:
"""
Gets all boards.
"""Gets all boards.
Returns:
A list of all boards.
"""
return self._services.boards.get_all()
def add_image_to_board(self, board_id: str, image_name: str) -> None:
"""
Adds an image to a board.
"""Adds an image to a board.
:param board_id: The ID of the board to add the image to.
:param image_name: The name of the image to add to the board.
Args:
board_id: The ID of the board to add the image to.
image_name: The name of the image to add to the board.
"""
return self._services.board_images.add_image_to_board(board_id, image_name)
def get_all_image_names_for_board(self, board_id: str) -> list[str]:
"""
Gets all image names for a board.
"""Gets all image names for a board.
:param board_id: The ID of the board to get the image names for.
Args:
board_id: The ID of the board to get the image names for.
Returns:
A list of all image names for the board.
"""
return self._services.board_images.get_all_board_image_names_for_board(board_id)
class LoggerInterface(InvocationContextInterface):
def debug(self, message: str) -> None:
"""
Logs a debug message.
"""Logs a debug message.
:param message: The message to log.
Args:
message: The message to log.
"""
self._services.logger.debug(message)
def info(self, message: str) -> None:
"""
Logs an info message.
"""Logs an info message.
:param message: The message to log.
Args:
message: The message to log.
"""
self._services.logger.info(message)
def warning(self, message: str) -> None:
"""
Logs a warning message.
"""Logs a warning message.
:param message: The message to log.
Args:
message: The message to log.
"""
self._services.logger.warning(message)
def error(self, message: str) -> None:
"""
Logs an error message.
"""Logs an error message.
:param message: The message to log.
Args:
message: The message to log.
"""
self._services.logger.error(message)
@ -146,20 +158,23 @@ class ImagesInterface(InvocationContextInterface):
image_category: ImageCategory = ImageCategory.GENERAL,
metadata: Optional[MetadataField] = None,
) -> ImageDTO:
"""
Saves an image, returning its DTO.
"""Saves an image, returning its DTO.
If the current queue item has a workflow or metadata, it is automatically saved with the image.
:param image: The image to save, as a PIL image.
:param board_id: The board ID to add the image to, if it should be added. It the invocation \
Args:
image: The image to save, as a PIL image.
board_id: The board ID to add the image to, if it should be added. It the invocation \
inherits from `WithBoard`, that board will be used automatically. **Use this only if \
you want to override or provide a board manually!**
:param image_category: The category of the image. Only the GENERAL category is added \
image_category: The category of the image. Only the GENERAL category is added \
to the gallery.
:param metadata: The metadata to save with the image, if it should have any. If the \
metadata: The metadata to save with the image, if it should have any. If the \
invocation inherits from `WithMetadata`, that metadata will be used automatically. \
**Use this only if you want to override or provide metadata manually!**
Returns:
The saved image DTO.
"""
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
@ -189,11 +204,14 @@ class ImagesInterface(InvocationContextInterface):
)
def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image:
"""
Gets an image as a PIL Image object.
"""Gets an image as a PIL Image object.
:param image_name: The name of the image to get.
:param mode: The color mode to convert the image to. If None, the original mode is used.
Args:
image_name: The name of the image to get.
mode: The color mode to convert the image to. If None, the original mode is used.
Returns:
The image as a PIL Image object.
"""
image = self._services.images.get_pil_image(image_name)
if mode and mode != image.mode:
@ -206,158 +224,202 @@ class ImagesInterface(InvocationContextInterface):
return image
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
"""
Gets an image's metadata, if it has any.
"""Gets an image's metadata, if it has any.
:param image_name: The name of the image to get the metadata for.
Args:
image_name: The name of the image to get the metadata for.
Returns:
The image's metadata, if it has any.
"""
return self._services.images.get_metadata(image_name)
def get_dto(self, image_name: str) -> ImageDTO:
"""
Gets an image as an ImageDTO object.
"""Gets an image as an ImageDTO object.
:param image_name: The name of the image to get.
Args:
image_name: The name of the image to get.
Returns:
The image as an ImageDTO object.
"""
return self._services.images.get_dto(image_name)
class TensorsInterface(InvocationContextInterface):
def save(self, tensor: Tensor) -> str:
"""
Saves a tensor, returning its name.
"""Saves a tensor, returning its name.
:param tensor: The tensor to save.
Args:
tensor: The tensor to save.
Returns:
The name of the saved tensor.
"""
name = self._services.tensors.save(obj=tensor)
return name
def load(self, name: str) -> Tensor:
"""
Loads a tensor by name.
"""Loads a tensor by name.
:param name: The name of the tensor to load.
Args:
name: The name of the tensor to load.
Returns:
The loaded tensor.
"""
return self._services.tensors.load(name)
class ConditioningInterface(InvocationContextInterface):
def save(self, conditioning_data: ConditioningFieldData) -> str:
"""
Saves a conditioning data object, returning its name.
"""Saves a conditioning data object, returning its name.
:param conditioning_data: The conditioning data to save.
Args:
conditioning_data: The conditioning data to save.
Returns:
The name of the saved conditioning data.
"""
name = self._services.conditioning.save(obj=conditioning_data)
return name
def load(self, name: str) -> ConditioningFieldData:
"""
Loads conditioning data by name.
"""Loads conditioning data by name.
:param name: The name of the conditioning data to load.
Args:
name: The name of the conditioning data to load.
Returns:
The loaded conditioning data.
"""
return self._services.conditioning.load(name)
class ModelsInterface(InvocationContextInterface):
def exists(self, key: str) -> bool:
"""
Checks if a model exists.
def exists(self, identifier: Union[str, "ModelField"]) -> bool:
"""Checks if a model exists.
:param key: The key of the model.
"""
return self._services.model_manager.store.exists(key)
Args:
identifier: The key or ModelField representing the model.
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
Returns:
True if the model exists, False if not.
"""
Loads a model.
if isinstance(identifier, str):
return self._services.model_manager.store.exists(identifier)
:param key: The key of the model.
:param submodel_type: The submodel of the model to get.
:returns: An object representing the loaded model.
return self._services.model_manager.store.exists(identifier.key)
def load(self, identifier: Union[str, "ModelField"], submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Loads a model.
Args:
identifier: The key or ModelField representing the model.
submodel_type: The submodel of the model to get.
Returns:
An object representing the loaded model.
"""
# The model manager emits events as it loads the model. It needs the context data to build
# the event payloads.
return self._services.model_manager.load_model_by_key(
key=key, submodel_type=submodel_type, context_data=self._data
)
if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
else:
_submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key)
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
def load_by_attrs(
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""
Loads a model by its attributes.
"""Loads a model by its attributes.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
"""
return self._services.model_manager.load_model_by_attr(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
context_data=self._data,
)
Args:
name: Name of the model.
base: The models' base type, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc.
type: Type of the model, e.g. `ModelType.Main`, `ModelType.Vae`, etc.
submodel_type: The type of submodel to load, e.g. `SubModelType.UNet`, `SubModelType.TextEncoder`, etc. Only main
models have submodels.
def get_config(self, key: str) -> AnyModelConfig:
Returns:
An object representing the loaded model.
"""
Gets a model's info, an dict-like object.
:param key: The key of the model.
"""
return self._services.model_manager.store.get_model(key=key)
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
if len(configs) == 0:
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""
Gets a model's metadata, if it has any.
if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
:param key: The key of the model.
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
def get_config(self, identifier: Union[str, "ModelField"]) -> AnyModelConfig:
"""Gets a model's config.
Args:
identifier: The key or ModelField representing the model.
Returns:
The model's config.
"""
return self._services.model_manager.store.get_metadata(key=key)
if isinstance(identifier, str):
return self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.store.get_model(identifier.key)
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
"""
Searches for models by path.
"""Searches for models by path.
:param path: The path to search for.
Args:
path: The path to search for.
Returns:
A list of models that match the path.
"""
return self._services.model_manager.store.search_by_path(path)
def search_by_attrs(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
name: Optional[str] = None,
base: Optional[BaseModelType] = None,
type: Optional[ModelType] = None,
format: Optional[ModelFormat] = None,
) -> list[AnyModelConfig]:
"""
Searches for models by attributes.
"""Searches for models by attributes.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
Args:
name: The name to search for (exact match).
base: The base to search for, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc.
type: Type type of model to search for, e.g. `ModelType.Main`, `ModelType.Vae`, etc.
format: The format of model to search for, e.g. `ModelFormat.Checkpoint`, `ModelFormat.Diffusers`, etc.
Returns:
A list of models that match the attributes.
"""
return self._services.model_manager.store.search_by_attr(
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_format=model_format,
model_name=name,
base_model=base,
model_type=type,
model_format=format,
)
class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig:
"""Gets the app's config."""
"""Gets the app's config.
Returns:
The app's config.
"""
return self._services.configuration.get_config()
@ -370,7 +432,11 @@ class UtilInterface(InvocationContextInterface):
self._cancel_event = cancel_event
def is_canceled(self) -> bool:
"""Checks if the current invocation has been canceled."""
"""Checks if the current session has been canceled.
Returns:
True if the current session has been canceled, False if not.
"""
return self._cancel_event.is_set()
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
@ -380,8 +446,9 @@ class UtilInterface(InvocationContextInterface):
This should be called after each denoising step.
:param intermediate_state: The intermediate state of the diffusion pipeline.
:param base_model: The base model for the current denoising step.
Args:
intermediate_state: The intermediate state of the diffusion pipeline.
base_model: The base model for the current denoising step.
"""
stable_diffusion_step_callback(
@ -394,8 +461,17 @@ class UtilInterface(InvocationContextInterface):
class InvocationContext:
"""
The `InvocationContext` provides access to various services and data for the current invocation.
"""Provides access to various services and data for the current invocation.
Attributes:
images (ImagesInterface): Methods to save, get and update images and their metadata.
tensors (TensorsInterface): Methods to save and get tensors, including image, noise, masks, and masked images.
conditioning (ConditioningInterface): Methods to save and get conditioning data.
models (ModelsInterface): Methods to check if a model exists, get a model, and get a model's info.
logger (LoggerInterface): The app logger.
config (ConfigInterface): The app config.
util (UtilInterface): Utility methods, including a method to check if an invocation was canceled and step callbacks.
boards (BoardsInterface): Methods to interact with boards.
"""
def __init__(
@ -438,11 +514,14 @@ def build_invocation_context(
data: InvocationContextData,
cancel_event: threading.Event,
) -> InvocationContext:
"""
Builds the invocation context for a specific invocation execution.
"""Builds the invocation context for a specific invocation execution.
:param services: The invocation services to wrap.
:param data: The invocation context data.
Args:
services: The invocation services to wrap.
data: The invocation context data.
Returns:
The invocation context.
"""
logger = LoggerInterface(services=services, data=data)

View File

@ -9,6 +9,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -35,6 +36,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_4())
migrator.register_migration(build_migration_5())
migrator.register_migration(build_migration_6())
migrator.register_migration(build_migration_7())
migrator.run_migrations()
return db

View File

@ -0,0 +1,88 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration7Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._create_models_table(cursor)
self._drop_old_models_tables(cursor)
def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None:
"""Drops the old model_records, model_metadata, model_tags and tags tables."""
tables = ["model_records", "model_metadata", "model_tags", "tags"]
for table in tables:
cursor.execute(f"DROP TABLE IF EXISTS {table};")
def _create_models_table(self, cursor: sqlite3.Cursor) -> None:
"""Creates the v4.0.0 models table."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS models (
id TEXT NOT NULL PRIMARY KEY,
hash TEXT GENERATED ALWAYS as (json_extract(config, '$.hash')) VIRTUAL NOT NULL,
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
description TEXT GENERATED ALWAYS as (json_extract(config, '$.description')) VIRTUAL,
source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL,
source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL,
source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL,
trigger_phrases TEXT GENERATED ALWAYS as (json_extract(config, '$.trigger_phrases')) VIRTUAL,
-- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
]
# Add trigger for `updated_at`.
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS models_updated_at
AFTER UPDATE
ON models FOR EACH ROW
BEGIN
UPDATE models SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
]
# Add indexes for searchable fields
indices = [
"CREATE INDEX IF NOT EXISTS base_index ON models(base);",
"CREATE INDEX IF NOT EXISTS type_index ON models(type);",
"CREATE INDEX IF NOT EXISTS name_index ON models(name);",
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON models(path);",
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def build_migration_7() -> Migration:
"""
Build the migration from database version 6 to 7.
This migration does the following:
- Adds the new models table
- Drops the old model_records, model_metadata, model_tags and tags tables.
- TODO(MM2): Migrates model names and descriptions from `models.yaml` to the new table (?).
"""
migration_7 = Migration(
from_version=6,
to_version=7,
callback=Migration7Callback(),
)
return migration_7

View File

@ -3,7 +3,6 @@
import json
import sqlite3
from hashlib import sha1
from logging import Logger
from pathlib import Path
from typing import Optional
@ -22,7 +21,7 @@ from invokeai.backend.model_manager.config import (
ModelConfigFactory,
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.model_manager.hash import ModelHash
ModelsValidator = TypeAdapter(AnyModelConfig)
@ -73,19 +72,27 @@ class MigrateModelYamlToDb1:
base_type, model_type, model_name = str(model_key).split("/")
try:
hash = FastModelHash.hash(self.config.models_path / stanza.path)
hash = ModelHash().hash(self.config.models_path / stanza.path)
except OSError:
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
continue
assert isinstance(model_key, str)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base"] = BaseModelType(base_type)
stanza["type"] = ModelType(model_type)
stanza["name"] = model_name
stanza["original_hash"] = hash
stanza["current_hash"] = hash
new_key = hash # deterministic key assignment
# special case for ip adapters, which need the new `image_encoder_model_id` field
if stanza["type"] == ModelType.IPAdapter:
try:
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
self.config.models_path / stanza.path
)
except OSError:
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
continue
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
@ -95,7 +102,7 @@ class MigrateModelYamlToDb1:
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
self._update_model(key, new_config)
else:
self.logger.info(f"Adding model {model_name} with key {model_key}")
self.logger.info(f"Adding model {model_name} with key {new_key}")
self._add_model(new_key, new_config)
except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database")
@ -143,9 +150,14 @@ class MigrateModelYamlToDb1:
""",
(
key,
record.original_hash,
record.hash,
json_serialized,
),
)
except sqlite3.IntegrityError as exc:
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
def _get_image_encoder_model_id(self, model_path: Path) -> str:
with open(model_path / "image_encoder.txt") as f:
encoder = f.read()
return encoder.strip()

View File

@ -17,8 +17,7 @@ class MigrateCallback(Protocol):
See :class:`Migration` for an example.
"""
def __call__(self, cursor: sqlite3.Cursor) -> None:
...
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
class MigrationError(RuntimeError):

View File

@ -8,3 +8,8 @@ class UrlServiceBase(ABC):
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets the URL for an image or thumbnail."""
pass
@abstractmethod
def get_model_image_url(self, model_key: str) -> str:
"""Gets the URL for a model image"""
pass

View File

@ -4,8 +4,9 @@ from .urls_base import UrlServiceBase
class LocalUrlService(UrlServiceBase):
def __init__(self, base_url: str = "api/v1"):
def __init__(self, base_url: str = "api/v1", base_url_v2: str = "api/v2"):
self._base_url = base_url
self._base_url_v2 = base_url_v2
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
image_basename = os.path.basename(image_name)
@ -15,3 +16,6 @@ class LocalUrlService(UrlServiceBase):
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
return f"{self._base_url}/images/i/{image_basename}/full"
def get_model_image_url(self, model_key: str) -> str:
return f"{self._base_url_v2}/models/i/{model_key}/image"

View File

@ -1,55 +0,0 @@
import json
from typing import Optional
from pydantic import ValidationError
from invokeai.app.services.shared.graph import Edge
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
"""
Parses raw session string, returning a dict of the graph.
Only the general graph shape is validated; none of the fields are validated.
Any `metadata_accumulator` nodes and edges are removed.
Any validation failure will return None.
"""
graph = json.loads(session_raw).get("graph", None)
# sanity check make sure the graph is at least reasonably shaped
if (
not isinstance(graph, dict)
or "nodes" not in graph
or not isinstance(graph["nodes"], dict)
or "edges" not in graph
or not isinstance(graph["edges"], list)
):
# something has gone terribly awry, return an empty dict
return None
try:
# delete the `metadata_accumulator` node
del graph["nodes"]["metadata_accumulator"]
except KeyError:
# no accumulator node, all good
pass
# delete any edges to or from it
for i, edge in enumerate(graph["edges"]):
try:
# try to parse the edge
Edge(**edge)
except ValidationError:
# something has gone terribly awry, return an empty dict
return None
if (
edge["source"]["node_id"] == "metadata_accumulator"
or edge["destination"]["node_id"] == "metadata_accumulator"
):
del graph["edges"][i]
return graph

View File

@ -22,7 +22,7 @@ def generate_ti_list(
for trigger in extract_ti_triggers_from_prompt(prompt):
name_or_key = trigger[1:-1]
try:
loaded_model = context.models.load(key=name_or_key)
loaded_model = context.models.load(name_or_key)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)
assert loaded_model.config.base == base
@ -30,7 +30,7 @@ def generate_ti_list(
except UnknownModelException:
try:
loaded_model = context.models.load_by_attrs(
model_name=name_or_key, base_model=base, model_type=ModelType.TextualInversion
name=name_or_key, base=base, type=ModelType.TextualInversion
)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)

View File

@ -1,6 +1,7 @@
"""
Initialization file for invokeai.backend.image_util methods.
"""
from .patchmatch import PatchMatch # noqa: F401
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
from .seamless import configure_model_padding # noqa: F401

View File

@ -3,6 +3,7 @@ This module defines a singleton object, "invisible_watermark" that
wraps the invisible watermark model. It respects the global "invisible_watermark"
configuration variable, that allows the watermarking to be supressed.
"""
import cv2
import numpy as np
from imwatermark import WatermarkEncoder

View File

@ -4,6 +4,7 @@ wraps the actual patchmatch object. It respects the global
"try_patchmatch" attribute, so that patchmatch loading can
be suppressed or deferred
"""
import numpy as np
import invokeai.backend.util.logging as logger

View File

@ -6,6 +6,7 @@ PngWriter -- Converts Images generated by T2I into PNGs, finds
Exports function retrieve_metadata(path)
"""
import json
import os
import re

View File

@ -3,6 +3,7 @@ This module defines a singleton object, "safety_checker" that
wraps the safety_checker model. It respects the global "nsfw_checker"
configuration variable, that allows the checker to be supressed.
"""
import numpy as np
from PIL import Image

View File

@ -1,6 +1,7 @@
"""
Check that the invokeai_root is correctly configured and exit if not.
"""
import sys
from invokeai.app.services.config import InvokeAIAppConfig

View File

@ -1,4 +1,5 @@
"""Utility (backend) functions used by model_install.py"""
from logging import Logger
from pathlib import Path
from typing import Any, Dict, List, Optional
@ -18,7 +19,6 @@ from invokeai.app.services.model_install import (
ModelInstallService,
ModelInstallServiceBase,
)
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import (
@ -38,7 +38,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService
logger = InvokeAILogger.get_logger(config=app_config)
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
db = init_db(config=app_config, logger=logger, image_files=image_files)
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
return obj

View File

@ -1,4 +1,5 @@
"""Re-export frequently-used symbols from the Model Manager backend."""
from .config import (
AnyModel,
AnyModelConfig,

View File

@ -19,15 +19,19 @@ Typical usage:
Validation errors will raise an InvalidModelConfigException error.
"""
import time
from enum import Enum
from typing import Literal, Optional, Type, Union
import torch
from diffusers import ModelMixin
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string
from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models
@ -55,8 +59,8 @@ class ModelType(str, Enum):
ONNX = "onnx"
Main = "main"
Vae = "vae"
Lora = "lora"
VAE = "vae"
LoRA = "lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
@ -72,9 +76,9 @@ class SubModelType(str, Enum):
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Vae = "vae"
VaeDecoder = "vae_decoder"
VaeEncoder = "vae_encoder"
VAE = "vae"
VAEDecoder = "vae_decoder"
VAEEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
@ -92,8 +96,8 @@ class ModelFormat(str, Enum):
Diffusers = "diffusers"
Checkpoint = "checkpoint"
Lycoris = "lycoris"
Onnx = "onnx"
LyCORIS = "lycoris"
ONNX = "onnx"
Olive = "olive"
EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder"
@ -111,128 +115,188 @@ class SchedulerPredictionType(str, Enum):
class ModelRepoVariant(str, Enum):
"""Various hugging face variants on the diffusers format."""
DEFAULT = "" # model files without "fp16" or other qualifier - empty str
Default = "" # model files without "fp16" or other qualifier - empty str
FP16 = "fp16"
FP32 = "fp32"
ONNX = "onnx"
OPENVINO = "openvino"
FLAX = "flax"
OpenVINO = "openvino"
Flax = "flax"
class ModelSourceType(str, Enum):
"""Model source type."""
Path = "path"
Url = "url"
HFRepoID = "hf_repo_id"
CivitAI = "civitai"
class ModelDefaultSettings(BaseModel):
vae: str | None
vae_precision: str | None
scheduler: SCHEDULER_NAME_VALUES | None
steps: int | None
cfg_scale: float | None
cfg_rescale_multiplier: float | None
class ModelConfigBase(BaseModel):
"""Base class for model configuration information."""
path: str = Field(description="filesystem path to the model file or directory")
name: str = Field(description="model name")
base: BaseModelType = Field(description="base model")
type: ModelType = Field(description="type of the model")
format: ModelFormat = Field(description="model format")
key: str = Field(description="unique key for model", default="<NOKEY>")
original_hash: Optional[str] = Field(
description="original fasthash of model contents", default=None
) # this is assigned at install time and will not change
current_hash: Optional[str] = Field(
description="current fasthash of model contents", default=None
) # if model is converted or otherwise modified, this will hold updated hash
description: Optional[str] = Field(description="human readable description of the model", default=None)
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time)
key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
hash: str = Field(description="The hash of the model file(s).")
path: str = Field(
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
)
name: str = Field(description="Name of the model.")
base: BaseModelType = Field(description="The base model.")
description: Optional[str] = Field(description="Model description", default=None)
source: str = Field(description="The original source of the model (path, URL or repo_id).")
source_type: ModelSourceType = Field(description="The type of source")
source_api_response: Optional[str] = Field(
description="The original API response from the source, as stringified JSON.", default=None
)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
schema["required"].extend(
["key", "base", "type", "format", "original_hash", "current_hash", "source", "last_modified"]
)
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(
use_enum_values=False,
validate_assignment=True,
json_schema_extra=json_schema_extra,
)
def update(self, attributes: Dict[str, Any]) -> None:
"""Update the object with fields in dict."""
for key, value in attributes.items():
setattr(self, key, value) # may raise a validation error
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
class _CheckpointConfig(ModelConfigBase):
class CheckpointConfigBase(ModelConfigBase):
"""Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config: str = Field(description="path to the checkpoint model config file")
config_path: str = Field(description="path to the checkpoint model config file")
converted_at: Optional[float] = Field(
description="When this model was last converted to diffusers", default_factory=time.time
)
class _DiffusersConfig(ModelConfigBase):
class DiffusersConfigBase(ModelConfigBase):
"""Model config for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
class LoRAConfig(ModelConfigBase):
class LoRALyCORISConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
type: Literal[ModelType.Lora] = ModelType.Lora
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
type: Literal[ModelType.LoRA] = ModelType.LoRA
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
class VaeCheckpointConfig(ModelConfigBase):
class LoRADiffusersConfig(ModelConfigBase):
"""Model config for LoRA/Diffusers models."""
type: Literal[ModelType.LoRA] = ModelType.LoRA
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}")
class VAECheckpointConfig(CheckpointConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.Vae] = ModelType.Vae
type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}")
class VaeDiffusersConfig(ModelConfigBase):
class VAEDiffusersConfig(ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.Vae] = ModelType.Vae
type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
class ControlNetDiffusersConfig(_DiffusersConfig):
class ControlNetDiffusersConfig(DiffusersConfigBase):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
class ControlNetCheckpointConfig(_CheckpointConfig):
class ControlNetCheckpointConfig(CheckpointConfigBase):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Checkpoint.value}")
class TextualInversionConfig(ModelConfigBase):
class TextualInversionFileConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
class _MainConfig(ModelConfigBase):
"""Model config for main models."""
class TextualInversionFolderConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""
vae: Optional[str] = Field(default=None)
variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
ztsnr_training: bool = False
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
class MainCheckpointConfig(CheckpointConfigBase):
"""Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
class MainDiffusersConfig(DiffusersConfigBase):
"""Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
class IPAdapterConfig(ModelConfigBase):
"""Model config for IP Adaptor format models."""
@ -241,65 +305,76 @@ class IPAdapterConfig(ModelConfigBase):
image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
class CLIPVisionDiffusersConfig(ModelConfigBase):
"""Model config for ClipVision."""
"""Model config for CLIPVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
class T2IConfig(ModelConfigBase):
class T2IAdapterConfig(ModelConfigBase):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
_ControlNetConfig = Annotated[
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
Field(discriminator="format"),
]
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
AnyModelConfig = Union[
_MainModelConfig,
_VaeConfig,
_ControlNetConfig,
# ModelConfigBase,
LoRAConfig,
TextualInversionConfig,
IPAdapterConfig,
CLIPVisionDiffusersConfig,
T2IConfig,
def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
"""
format_ = None
type_ = None
if isinstance(v, dict):
format_ = v.get("format")
if isinstance(format_, Enum):
format_ = format_.value
type_ = v.get("type")
if isinstance(type_, Enum):
type_ = type_.value
else:
format_ = v.format.value
type_ = v.type.value
v = f"{type_}.{format_}"
return v
AnyModelConfig = Annotated[
Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
# IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
# This is a known issue. Please see:
# https://github.com/tiangolo/fastapi/discussions/9761 and
# https://github.com/tiangolo/fastapi/discussions/9287
# AnyModelConfig = Annotated[
# Union[
# _MainModelConfig,
# _ONNXConfig,
# _VaeConfig,
# _ControlNetConfig,
# LoRAConfig,
# TextualInversionConfig,
# IPAdapterConfig,
# CLIPVisionDiffusersConfig,
# T2IConfig,
# ],
# Field(discriminator="type"),
# ]
class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects."""
@ -331,6 +406,6 @@ class ModelConfigFactory(object):
assert model is not None
if key:
model.key = key
if timestamp:
model.last_modified = timestamp
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
model.converted_at = timestamp
return model # type: ignore

View File

@ -15,7 +15,7 @@
#
# Adapted for use in InvokeAI by Lincoln Stein, July 2023
#
""" Conversion script for the Stable Diffusion checkpoints."""
"""Conversion script for the Stable Diffusion checkpoints."""
import re
from contextlib import nullcontext

View File

@ -11,56 +11,178 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
import hashlib
import os
from pathlib import Path
from typing import Dict, Union
from imohash import hashfile
from typing import Callable, Literal, Optional, Union
from concurrent.futures import ThreadPoolExecutor, as_completed
class FastModelHash(object):
"""FastModelHash obect provides one public class method, hash()."""
from blake3 import blake3
@classmethod
def hash(cls, model_location: Union[str, Path]) -> str:
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
ALGORITHM = Literal[
"md5",
"sha1",
"sha224",
"sha256",
"sha384",
"sha512",
"blake2b",
"blake2s",
"sha3_224",
"sha3_256",
"sha3_384",
"sha3_512",
"shake_128",
"shake_256",
"blake3",
]
class ModelHash:
"""
Return hexdigest string for model located at model_location.
Creates a hash of a model using a specified algorithm.
:param model_location: Path to the model
Args:
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
file_filter: A function that takes a file name and returns True if the file should be included in the hash.
If the model is a single file, it is hashed directly using the provided algorithm.
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
that directory hashes are never weaker than the file hashes.
Usage:
```py
# BLAKE3 hash
ModelHash().hash("path/to/some/model.safetensors")
# MD5
ModelHash("md5").hash("path/to/model/dir/")
```
"""
model_location = Path(model_location)
if model_location.is_file():
return cls._hash_file(model_location)
elif model_location.is_dir():
return cls._hash_dir(model_location)
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
if algorithm == "blake3":
self._hash_file = self._blake3
elif algorithm in hashlib.algorithms_available:
self._hash_file = self._get_hashlib(algorithm)
else:
raise OSError(f"Not a valid file or directory: {model_location}")
raise ValueError(f"Algorithm {algorithm} not available")
@classmethod
def _hash_file(cls, model_location: Union[str, Path]) -> str:
self._file_filter = file_filter or self._default_file_filter
def hash(self, model_path: Union[str, Path]) -> str:
"""
Fasthash a single file and return its hexdigest.
Return hexdigest of hash of model located at model_path using the algorithm provided at class instantiation.
:param model_location: Path to the model file
If model_path is a directory, the hash is computed by hashing the hashes of all model files in the
directory. The final composite hash is always computed using BLAKE3.
Args:
model_path: Path to the model
Returns:
str: Hexdigest of the hash of the model
"""
# we return md5 hash of the filehash to make it shorter
# cryptographic security not needed here
return hashlib.md5(hashfile(model_location)).hexdigest()
@classmethod
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
components: Dict[str, str] = {}
model_path = Path(model_path)
if model_path.is_file():
return self._hash_file(model_path)
elif model_path.is_dir():
return self._hash_dir(model_path)
else:
raise OSError(f"Not a valid file or directory: {model_path}")
for root, _dirs, files in os.walk(model_location):
for file in files:
# only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted.
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
continue
path = (Path(root) / file).as_posix()
fast_hash = cls._hash_file(path)
components.update({path: fast_hash})
def _hash_dir(self, dir: Path) -> str:
"""Compute the hash for all files in a directory and return a hexdigest.
# hash all the model hashes together, using alphabetic file order
md5 = hashlib.md5()
for _path, fast_hash in sorted(components.items()):
md5.update(fast_hash.encode("utf-8"))
return md5.hexdigest()
Args:
dir: Path to the directory
Returns:
str: Hexdigest of the hash of the directory
"""
model_component_paths = self._get_file_paths(dir, self._file_filter)
# Use ThreadPoolExecutor to hash files in parallel
with ThreadPoolExecutor(min(((os.cpu_count() or 1) + 4), len(model_component_paths))) as executor:
future_to_component = {executor.submit(self._hash_file, component): component for component in sorted(model_component_paths)}
component_hashes = [future.result() for future in as_completed(future_to_component)]
# BLAKE3 to hash the hashes
composite_hasher = blake3()
component_hashes.sort()
for h in component_hashes:
composite_hasher.update(h.encode("utf-8"))
return composite_hasher.hexdigest()
@staticmethod
def _get_file_paths(model_path: Path, file_filter: Callable[[str], bool]) -> list[Path]:
"""Return a list of all model files in the directory.
Args:
model_path: Path to the model
file_filter: Function that takes a file name and returns True if the file should be included in the list.
Returns:
List of all model files in the directory
"""
files: list[Path] = []
entries = [entry for entry in os.scandir(model_path.as_posix()) if not entry.name.startswith(".")]
dirs = [entry for entry in entries if entry.is_dir()]
file_paths = [entry.path for entry in entries if entry.is_file() and file_filter(entry.path)]
files.extend([Path(file) for file in file_paths])
for dir in dirs:
files.extend(ModelHash._get_file_paths(Path(dir.path), file_filter))
return files
@staticmethod
def _blake3(file_path: Path) -> str:
"""Hashes a file using BLAKE3
Args:
file_path: Path to the file to hash
Returns:
Hexdigest of the hash of the file
"""
file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(file_path)
return file_hasher.hexdigest()
@staticmethod
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
"""Factory function that returns a function to hash a file with the given algorithm.
Args:
algorithm: Hashing algorithm to use
Returns:
A function that hashes a file using the given algorithm
"""
def hashlib_hasher(file_path: Path) -> str:
"""Hashes a file using a hashlib algorithm."""
hasher = hashlib.new(algorithm)
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8 * 1024), b""):
hasher.update(chunk)
return hasher.hexdigest()
return hashlib_hasher
@staticmethod
def _default_file_filter(file_path: str) -> bool:
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
Args:
file_path: Path to the file
Returns:
True if the file matches the given extensions, otherwise False
"""
return file_path.endswith(MODEL_FILE_EXTENSIONS)

View File

@ -2,6 +2,7 @@
"""
Init file for the model loader.
"""
from importlib import import_module
from pathlib import Path

View File

@ -1,6 +1,7 @@
"""
Disk-based converted model cache.
"""
from abc import ABC, abstractmethod
from pathlib import Path

View File

@ -13,6 +13,7 @@ from invokeai.backend.model_manager import (
ModelRepoVariant,
SubModelType,
)
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
@ -50,7 +51,7 @@ class ModelLoader(ModelLoaderBase):
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
if model_config.type == "main" and not submodel_type:
if model_config.type is ModelType.Main and not submodel_type:
raise InvalidModelConfigException("submodel_type is required when loading a main model")
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
@ -80,7 +81,7 @@ class ModelLoader(ModelLoaderBase):
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
return self._convert_model(config, model_path, cache_path)
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
return False
def _load_if_needed(
@ -119,7 +120,7 @@ class ModelLoader(ModelLoaderBase):
return calc_model_size_by_fs(
model_path=model_path,
subfolder=submodel_type.value if submodel_type else None,
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
)
# This needs to be implemented in subclasses that handle checkpoints

View File

@ -14,10 +14,9 @@ Use like this:
).load_model(model_config, submodel_type)
"""
import hashlib
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple, Type
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
from ..config import (
AnyModelConfig,
@ -26,8 +25,6 @@ from ..config import (
ModelFormat,
ModelType,
SubModelType,
VaeCheckpointConfig,
VaeDiffusersConfig,
)
from . import ModelLoaderBase
@ -60,6 +57,9 @@ class ModelLoaderRegistryBase(ABC):
"""
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
class ModelLoaderRegistry:
"""
This class allows model loaders to register their type, base and format.
@ -70,10 +70,10 @@ class ModelLoaderRegistry:
@classmethod
def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
"""Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
key = cls._to_registry_key(base, type, format)
if key in cls._registry:
raise Exception(
@ -89,33 +89,15 @@ class ModelLoaderRegistry:
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type."""
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
implementation = cls._registry.get(key1) or cls._registry.get(key2)
if not implementation:
raise NotImplementedError(
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
)
return implementation, conf2, submodel_type
@classmethod
def _handle_subtype_overrides(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
model_path = Path(config.vae)
config_class = (
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
)
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
submodel_type = None
else:
new_conf = config
return new_conf, submodel_type
return implementation, config, submodel_type
@staticmethod
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:

View File

@ -3,8 +3,8 @@
from pathlib import Path
import safetensors
import torch
from safetensors.torch import load_file as safetensors_load_file
from invokeai.backend.model_manager import (
AnyModelConfig,
@ -12,6 +12,7 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
from .. import ModelLoaderRegistry
@ -20,15 +21,15 @@ from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
class ControlnetLoader(GenericDiffusersLoader):
class ControlNetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if config.format != ModelFormat.Checkpoint:
if not isinstance(config, CheckpointConfigBase):
return False
elif (
dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
@ -37,13 +38,13 @@ class ControlnetLoader(GenericDiffusersLoader):
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}")
raise Exception(f"ControlNet conversion not supported for model type: {config.base}")
else:
assert hasattr(config, "config")
config_file = config.config
assert isinstance(config, CheckpointConfigBase)
config_file = config.config_path
if model_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
checkpoint = safetensors_load_file(model_path, device="cpu")
else:
checkpoint = torch.load(model_path, map_location="cpu")

View File

@ -3,9 +3,10 @@
import sys
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Optional
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
from invokeai.backend.model_manager import (
AnyModel,
@ -41,6 +42,7 @@ class GenericDiffusersLoader(ModelLoader):
# TO DO: Add exception handling
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
result = None
if submodel_type:
try:
config = self._load_diffusers_config(model_path, config_name="model_index.json")
@ -64,6 +66,7 @@ class GenericDiffusersLoader(ModelLoader):
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
assert result is not None
return result
# TO DO: Add exception handling
@ -75,7 +78,7 @@ class GenericDiffusersLoader(ModelLoader):
result: ModelMixin = getattr(res_type, class_name)
return result
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> dict[str, Any]:
return ConfigLoader.load_config(model_path, config_name=config_name)
@ -83,8 +86,8 @@ class ConfigLoader(ConfigMixin):
"""Subclass of ConfigMixin for loading diffusers configuration files."""
@classmethod
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
def load_config(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # pyright: ignore [reportIncompatibleMethodOverride]
"""Load a diffusrs ConfigMixin configuration."""
cls.config_name = kwargs.pop("config_name")
# Diffusers doesn't provide typing info
# TODO(psyche): the types on this diffusers method are not correct
return super().load_config(*args, **kwargs) # type: ignore

View File

@ -31,7 +31,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
if submodel_type is not None:
raise ValueError("There are no submodels in an IP-Adapter model.")
model = build_ip_adapter(
ip_adapter_ckpt_path=model_path / "ip_adapter.bin",
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
device=torch.device("cpu"),
dtype=self._torch_dtype,
)

View File

@ -1,7 +1,6 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for LoRA model loading in InvokeAI."""
from logging import Logger
from pathlib import Path
from typing import Optional, Tuple
@ -23,9 +22,9 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
from .. import ModelLoader, ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
class LoraLoader(ModelLoader):
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
class LoRALoader(ModelLoader):
"""Class to load LoRA models."""
# We cheat a little bit to get access to the model base

View File

@ -18,7 +18,7 @@ from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.ONNX)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
class OnnyxDiffusersModel(GenericDiffusersLoader):
"""Class to load onnx models."""

View File

@ -1,11 +1,11 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for StableDiffusion model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from invokeai.backend.model_manager import (
AnyModel,
@ -17,7 +17,7 @@ from invokeai.backend.model_manager import (
ModelVariantType,
SubModelType,
)
from invokeai.backend.model_manager.config import MainCheckpointConfig
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from .. import ModelLoaderRegistry
@ -55,11 +55,11 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
return result
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if config.format != ModelFormat.Checkpoint:
if not isinstance(config, CheckpointConfigBase):
return False
elif (
dest_path.exists()
and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0)
and (dest_path / "model_index.json").stat().st_mtime >= (config.converted_at or 0.0)
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
@ -74,7 +74,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
)
config_file = config.config
config_file = config.config_path
self._logger.info(f"Converting {model_path} to diffusers format")
convert_ckpt_to_diffusers(

View File

@ -1,7 +1,6 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for TI model loading in InvokeAI."""
from pathlib import Path
from typing import Optional, Tuple

View File

@ -3,9 +3,9 @@
from pathlib import Path
import safetensors
import torch
from omegaconf import DictConfig, OmegaConf
from safetensors.torch import load_file as safetensors_load_file
from invokeai.backend.model_manager import (
AnyModelConfig,
@ -13,24 +13,25 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
class VaeLoader(GenericDiffusersLoader):
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class VAELoader(GenericDiffusersLoader):
"""Class to load VAE models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if config.format != ModelFormat.Checkpoint:
if not isinstance(config, CheckpointConfigBase):
return False
elif (
dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
@ -38,16 +39,15 @@ class VaeLoader(GenericDiffusersLoader):
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
# TO DO: check whether sdxl VAE models convert.
# TODO(MM2): check whether sdxl VAE models convert.
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}")
raise Exception(f"VAE conversion not supported for model type: {config.base}")
else:
config_file = (
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
)
assert isinstance(config, CheckpointConfigBase)
config_file = config.config_path
if model_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
checkpoint = safetensors_load_file(model_path, device="cpu")
else:
checkpoint = torch.load(model_path, map_location="cpu")
@ -55,7 +55,7 @@ class VaeLoader(GenericDiffusersLoader):
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
ckpt_config = OmegaConf.load(self._app_config.root_path / config_file)
assert isinstance(ckpt_config, DictConfig)
vae_model = convert_ldm_vae_to_diffusers(

View File

@ -16,6 +16,7 @@ from diffusers import AutoPipelineForText2Image
from diffusers.utils import logging as dlogging
from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from . import (
@ -117,7 +118,6 @@ class ModelMerger(object):
config = self._installer.app_config
store = self._installer.record_store
base_models: Set[BaseModelType] = set()
vae = None
variant = None if self._installer.app_config.full_precision else "fp16"
assert (
@ -134,10 +134,6 @@ class ModelMerger(object):
"normal"
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
# pick up the first model's vae
if key == model_keys[0]:
vae = info.vae
# tally base models used
base_models.add(info.base)
model_paths.extend([config.models_path / info.path])
@ -163,12 +159,10 @@ class ModelMerger(object):
# update model's config
model_config = self._installer.record_store.get_model(key)
model_config.update(
{
"name": merged_model_name,
"description": f"Merge of models {', '.join(model_names)}",
"vae": vae,
}
model_config.name = merged_model_name
model_config.description = f"Merge of models {', '.join(model_names)}"
self._installer.record_store.update_model(
key, ModelRecordChanges(name=model_config.name, description=model_config.description)
)
self._installer.record_store.update_model(key, model_config)
return model_config

View File

@ -18,15 +18,14 @@ assert isinstance(data, CivitaiMetadata)
if data.allow_commercial_use:
print("Commercial use of this model is allowed")
"""
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch, ModelMetadataFetchBase
from .metadata_base import (
AnyModelRepoMetadata,
AnyModelRepoMetadataValidator,
BaseMetadata,
CivitaiMetadata,
CommercialUsage,
HuggingFaceMetadata,
LicenseRestrictions,
ModelMetadataWithFiles,
RemoteModelFile,
UnknownMetadataException,
@ -37,10 +36,8 @@ __all__ = [
"AnyModelRepoMetadataValidator",
"CivitaiMetadata",
"CivitaiMetadataFetch",
"CommercialUsage",
"HuggingFaceMetadata",
"HuggingFaceMetadataFetch",
"LicenseRestrictions",
"ModelMetadataFetchBase",
"BaseMetadata",
"ModelMetadataWithFiles",

View File

@ -23,22 +23,21 @@ metadata = fetcher.from_url("https://civitai.com/models/206883/split")
print(metadata.trained_words)
"""
import json
import re
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Optional
import requests
from pydantic import TypeAdapter, ValidationError
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.backend.model_manager import ModelRepoVariant
from invokeai.backend.model_manager.config import ModelRepoVariant
from ..metadata_base import (
AnyModelRepoMetadata,
CivitaiMetadata,
CommercialUsage,
LicenseRestrictions,
RemoteModelFile,
UnknownMetadataException,
)
@ -52,10 +51,13 @@ CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
StringSetAdapter = TypeAdapter(set[str])
class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""Fetch model metadata from Civitai."""
def __init__(self, session: Optional[Session] = None):
def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
@ -63,6 +65,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
this module without an internet connection.
"""
self._requests = session or requests.Session()
self._api_key = api_key
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
@ -102,22 +105,21 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
May raise an `UnknownMetadataException`.
"""
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(model_url).json()
return self._from_model_json(model_json)
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
return self._from_api_response(model_json)
def _from_model_json(self, model_json: Dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
try:
version_id = version_id or model_json["modelVersions"][0]["id"]
version_id = version_id or api_response["modelVersions"][0]["id"]
except TypeError as excp:
raise UnknownMetadataException from excp
# loop till we find the section containing the version requested
version_sections = [x for x in model_json["modelVersions"] if x["id"] == version_id]
version_sections = [x for x in api_response["modelVersions"] if x["id"] == version_id]
if not version_sections:
raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
version_json = version_sections[0]
safe_thumbnails = [x["url"] for x in version_json["images"] if x["nsfw"] == "None"]
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
primary = [x for x in version_json["files"] if x.get("primary")]
@ -134,36 +136,23 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
url = url + f"?type={primary_file['type']}{metadata_string}"
model_files = [
RemoteModelFile(
url=url,
url=self._get_url_with_api_key(url),
path=Path(primary_file["name"]),
size=int(primary_file["sizeKB"] * 1024),
sha256=primary_file["hashes"]["SHA256"],
)
]
try:
trigger_phrases = StringSetAdapter.validate_python(version_json.get("trainedWords"))
except ValidationError:
trigger_phrases: set[str] = set()
return CivitaiMetadata(
id=model_json["id"],
name=version_json["name"],
version_id=version_json["id"],
version_name=version_json["name"],
created=datetime.fromisoformat(_fix_timezone(version_json["createdAt"])),
updated=datetime.fromisoformat(_fix_timezone(version_json["updatedAt"])),
published=datetime.fromisoformat(_fix_timezone(version_json["publishedAt"])),
base_model_trained_on=version_json["baseModel"], # note - need a dictionary to turn into a BaseModelType
files=model_files,
download_url=version_json["downloadUrl"],
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
author=model_json["creator"]["username"],
description=model_json["description"],
version_description=version_json["description"] or "",
tags=model_json["tags"],
trained_words=version_json["trainedWords"],
nsfw=model_json["nsfw"],
restrictions=LicenseRestrictions(
AllowNoCredit=model_json["allowNoCredit"],
AllowCommercialUse={CommercialUsage(x) for x in model_json["allowCommercialUse"]},
AllowDerivatives=model_json["allowDerivatives"],
AllowDifferentLicense=model_json["allowDifferentLicense"],
),
trigger_phrases=trigger_phrases,
api_response=json.dumps(version_json),
)
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
@ -174,14 +163,14 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""
if model_id is None:
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
version = self._requests.get(version_url).json()
version = self._requests.get(self._get_url_with_api_key(version_url)).json()
if error := version.get("error"):
raise UnknownMetadataException(error)
model_id = version["modelId"]
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(model_url).json()
return self._from_model_json(model_json, version_id)
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
return self._from_api_response(model_json, version_id)
@classmethod
def from_json(cls, json: str) -> CivitaiMetadata:
@ -189,6 +178,11 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
metadata = CivitaiMetadata.model_validate_json(json)
return metadata
def _get_url_with_api_key(self, url: str) -> str:
if not self._api_key:
return url
def _fix_timezone(date: str) -> str:
return re.sub(r"Z$", "+00:00", date)
if "?" in url:
return f"{url}&token={self._api_key}"
return f"{url}?token={self._api_key}"

View File

@ -13,6 +13,7 @@ metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
print(metadata.tags)
"""
import json
import re
from pathlib import Path
from typing import Optional
@ -23,7 +24,7 @@ from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFo
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.backend.model_manager import ModelRepoVariant
from invokeai.backend.model_manager.config import ModelRepoVariant
from ..metadata_base import (
AnyModelRepoMetadata,
@ -60,6 +61,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
# Little loop which tries fetching a revision corresponding to the selected variant.
# If not available, then set variant to None and get the default.
# If this too fails, raise exception.
model_info = None
while not model_info:
try:
@ -72,23 +74,24 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
else:
variant = None
files: list[RemoteModelFile] = []
_, name = id.split("/")
return HuggingFaceMetadata(
id=model_info.id,
author=model_info.author,
name=name,
last_modified=model_info.last_modified,
tag_dict=model_info.card_data.to_dict() if model_info.card_data else {},
tags=model_info.tags,
files=[
for s in model_info.siblings or []:
assert s.rfilename is not None
assert s.size is not None
files.append(
RemoteModelFile(
url=hf_hub_url(id, x.rfilename, revision=variant),
path=Path(name, x.rfilename),
size=x.size,
sha256=x.lfs.get("sha256") if x.lfs else None,
url=hf_hub_url(id, s.rfilename, revision=variant),
path=Path(name, s.rfilename),
size=s.size,
sha256=s.lfs.get("sha256") if s.lfs else None,
)
for x in model_info.siblings
],
)
return HuggingFaceMetadata(
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str)
)
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:

View File

@ -14,10 +14,8 @@ versions of these fields are intended to be kept in sync with the
remote repo.
"""
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
from typing import List, Literal, Optional, Union
from huggingface_hub import configure_http_backend, hf_hub_url
from pydantic import BaseModel, Field, TypeAdapter
@ -34,31 +32,6 @@ class UnknownMetadataException(Exception):
"""Raised when no metadata is available for a model."""
class CommercialUsage(str, Enum):
"""Type of commercial usage allowed."""
No = "None"
Image = "Image"
Rent = "Rent"
RentCivit = "RentCivit"
Sell = "Sell"
class LicenseRestrictions(BaseModel):
"""Broad categories of licensing restrictions."""
AllowNoCredit: bool = Field(
description="if true, model can be redistributed without crediting author", default=False
)
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
AllowDifferentLicense: bool = Field(
description="if true, derivatives of this model be redistributed under a different license", default=False
)
AllowCommercialUse: Optional[Set[CommercialUsage] | CommercialUsage] = Field(
description="Type of commercial use allowed if no commercial use is allowed.", default=None
)
class RemoteModelFile(BaseModel):
"""Information about a downloadable file that forms part of a model."""
@ -72,8 +45,6 @@ class ModelMetadataBase(BaseModel):
"""Base class for model metadata information."""
name: str = Field(description="model's name")
author: str = Field(description="model's author")
tags: Set[str] = Field(description="tags provided by model source")
class BaseMetadata(ModelMetadataBase):
@ -111,60 +82,16 @@ class CivitaiMetadata(ModelMetadataWithFiles):
"""Extended metadata fields provided by Civitai."""
type: Literal["civitai"] = "civitai"
id: int = Field(description="Civitai version identifier")
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
version_id: int = Field(description="Civitai model version identifier")
created: datetime = Field(description="date the model was created")
updated: datetime = Field(description="date the model was last modified")
published: datetime = Field(description="date the model was published to Civitai")
description: str = Field(description="text description of model; may contain HTML")
version_description: str = Field(
description="text description of the model's reversion; usually change history; may contain HTML"
)
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
download_url: AnyHttpUrl = Field(description="download URL for this model")
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
weight_minmax: Tuple[float, float] = Field(
description="minimum and maximum slider values for a LoRA or other secondary model", default=(-1.0, +2.0)
) # note: For future use
@property
def credit_required(self) -> bool:
"""Return True if you must give credit for derivatives of this model and images generated from it."""
return not self.restrictions.AllowNoCredit
@property
def allow_commercial_use(self) -> bool:
"""Return True if commercial use is allowed."""
if self.restrictions.AllowCommercialUse is None:
return False
else:
# accommodate schema change
acu = self.restrictions.AllowCommercialUse
commercial_usage = acu if isinstance(acu, set) else {acu}
return CommercialUsage.No not in commercial_usage
@property
def allow_derivatives(self) -> bool:
"""Return True if derivatives of this model can be redistributed."""
return self.restrictions.AllowDerivatives
@property
def allow_different_license(self) -> bool:
"""Return true if derivatives of this model can use a different license."""
return self.restrictions.AllowDifferentLicense
trigger_phrases: set[str] = Field(description="Trigger phrases extracted from the API response")
api_response: Optional[str] = Field(description="Response from the Civitai API as stringified JSON", default=None)
class HuggingFaceMetadata(ModelMetadataWithFiles):
"""Extended metadata fields provided by HuggingFace."""
type: Literal["huggingface"] = "huggingface"
id: str = Field(description="huggingface model id")
tag_dict: Dict[str, Any]
last_modified: datetime = Field(description="date of last commit to repo")
id: str = Field(description="The HF model id")
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
def download_urls(
self,
@ -193,7 +120,7 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
# the next step reads model_index.json to determine which subdirectories belong
# to the model
if Path(f"{prefix}model_index.json") in paths:
url = hf_hub_url(self.id, filename="model_index.json", subfolder=subfolder)
url = hf_hub_url(self.id, filename="model_index.json", subfolder=str(subfolder) if subfolder else None)
resp = session.get(url)
resp.raise_for_status()
submodels = resp.json()

View File

@ -1,221 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Storage for Model Metadata
"""
import sqlite3
from typing import List, Optional, Set, Tuple
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .fetch import ModelMetadataFetchBase
from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException
class ModelMetadataStore:
"""Store, search and fetch model metadata retrieved from remote repositories."""
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_metadata(
id,
metadata
)
VALUES (?,?);
""",
(
model_key,
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
raise UnknownMetadataException from excp
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata FROM model_metadata
WHERE id=?;
""",
(model_key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownMetadataException("model metadata not found")
return ModelMetadataFetchBase.from_json(rows[0])
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT id,metadata FROM model_metadata;
""",
(),
)
rows = self._cursor.fetchall()
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
json_serialized = metadata.model_dump_json() # turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_metadata
SET
metadata=?
WHERE id=?;
""",
(json_serialized, model_key),
)
if self._cursor.rowcount == 0:
raise UnknownMetadataException("model metadata not found")
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_metadata(model_key)
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
self._cursor.execute(
"""--sql
select tag_text from tags;
"""
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Optional[Set[str]] = None
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.model_id FROM model_tags AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
if matches is None:
matches = model_keys
matches = matches.intersection(model_keys)
except sqlite3.Error as e:
raise e
return matches if matches else set()
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE author=?;
""",
(author,),
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE name=?;
""",
(name,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@ -8,6 +8,7 @@ import torch
from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.util.util import SilenceWarnings
from .config import (
@ -17,11 +18,12 @@ from .config import (
ModelConfigFactory,
ModelFormat,
ModelRepoVariant,
ModelSourceType,
ModelType,
ModelVariantType,
SchedulerPredictionType,
)
from .hash import FastModelHash
from .hash import ModelHash
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
CkptType = Dict[str, Any]
@ -82,6 +84,9 @@ class ProbeBase(object):
class ModelProbe(object):
hasher = ModelHash()
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
"diffusers": {},
"checkpoint": {},
@ -95,8 +100,8 @@ class ModelProbe(object):
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"AutoencoderTiny": ModelType.Vae,
"AutoencoderKL": ModelType.VAE,
"AutoencoderTiny": ModelType.VAE,
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter,
@ -108,14 +113,6 @@ class ModelProbe(object):
) -> None:
cls.PROBES[format][model_type] = probe_class
@classmethod
def heuristic_probe(
cls,
model_path: Path,
fields: Optional[Dict[str, Any]] = None,
) -> AnyModelConfig:
return cls.probe(model_path, fields)
@classmethod
def probe(
cls,
@ -137,19 +134,21 @@ class ModelProbe(object):
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None
model_type = None
if format_type == "diffusers":
if format_type is ModelFormat.Diffusers:
model_type = cls.get_model_type_from_folder(model_path)
else:
model_type = cls.get_model_type_from_checkpoint(model_path)
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class:
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
hash = FastModelHash.hash(model_path)
probe = probe_class(model_path)
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
fields["source"] = fields.get("source") or model_path.as_posix()
fields["key"] = fields.get("key", uuid_string())
fields["path"] = model_path.as_posix()
fields["type"] = fields.get("type") or model_type
fields["base"] = fields.get("base") or probe.get_base_type()
@ -161,15 +160,17 @@ class ModelProbe(object):
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["original_hash"] = fields.get("original_hash") or hash
fields["current_hash"] = fields.get("current_hash") or hash
fields["hash"] = fields.get("hash") or cls.hasher.hash(model_path)
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
# additional fields needed for main and controlnet models
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
fields["config"] = cls._get_checkpoint_config_path(
if (
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
and fields["format"] is ModelFormat.Checkpoint
):
fields["config_path"] = cls._get_checkpoint_config_path(
model_path,
model_type=fields["type"],
base_type=fields["base"],
@ -179,7 +180,7 @@ class ModelProbe(object):
# additional fields needed for main non-checkpoint models
elif fields["type"] == ModelType.Main and fields["format"] in [
ModelFormat.Onnx,
ModelFormat.ONNX,
ModelFormat.Olive,
ModelFormat.Diffusers,
]:
@ -188,7 +189,7 @@ class ModelProbe(object):
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
)
model_info = ModelConfigFactory.make_config(fields, key=fields.get("key", None))
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
return model_info
@classmethod
@ -213,11 +214,11 @@ class ModelProbe(object):
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
return ModelType.Vae
return ModelType.VAE
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.Lora
return ModelType.LoRA
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
return ModelType.Lora
return ModelType.LoRA
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.ControlNet
elif key in {"emb_params", "string_to_param"}:
@ -239,7 +240,7 @@ class ModelProbe(object):
if (folder_path / f"learned_embeds.{suffix}").exists():
return ModelType.TextualInversion
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
return ModelType.Lora
return ModelType.LoRA
if (folder_path / "unet/model.onnx").exists():
return ModelType.ONNX
if (folder_path / "image_encoder.txt").exists():
@ -285,13 +286,21 @@ class ModelProbe(object):
if possible_conf.exists():
return possible_conf.absolute()
if model_type == ModelType.Main:
if model_type is ModelType.Main:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
elif model_type == ModelType.ControlNet:
elif model_type is ModelType.ControlNet:
config_file = (
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
"../controlnet/cldm_v15.yaml"
if base_type is BaseModelType.StableDiffusion1
else "../controlnet/cldm_v21.yaml"
)
elif model_type is ModelType.VAE:
config_file = (
"../stable-diffusion/v1-inference.yaml"
if base_type is BaseModelType.StableDiffusion1
else "../stable-diffusion/v2-inference.yaml"
)
else:
raise InvalidModelConfigException(
@ -497,12 +506,12 @@ class FolderProbeBase(ProbeBase):
if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16
if "openvino_model" in x.name:
return ModelRepoVariant.OPENVINO
return ModelRepoVariant.OpenVINO
if "flax_model" in x.name:
return ModelRepoVariant.FLAX
return ModelRepoVariant.Flax
if x.suffix == ".onnx":
return ModelRepoVariant.ONNX
return ModelRepoVariant.DEFAULT
return ModelRepoVariant.Default
class PipelineFolderProbe(FolderProbeBase):
@ -708,8 +717,8 @@ class T2IAdapterFolderProbe(FolderProbeBase):
############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
@ -717,8 +726,8 @@ ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderPro
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)

View File

@ -28,6 +28,7 @@ from typing import Callable, Optional, Set, Union
from pydantic import BaseModel, Field
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
default_logger: Logger = InvokeAILogger.get_logger()
@ -117,13 +118,10 @@ class ModelSearch(ModelSearchBase):
"""
models_found: Set[Path] = Field(default_factory=set)
scanned_dirs: Set[Path] = Field(default_factory=set)
pruned_paths: Set[Path] = Field(default_factory=set)
config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
def search_started(self) -> None:
self.models_found = set()
self.scanned_dirs = set()
self.pruned_paths = set()
if self.on_search_started:
self.on_search_started(self._directory)
@ -139,29 +137,28 @@ class ModelSearch(ModelSearchBase):
def search(self, directory: Union[Path, str]) -> Set[Path]:
self._directory = Path(directory)
if not self._directory.is_absolute():
self._directory = self.config.models_path / self._directory
self.stats = SearchStats() # zero out
self.search_started() # This will initialize _models_found to empty
self._walk_directory(directory)
self._walk_directory(self._directory)
self.search_completed()
return self.models_found
def _walk_directory(self, path: Union[Path, str]) -> None:
for root, dirs, files in os.walk(path, followlinks=True):
# don't descend into directories that start with a "."
# to avoid the Mac .DS_STORE issue.
if str(Path(root).name).startswith("."):
self.pruned_paths.add(Path(root))
if any(Path(root).is_relative_to(x) for x in self.pruned_paths):
continue
self.stats.items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path.parent in self.scanned_dirs:
self.scanned_dirs.add(path)
continue
def _walk_directory(self, path: Union[Path, str], max_depth: int = 20) -> None:
absolute_path = Path(path)
if (
len(absolute_path.parts) - len(self._directory.parts) > max_depth
or not absolute_path.exists()
or absolute_path.parent in self.models_found
):
return
entries = os.scandir(absolute_path.as_posix())
entries = [entry for entry in entries if not entry.name.startswith(".")]
dirs = [entry for entry in entries if entry.is_dir()]
file_names = [entry.name for entry in entries if entry.is_file()]
if any(
(path / x).exists()
x in file_names
for x in [
"config.json",
"model_index.json",
@ -170,22 +167,23 @@ class ModelSearch(ModelSearchBase):
"image_encoder.txt",
]
):
self.scanned_dirs.add(path)
try:
self.model_found(path)
self.model_found(absolute_path)
return
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))
return
for n in file_names:
if n.endswith((".ckpt", ".bin", ".pth", ".safetensors", ".pt")):
try:
self.model_found(absolute_path / n)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path.parent in self.scanned_dirs:
continue
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
try:
self.model_found(path)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))
for d in dirs:
self._walk_directory(absolute_path / d)

Some files were not shown because too many files have changed in this diff Show More