Compare commits

..

276 Commits

Author SHA1 Message Date
80fd3d3f3c cleanup: Remove manual offload from Depth Anything Processor (#5812)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [ ] No


## Description


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Merge Plan

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2024-03-01 23:13:06 +05:30
41b77cd5ff fix: minor fixes to types in the DA Detector 2024-03-01 23:08:41 +05:30
6f77477a1c cleanup: remove manual offload from depth anything 2024-03-01 23:08:41 +05:30
7cfbe5a62a docs: add v3 -> v4 migration, invocation API docs 2024-02-29 15:33:13 -05:00
68344ecac9 docs(nodes): update all docstrings for public nodes API 2024-02-29 15:33:13 -05:00
84dc5c5c7b fix: make invocation_context.py accessible to mkdocs
Needs an `__init__.py`.
2024-02-29 15:33:13 -05:00
691ecb1f5b docs: update mkdocs config 2024-02-29 15:33:13 -05:00
90b84c650f docs: bump mkdocs, add mkdocstrings
Also remove ancient requirements file - the docs dependencies are in the pyproject.toml file.
2024-02-29 15:33:13 -05:00
014be0ab67 feat(nodes): revise model load API args 2024-02-29 15:33:13 -05:00
e5d9f33f7b Next: Remove deprecated app.on_event usage in api runner 2024-02-29 20:06:07 +11:00
5a87e7b3f8 chore: ruff 2024-02-29 20:05:39 +11:00
f8b673dc85 fix: Assertion issue with SDXL Compel 2024-02-29 20:05:39 +11:00
cb8e0cbf35 Fix merge with next 2024-02-29 00:35:48 -05:00
33bd9da26c Switch absolute path to as_posix in _walk_directory 2024-02-29 00:35:48 -05:00
9190abd487 Ruff checks 2024-02-29 00:35:48 -05:00
ff47334f22 Fix directory called on _walk_directory 2024-02-29 00:35:48 -05:00
a8c3efd98a Switch ModelSearch from os.walk to os.scandir 2024-02-29 00:35:48 -05:00
8c6860a2c5 Ruff format 2024-02-28 09:49:56 -05:00
fa8263e6f0 Ruff check 2024-02-28 09:49:56 -05:00
e4b8cb1d34 Extract TI loading logic into util, disallow it from ever failing a generation 2024-02-28 09:49:56 -05:00
408a800593 Fix one last reference to the uncasted model 2024-02-28 09:49:56 -05:00
9e5e3f1019 Allow TIs to be either a key or a name in the prompt during our transition to using keys 2024-02-28 09:49:56 -05:00
98a13aa7dc handle change to Civitai metadata schema for commercial usage 2024-02-28 16:15:29 +11:00
4418c118db added add all button to scan models (#5811)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [X] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [ ] No


## Description


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Merge Plan

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2024-02-27 09:56:23 -05:00
110b0bc8fe updated to use new import model mutation 2024-02-27 09:48:41 -05:00
175cfe41a4 Merge branch 'next' into chainchompa/add-all-scan 2024-02-27 09:44:35 -05:00
a12d54afb9 added add all button to scan models 2024-02-27 09:43:02 -05:00
18af5348a2 fix(ui): merge conflict 2024-02-27 08:38:37 -05:00
b18c8e1c96 chore(ui): bump deps
The only major version is `query-string`. The breaking change for it is dropping support for old versions of node. Not a problem for us.
2024-02-27 08:38:37 -05:00
ea1e647174 ci: change frontend check to dpdm 2024-02-27 08:38:37 -05:00
af059f2cff feat(ui): move from madge to dpdm for circular dependencies 2024-02-27 08:38:37 -05:00
d8e21091e7 tidy(ui): fix circular dependencies in listeners 2024-02-27 08:38:37 -05:00
344041fd3a tidy: remove some traces of ONNX 2024-02-27 08:38:37 -05:00
588a220dd4 chore(ui): typegen, update knip config
Knip should never touch the autogenerated types
2024-02-27 08:38:37 -05:00
770d4092b9 chore(ui): update pnpm-lock.yaml
Forgot to run `pnpm i` earlier after removing packages.
2024-02-27 08:38:37 -05:00
33fe02bdff ci: add knip to ui check workflow 2024-02-27 08:38:37 -05:00
8a353bc1e3 feat(ui): configure knip 2024-02-27 08:38:37 -05:00
240f4801db tidy(ui): clean up unused code 6
unused files
2024-02-27 08:38:37 -05:00
da50507b2d tidy(ui): clean up unused code 5
variables, types and schemas
2024-02-27 08:38:37 -05:00
67d150ab66 tidy(ui): clean up unused code 4
variables, types and schemas
2024-02-27 08:38:37 -05:00
40d70add76 tidy(ui): clean up unused code 3
variables, types and schemas
2024-02-27 08:38:37 -05:00
7bd9bf3ba5 tidy(ui): clean up unused code 2
types and schemas
2024-02-27 08:38:37 -05:00
c94d607089 feat(mm): add log stmt for download complete event 2024-02-27 08:38:37 -05:00
ad801e54d4 fix(ui): model install progress sets total bytes correctly 2024-02-27 08:38:37 -05:00
fb4db83911 chore(ui): lint 2024-02-27 08:38:37 -05:00
cc229c3ea0 fix(ui): fix remaining TS issues 2024-02-27 08:38:37 -05:00
ca00fabd79 fix(ui): fix up MM queries & types (wip) 2024-02-27 08:38:37 -05:00
b361fabf81 tidy(api): remove non-heuristic install route 2024-02-27 08:38:37 -05:00
00669200c7 tidy(mm): remove ONNX from AnyModelConfig 2024-02-27 08:38:37 -05:00
fa07e82d2a tidy(ui): clean up unused code 1
- Only export when necessary
- Remove totally usused functions, variables, state, etc
- Remove unused packages
2024-02-27 08:38:37 -05:00
3632c5cd57 feat(ui): add knip + minimal config
https://knip.dev/

Replaces `unimported`
2024-02-27 08:38:37 -05:00
daef68d3c1 fix(ui): fix missing component import 2024-02-27 08:38:37 -05:00
ba29376fba ui: split the canvas mask blur and edge size setting 2024-02-27 07:32:13 -05:00
3efd9465eb feat(ui): create metadata types for control adapters
These are the same as the existing control adapter types, but the model field is non-nullable, simplifying handling of these objects.
2024-02-26 14:49:38 -05:00
a3b11c04cb fix(ui): model metadata handlers use model identifiers, not configs
Model metadata includes the main model, VAE and refiner model.

These used full model configs, as returned by the server, as their metadata type.

LoRA and control adapter metadata only use the metadata identifier.

This created a difference in handling. After parsing a model/vae/refiner, we have its name and can display it. But for LoRAs and control adapters, we only have the model key and must query for the full model config to get the name.

This change makes main model/vae/refiner metadata only have the model key, like LoRAs and control adapters.

The render function is now async so fetching can occur within it. All metadata fields with models now only contain the identifier, and fetch the model name to render their values.
2024-02-26 14:49:38 -05:00
8f9e3ac795 fix(ui): CanvasPasteBack types 2024-02-26 14:49:38 -05:00
2367f53367 tidy(ui): remove unused metadata schemas 2024-02-26 14:49:38 -05:00
8b9f0a9551 feat(nodes): update LoRAMetadataItem model
LoRA model now at under `model` not `lora.
2024-02-26 14:49:38 -05:00
ab57976e42 tidy(ui): tidy model identifier logic
- Move some files around
- Use util to extract key and base from model config
2024-02-26 14:49:38 -05:00
3c103c89f3 feat(ui): optimize model query caching
When we retrieve a list of models, upsert that data into the `getModelConfig` and `getModelConfigByAttrs` query caches.

With this change, calls to those two queries are almost always going to be free, because their caches will already have all models in them. The exception is queries for models that no longer exist.
2024-02-26 14:49:38 -05:00
0f19176944 fix(ui): fix lora metadata item type 2024-02-26 14:49:38 -05:00
fc09a954b5 fix(ui): fix node type 2024-02-26 14:49:38 -05:00
e7eee29825 feat(ui): add transformation to width/height parameter schemas to round to multiple of 8
This allows image dimensions that are not multiples of 8 to still be recalled with best effort.
2024-02-26 14:49:38 -05:00
2c1ba23f61 fix(ui): fix lora metadata rendering 2024-02-26 14:49:38 -05:00
58ef6dc6ce fix(ui): fix type issues related to change in LoRA type 2024-02-26 14:49:38 -05:00
8faefa89fe feat(ui): migrate all metadata recall logic to new system 2024-02-26 14:49:38 -05:00
02f59a3831 fix(ui): use id for component key in control adapter components 2024-02-26 14:49:38 -05:00
2555be3058 feat(ui): no JSX in metadata handlers 2024-02-26 14:49:38 -05:00
e174ce038f feat(ui): refactor metadata handling (again)
Add concepts for metadata handlers. Handlers include parsers, recallers and validators for different metadata types:
- Parsers parse a raw metadata object of any shape to a structured object.
- Recallers load the parsed metadata into state. Recallers are optional, as some metadata types don't need to be loaded into state.
- Validators provide an additional layer of validation before recalling the metadata. This is needed because a metadata object may be valid, but not able to be recalled due to some other requirement, like base model compatibility. Validators are optional.

Sometimes metadata is not a single object but a list of items - like LoRAs. Metadata handlers may implement an optional set of "item" handlers which operate on individual items in the list.

Parsers and validators are async to allow fetching additional data, like a model config. Recallers are synchronous.

The these handlers are composed into a public API, exported as a `handlers` object. Besides the handlers functions, a metadata handler set includes:
- A function to get the label of the metadata type.
- An optional function to render the value of the metadata type.
- An optional function to render the _item_ value of the metadata type.
2024-02-26 14:49:38 -05:00
0f10faf0d4 build(ui): do not fail build on eslint error in dev mode 2024-02-26 14:49:38 -05:00
393e32f8a7 chore(ui): typegen 2024-02-26 14:49:38 -05:00
70412464c8 feat(api): add MM get_by_attrs route
Gets the first model that matches the given name, base and type. Raises 404 if there isn't one.

This will be used for backwards compatibility with old metadata.
2024-02-26 14:49:38 -05:00
30fdb9dbfd undo 2024-02-26 14:44:37 -05:00
66f6013436 fix literal strings in MM UI 2024-02-26 14:44:37 -05:00
49b04f7db8 fix TI appearing as key in prompt 2024-02-26 14:20:28 -05:00
253dc5d43d fix base model grouping in combobox 2024-02-26 14:20:28 -05:00
3ccb4e6ff9 fix(mm): fix ModelCacheBase method name 2024-02-26 17:38:31 +11:00
200a9d1801 chore: ruff 2024-02-26 17:38:31 +11:00
b09a76ea0d recover gracefuly from GPU out of memory errors (next version) 2024-02-26 17:38:31 +11:00
8a2030e78a clear out VRAM when an OOM occurs 2024-02-26 17:38:31 +11:00
dfa5505ed8 feat(ui): bulk download click to download 2024-02-25 22:23:15 -05:00
f8b731b900 fix(ui): fix node types for canvas graphs 2024-02-24 19:38:16 +11:00
fd9ab0fb7d chore(ui): typegen 2024-02-24 19:38:16 +11:00
f504a5c96e tidy(nodes): rename canvas paste back 2024-02-24 19:38:16 +11:00
afe6639b9c fix: outpaint result not getting pasted back correctly 2024-02-24 19:38:16 +11:00
1f1bf15099 fix: lint errors 2024-02-24 19:38:16 +11:00
8fa238f100 canvas: improve paste back (or try to) 2024-02-24 19:38:16 +11:00
30b6a0ee23 wip(ui): Replace 2 Layer Coherence pass with Gradient Mask 2024-02-24 19:38:16 +11:00
784878c300 chore: ruff 2024-02-24 19:04:52 +11:00
b51b163400 fix(ui): fix merge issue 2024-02-24 19:04:52 +11:00
7e13224ec8 fix(ui): use new scan_folder response instead of hook to determine if models are installed already 2024-02-24 19:04:52 +11:00
7bc454209c chore(ui): typegen 2024-02-24 19:04:52 +11:00
cc7f6c7048 feat(mm): add logic to scan_folder route to check if a model is already installed
This was done in the frontend before but it's something the backend should handle.

The logic compares the found model paths to the path and source of all installed models. It excludes core models.
2024-02-24 19:04:52 +11:00
8b8d950137 chore(ui): lint 2024-02-24 19:04:52 +11:00
24fd7f41ff build(ui): restore i18n eslint rule 2024-02-24 19:04:52 +11:00
7c5e458372 chore: ruff 2024-02-24 19:04:52 +11:00
a5dba4b0d9 fix(ui): fix metadata route 2024-02-24 19:04:52 +11:00
72fb1cefff chore(ui): typegen 2024-02-24 19:04:52 +11:00
a64f1c0b20 feat(api): mm metadata route "meta" -> "metadata" 2024-02-24 19:04:52 +11:00
974658107d lint fix 2024-02-24 19:04:52 +11:00
07fb5d5c19 updated translations 2024-02-24 19:04:52 +11:00
20c75e7a7e fix convert endpoint logic 2024-02-24 19:04:52 +11:00
cfcb68696c clean up old model manager components and endpoints 2024-02-24 19:04:52 +11:00
7b1b6d3235 add model convert to checkpoint main models 2024-02-24 19:04:52 +11:00
aefba52a0a fix logic to see if scanned models are already installed, style tweaks 2024-02-24 19:04:52 +11:00
6af46f9c5f add error_reason to ModelInstallJob 2024-02-24 19:04:52 +11:00
190702d011 add error_reason to UI if import fails 2024-02-24 19:04:52 +11:00
7785e8ff79 fix types for ImportQueue, add QuickAdd for scan models 2024-02-24 19:04:52 +11:00
b3beaefa04 refactored and fixed issues with advanced import form 2024-02-24 19:04:52 +11:00
98be81354a fix(ui): misc MM cleanup 2024-02-24 19:04:52 +11:00
2a2a5eb775 chore(ui): temp disable eslint i18 rule 2024-02-24 19:04:52 +11:00
4a42b15b42 fix(ui): fix ImportMainModelResponse type 2024-02-24 19:04:52 +11:00
f24d5e5e31 fix(ui): simplify model install event listeners 2024-02-24 19:04:52 +11:00
4b106bc903 fix(ui): fix model install event types 2024-02-24 19:04:52 +11:00
135ef9066f added advanced import forms, not fully working yet 2024-02-24 19:04:52 +11:00
0567f98e4a get positioning/scrolling working for scan results list 2024-02-24 19:04:52 +11:00
5b66baa3ec basic scan working and renders results 2024-02-24 19:04:52 +11:00
a022aaf258 add scan model endpoint, break add model into tabs 2024-02-24 19:04:52 +11:00
94065b090a update metadata endpoint 2024-02-24 19:04:52 +11:00
091bf9220b allow metadata-less models to be used for GET metadata endpoint 2024-02-24 19:04:52 +11:00
8d243b1fca added status to import queue model 2024-02-24 19:04:52 +11:00
23c412e011 delete model imports and prune all finished, update state with socket messages 2024-02-24 19:04:52 +11:00
66692f02aa fix sync model endpoint 2024-02-24 19:04:52 +11:00
38af1c3a81 form error handling 2024-02-24 19:04:52 +11:00
7b4b7e3781 finish model update 2024-02-24 19:04:52 +11:00
02a3472505 added socket listeners, added more info to ui 2024-02-24 19:04:52 +11:00
909d354a38 edit view for model, depending on type and valid values 2024-02-24 19:04:52 +11:00
7801b8c42f hook up Add Model button 2024-02-24 19:04:52 +11:00
4fd259bb89 single model view 2024-02-24 19:04:52 +11:00
b8b3ef9725 added import model form and importqueue 2024-02-24 19:04:52 +11:00
3a8d5dc349 model list, filtering, searching 2024-02-24 19:04:52 +11:00
358cac9674 workspace for mary and jenn 2024-02-24 19:04:52 +11:00
bdc2b8069b get old UI working somewhat with new endpoints 2024-02-24 19:04:52 +11:00
09295ae43b Allow passing in key on register 2024-02-23 14:47:14 -05:00
80ad14d89f Remove passing keys in on register 2024-02-23 14:33:49 -05:00
c674eb3168 Run ruff 2024-02-23 14:33:49 -05:00
63138640a7 Allow users to run model manager without cuda 2024-02-23 14:33:49 -05:00
d103ff0d6e fix(ui): roll back utility-types
It's `Required` util does not distribute over unions as expected. Also we have `ts-toolbelt` already for some utils.
2024-02-23 07:53:45 +11:00
94931e8ac0 feat(ui): refactor metadata handling
Refactor of metadata recall handling. This is in preparation for a backwards compatibility layer for models.

- Create helpers to fetch a model outside react (e.g. not in a hook)
- Created helpers to parse model metadata
- Renamed a lot of types that were confusing and/or had naming collisions
2024-02-23 07:53:45 +11:00
b409f3aaf9 chore(ui): typegen 2024-02-23 07:53:45 +11:00
f96b7f2e11 fix(nodes): make fields on ModelConfigBase required
The setup of `ModelConfigBase` means autogenerated types have critical fields flagged as nullable (like `key` and `base`). Need to manually flag them as required.
2024-02-23 07:53:45 +11:00
de3be4bd30 feat(ui): replace type-fest with utility-types
- The new package has more useful types
- Only used `JsonObject` from `type-fest`; added an implementation of that type
2024-02-23 07:53:45 +11:00
cc12f57a5a several small model install enhancements
- Support extended HF repoid syntax in TUI. This allows
  installation of subfolders and safetensors files, as in
  `XpucT/Deliberate::Deliberate_v5.safetensors`

- Add `error` and `error_traceback` properties to the install
  job objects.

- Rename the `heuristic_import` route to `heuristic_install`.

- Fix the example `config` input in the `heuristic_install` route.
2024-02-23 07:48:23 +11:00
613f11a3ac use official Deliberate download repo 2024-02-23 07:48:04 +11:00
a6e2d2c5e0 fix repo-id for the Deliberate v5 model
prevent lora and embedding file suffixes from being stripped during installation

apply psychedelicious patch to get compel to load proper TI embedding
2024-02-23 07:48:04 +11:00
ae14df97d6 remove startup dependency on legacy models.yaml file 2024-02-23 07:47:39 +11:00
a6e1ac6096 chore: typing 2024-02-22 10:04:33 -05:00
8530635540 chore: typing fix 2024-02-22 10:04:33 -05:00
b2b7aed030 feat(nodes): added gradient mask node 2024-02-22 10:04:33 -05:00
970d45f691 Run ruff 2024-02-22 09:50:02 -05:00
19b9a22d93 rename endpoint for scanning 2024-02-22 09:50:02 -05:00
c0d9990344 Create /search endpoint, update model object structure in scan model page 2024-02-22 09:50:02 -05:00
4ac5e307c4 chore(ui): bump deps
Notable updates:
- Minor version of RTK includes customizable selectors for RTK Query, so we can remove the patch that was added to ensure only the LRU memoize function was used for perf reasons. Updated to use the LRU memoize function.
- Major version of react-resizable-panels. No breaking changes, works great, and you can now resize all panels when dragging at the intersection point of panels. Cool!
- Minor (?) version of nanostores. `action` API is removed, we were using it in one spot. Fixed.
- @invoke-ai/eslint-config-react has all deps bumped and now has its dependent plugins/configs listed as normal dependencies (as opposed to peer deps). This means we can remove those packages from explicit dev deps.
2024-02-22 07:27:28 +11:00
2815f737fe tidy(ui): remove debugging stmt 2024-02-22 07:26:47 +11:00
63e96fd1ea fix(ui): handle new model format for metadata 2024-02-22 07:26:47 +11:00
66ab56246a fix(ui): use model names in badges 2024-02-22 07:26:47 +11:00
20a56bc757 fix(nodes): fix TI loading 2024-02-22 07:26:47 +11:00
82925e1539 fix(ui): fix package build 2024-02-21 08:31:55 -05:00
0137a0db7b feat(ui): do not subscribe to bulk download sio room if baseUrl is set 2024-02-21 00:00:25 +11:00
b410793684 feat(ui): revise bulk download listeners
- Use a single listener for all of the to keep them in one spot
- Use the bulk download item name as a toast id so we can update the existing toasts
- Update handling to work with other environments
- Move all bulk download handling from components to listener
2024-02-21 00:00:25 +11:00
894e9f127b chore(ui): typegen 2024-02-21 00:00:25 +11:00
dd9b1c8eec feat(bulk_download): update response model, messages 2024-02-21 00:00:25 +11:00
8d9c566656 implementing download for bulk_download events 2024-02-21 00:00:25 +11:00
9db7e073a3 setting up event listeners for bulk download socket 2024-02-21 00:00:25 +11:00
5f64ed5bd5 test: clean up & fix tests
- Deduplicate the mock invocation services. This is possible now that the import order issue is resolved.
- Merge `DummyEventService` into `TestEventService` and update all tests to use `TestEventService`.
2024-02-20 23:39:30 +11:00
7f75f6226b tidy(bulk_download): don't store events service separately
Using the invoker object directly leaves no ambiguity as to what `_events_bus` actually is.
2024-02-20 23:39:30 +11:00
6dc819fd47 tidy(bulk_download): do not rely on pagination API to get all images for board
We can get all images for the board as a list of image names, then pass that to `_image_handler` to get the DTOs, decoupling from the pagination API.
2024-02-20 23:39:30 +11:00
0cc81e5d63 tidy(bulk_download): nit - use or as a coalescing operator
Just a bit cleaner.
2024-02-20 23:39:30 +11:00
daecc54153 tidy(bulk_download): use single underscore for private attrs
Double underscores are used in the app but it doesn't actually do or convey anything that single underscores don't already do. Considered unpythonic except for actual dunder/magic methods.
2024-02-20 23:39:30 +11:00
4c31c7f9f1 tidy(bulk_download): remove class-level attr annotations
These can be misleading as they shadow actual assigned class attributes. This pattern is in the rest of the app but it shouldn't be.
2024-02-20 23:39:30 +11:00
d709c5519f tidy(bulk_download): remove extraneous abstract methods
`start`, `stop` and `__init__` are not required in implementations of an ABC or service.
2024-02-20 23:39:30 +11:00
5d84ecef49 tidy(bulk_download): clean up comments 2024-02-20 23:39:30 +11:00
641d246213 adding bulk_download_item_name to socket events 2024-02-20 23:39:30 +11:00
2e53aa48c9 refactoring handlers to do null check 2024-02-20 23:39:30 +11:00
ef12631450 removing dependency on an output folder, embrace python temp folder for bulk download 2024-02-20 23:39:30 +11:00
d9eb626b62 relocating event_service fixture due to import ordering 2024-02-20 23:39:30 +11:00
8033589629 moving the responsibility of cleaning up board names to the service not the route 2024-02-20 23:39:30 +11:00
124075ae7a updating imports to satisfy ruff 2024-02-20 23:39:30 +11:00
0bde933c89 using temp directory for downloads 2024-02-20 23:39:30 +11:00
fc5c5b6bdd returning the bulk_download_item_name on response for possible polling 2024-02-20 23:39:30 +11:00
ff53563152 narrowing bulk_download stop service scope 2024-02-20 23:39:30 +11:00
12b0d735e7 adding test coverage for new bulk download routes 2024-02-20 23:39:30 +11:00
d06ee94fd3 cleaning up bulk download zip after the response is complete 2024-02-20 23:39:30 +11:00
9dbdb6cf7c replacing import removed during rebase 2024-02-20 23:39:30 +11:00
7c091570fe 97% test coverage on bulk_download 2024-02-20 23:39:30 +11:00
e99f3482cc refactoring bulk_download to be better managed 2024-02-20 23:39:30 +11:00
d999c9ffd6 refactoring dummy event service, DRY principal; adding bulk_download_event to existing invoker tests 2024-02-20 23:39:30 +11:00
888db8ac46 refactoring bulkdownload to consider image category 2024-02-20 23:39:30 +11:00
7deef2cb27 fixing issue where default board did not return images 2024-02-20 23:39:30 +11:00
ada807af0c using the board name to download boards 2024-02-20 23:39:30 +11:00
aa132fb9e3 reworking some of the logic to use a default room, adding endpoint to download file on complete 2024-02-20 23:39:30 +11:00
98a01368b8 linted and styling 2024-02-20 23:39:30 +11:00
fc9a62dbf5 implementation of bulkdownload background task 2024-02-20 23:39:30 +11:00
4d8bec1605 adding socket events for bulk download 2024-02-20 23:39:30 +11:00
cf9dad83bc groundwork for the bulk_download_service 2024-02-20 23:39:30 +11:00
0d0a2a5c91 fix(ui): get workflow editor model selects working 2024-02-20 13:33:31 +11:00
0cab636ab0 fix(ui): get refiner model select working 2024-02-20 13:33:31 +11:00
de097ec58a fix(ui): get vae model select working 2024-02-20 13:33:31 +11:00
bb6f426162 fix(ui): get embedding select working 2024-02-20 13:33:31 +11:00
663f135b3c fix(ui): get lora select working 2024-02-20 13:33:31 +11:00
2f2097662a chore(ui): bump @invoke-ai/ui-library 2024-02-20 13:33:31 +11:00
458c29cfa5 fix(ui): fix low-hanging fruit types 2024-02-20 13:33:31 +11:00
4bec01d6f2 Add a few convenience targets to Makefile
- "test" to run pytests
- "frontend-install" to reinstall pnpm's node modeuls
2024-02-20 10:02:46 +11:00
9d79ee8dc4 chore(nodes): update TODO comment 2024-02-20 09:54:01 +11:00
78dd460348 tidy(nodes): clean up profiler/stats in processor, better comments 2024-02-20 09:54:01 +11:00
9d27d354cf fix(nodes): fix typing on stats service context manager 2024-02-20 09:54:01 +11:00
e8725a1099 fix(nodes): fix model load events
was accessing incorrect properties in event data
2024-02-20 09:54:01 +11:00
479d65b6e1 feat(nodes): making invocation class var in processor 2024-02-20 09:54:01 +11:00
5d4b388dfd feat(nodes): improved error messages in processor 2024-02-20 09:54:01 +11:00
4956fa282b feat(nodes): make processor thread limit and polling interval configurable 2024-02-20 09:54:01 +11:00
51133522b7 tests(nodes): fix tests following removal of services 2024-02-20 09:54:01 +11:00
6d5cc8b1ff chore(nodes): better comments for invocation context 2024-02-20 09:54:01 +11:00
08a5bb90e2 chore(nodes): "context_data" -> "data"
Changed within InvocationContext, for brevity.
2024-02-20 09:54:01 +11:00
39bdf5c4e9 refactor(nodes): move is_canceled to context.util 2024-02-20 09:54:01 +11:00
f31e4205aa feat(nodes): add whole queue_item to InvocationContextData
No reason to not have the whole thing in there.
2024-02-20 09:54:01 +11:00
4d05c4ff66 tidy(nodes): remove extraneous comments 2024-02-20 09:54:01 +11:00
7e88d2a7f1 feat(nodes): better invocation error messages 2024-02-20 09:54:01 +11:00
556f6aa174 chore(nodes): add comments for cancel state 2024-02-20 09:54:01 +11:00
6a74048af8 feat(nodes): promote is_canceled to public node API 2024-02-20 09:54:01 +11:00
2cb51bff11 refactor(nodes): merge processors
Consolidate graph processing logic into session processor.

With graphs as the unit of work, and the session queue distributing graphs, we no longer need the invocation queue or processor.

Instead, the session processor dequeues the next session and processes it in a simple loop, greatly simplifying the app.

- Remove `graph_execution_manager` service.
- Remove `queue` (invocation queue) service.
- Remove `processor` (invocation processor) service.
- Remove queue-related logic from `Invoker`. It now only starts and stops the services, providing them with access to other services.
- Remove unused `invocation_retrieval_error` and `session_retrieval_error` events, these are no longer needed.
- Clean up stats service now that it is less coupled to the rest of the app.
- Refactor cancellation logic - cancellations now originate from session queue (i.e. HTTP cancel endpoint) and are emitted as events. Processor gets the events and sets the canceled event. Access to this event is provided to the invocation context for e.g. the step callback.
- Remove `sessions` router; it provided access to `graph_executions` but that no longer exists.
2024-02-20 09:54:01 +11:00
851e835e0e tidy(nodes): remove commented tests 2024-02-20 09:48:14 +11:00
fe04f28841 chore(ui): typegen 2024-02-20 09:48:14 +11:00
258fc006ec tidy(nodes): remove no-op model_config
Because we now customize the JSON Schema creation for GraphExecutionState, the model_config did nothing.
2024-02-20 09:48:14 +11:00
dcb4ee47d5 tidy(nodes): remove LibraryGraphs
The workflow library supersedes this unused feature.
2024-02-20 09:48:14 +11:00
1a56f5aaf9 tidy(nodes): move node tests to parent dir
Thanks to the resolution of the import vs union issue, we can put tests anywhere.
2024-02-20 09:48:14 +11:00
5fc745653a tidy(nodes): remove GraphInvocation
`GraphInvocation` is a node that can contain a whole graph. It is removed for a number of reasons:

1. This feature was unused (the UI doesn't support it) and there is no plan for it to be used.

The use-case it served is known in other node execution engines as "node groups" or "blocks" - a self-contained group of nodes, which has group inputs and outputs. This is a planned feature that will be handled client-side.

2. It adds substantial complexity to the graph processing logic. It's probably not enough to have a measurable performance impact but it does make it harder to work in the graph logic.

3. It allows for graphs to be recursive, and the improved invocations union handling does not play well with it. Actually, it works fine within `graph.py` but not in the tests for some reason. I do not understand why. There's probably a workaround, but I took this as encouragement to remove `GraphInvocation` from the app since we don't use it.
2024-02-20 09:48:14 +11:00
47b5a90177 fix(nodes): fix OpenAPI schema generation
The change to `Graph.nodes` and `GraphExecutionState.results` validation requires some fanagling to get the OpenAPI schema generation to work. See new comments for a details.
2024-02-20 09:48:14 +11:00
81518ee1af feat(nodes): JIT graph nodes validation
We use pydantic to validate a union of valid invocations when instantiating a graph.

Previously, we constructed the union while creating the `Graph` class. This introduces a dependency on the order of imports.

For example, consider a setup where we have 3 invocations in the app:

- Python executes the module where `FirstInvocation` is defined, registering `FirstInvocation`.
- Python executes the module where `SecondInvocation` is defined, registering `SecondInvocation`.
- Python executes the module where `Graph` is defined. A union of invocations is created and used to define the `Graph.nodes` field. The union contains `FirstInvocation` and `SecondInvocation`.
- Python executes the module where `ThirdInvocation` is defined, registering `ThirdInvocation`.
- A graph is created that includes `ThirdInvocation`. Pydantic validates the graph using the union, which does not know about `ThirdInvocation`, raising a `ValidationError` about an unknown invocation type.

This scenario has been particularly problematic in tests, where we may create invocations dynamically. The test files have to be structured in such a way that the imports happen in the right order. It's a major pain.

This PR refactors the validation of graph nodes to resolve this issue:

- `BaseInvocation` gets a new method `get_typeadapter`. This builds a pydantic `TypeAdapter` for the union of all registered invocations, caching it after the first call.
- `Graph.nodes`'s type is widened to `dict[str, BaseInvocation]`. This actually is a nice bonus, because we get better type hints whenever we reference `some_graph.nodes`.
- A "plain" field validator takes over the validation logic for `Graph.nodes`. "Plain" validators totally override pydantic's own validation logic. The validator grabs the `TypeAdapter` from `BaseInvocation`, then validates each node with it. The validation is identical to the previous implementation - we get the same errors.

`BaseInvocationOutput` gets the same treatment.
2024-02-20 09:48:14 +11:00
b06d63fb34 remove errant def that was crashing invokeai-configure 2024-02-19 17:31:53 +11:00
5278a64301 one more redundant RGB convert removed 2024-02-19 17:31:08 +11:00
4de4473c0f chore: ruff formatting 2024-02-19 17:31:08 +11:00
2c28a850ca chore(invocations): remove redundant RGB conversions 2024-02-19 17:31:08 +11:00
6dada3326d chore(invocations): use IMAGE_MODES constant literal 2024-02-19 17:31:08 +11:00
2dfdc02ec8 fix: removed custom module 2024-02-19 17:31:08 +11:00
1f19db4c6a fix(nodes): canny preprocessor uses RGBA again 2024-02-19 17:31:08 +11:00
7c150c27f2 feat(nodes): format option for get_image method
Also default CNet preprocessors to "RGB"
2024-02-19 17:31:08 +11:00
248916c190 fix: Alpha channel causing issue with DW Processor 2024-02-19 08:17:56 +11:00
be8b99eed5 final tidying before marking PR as ready for review
- Replace AnyModelLoader with ModelLoaderRegistry
- Fix type check errors in multiple files
- Remove apparently unneeded `get_model_config_enum()` method from model manager
- Remove last vestiges of old model manager
- Updated tests and documentation

resolve conflict with seamless.py
2024-02-19 08:16:56 +11:00
2ad0752582 Tidy names and locations of modules
- Rename old "model_management" directory to "model_management_OLD" in order to catch
  dangling references to original model manager.
- Caught and fixed most dangling references (still checking)
- Rename lora, textual_inversion and model_patcher modules
- Introduce a RawModel base class to simplfy the Union returned by the
  model loaders.
- Tidy up the model manager 2-related tests. Add useful fixtures, and
  a finalizer to the queue and installer fixtures that will stop the
  services and release threads.
2024-02-19 08:16:56 +11:00
ba1f8878dd Fix issues identified during PR review by RyanjDick and brandonrising
- ModelMetadataStoreService is now injected into ModelRecordStoreService
  (these two services are really joined at the hip, and should someday be merged)
- ModelRecordStoreService is now injected into ModelManagerService
- Reduced timeout value for the various installer and download wait*() methods
- Introduced a Mock modelmanager for testing
- Removed bare print() statement with _logger in the install helper backend.
- Removed unused code from model loader init file
- Made `locker` a private variable in the `LoadedModel` object.
- Fixed up model merge frontend (will be deprecated anyway!)
2024-02-19 08:16:56 +11:00
bc524026f9 feat(ui): update model identifiers to use key (#5730)
## What type of PR is this? (check all applicable)

- [x] Refactor

## Description

- Update zod schemas & types to use key instead of name/base/type
- Use new `CustomSelect` component instead of `ComboBox` for main model
select and control adapter model selects (less jank, will switch to
ComboBox based on CustomSelect for v4 so you can search the select)

## QA Instructions, Screenshots, Recordings

If you hold your breath, you should be able to generate with a control
adapter.

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Merge Plan

This PR can be merged when approved. Frontend tests not passing.

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->
2024-02-16 11:17:35 -05:00
ad7c571983 fix(nodes): fix t2i adapter model loading (#5731)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission

## Description

Fixes t2i adapter loading

## Merge Plan

This PR can be merged when approved

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->
2024-02-16 11:17:21 -05:00
8559c6a392 fix(nodes): fix t2i adapter model loading 2024-02-16 22:51:47 +11:00
c7904a32f4 chore(ui): lint 2024-02-16 22:42:15 +11:00
17f5484f5b feat(ui): fix main model & control adapter model selects 2024-02-16 22:41:09 +11:00
86a372b02f refactor(ui): url builders for each router
The MM2 router is at `api/v2/models`. URL builder utils make this a bit easier to manage.
2024-02-16 21:57:30 +11:00
2e9aa9391d feat(ui): update model identifier to be key (wip)
- Update most model identifiers to be `{key: string}` instead of name/base/type. Doesn't change the model select components yet.
- Update model _parameters_, stored in redux, to be `{key: string, base: BaseModel}` - we need to store the base model to be able to check model compatibility. May want to store the whole config? Not sure...
2024-02-16 18:56:02 +11:00
0c8112cf28 fix(ui): update model types 2024-02-15 22:17:16 +11:00
019898c7be tests(ui): add type tests 2024-02-15 22:16:55 +11:00
2b1ff8d196 tests(ui): enable vitest type testing
This is useful for the zod schemas and types we have created to match the backend.
2024-02-15 22:16:11 +11:00
79fb691b4d chore(ui): typegen 2024-02-15 22:15:21 +11:00
560ae17e21 feat(ui): export components type 2024-02-15 21:16:25 +11:00
2bd1ab2f1c fix(ui): fix type issues 2024-02-15 20:53:41 +11:00
ed43472582 chore: lint 2024-02-15 20:52:44 +11:00
6e5e9176c0 chore: ruff 2024-02-15 20:50:47 +11:00
4c6bcdbc18 feat(nodes): update invocation context for mm2, update nodes model usage 2024-02-15 20:43:41 +11:00
20e6d4fa3c Raise InvalidModelConfigException when unable to detect load class in ModelLoader 2024-02-15 18:00:16 +11:00
8e51392910 Update _get_hf_load_class to support clipvision models 2024-02-15 18:00:16 +11:00
0b1c2acd61 References to context.services.model_manager.store.get_model can only accept keys, remove invalid assertion 2024-02-15 18:00:16 +11:00
86ac55ab5f Remove references to model_records service, change submodel property on ModelInfo to submodel_type to support new params in model manager 2024-02-15 18:00:16 +11:00
3e82f63c7e improve swagger documentation 2024-02-15 18:00:08 +11:00
631f6cae19 fix a number of typechecking errors 2024-02-15 18:00:08 +11:00
0845a0ed84 add route for model conversion from safetensors to diffusers
- Begin to add SwaggerUI documentation for AnyModelConfig and other
  discriminated Unions.
2024-02-15 18:00:08 +11:00
46c8ce9fed add a JIT download_and_cache() call to the model installer 2024-02-15 18:00:08 +11:00
13a9ea35b5 add back the heuristic_import() method and extend repo_ids to arbitrary file paths 2024-02-15 18:00:08 +11:00
94e8d1b6d5 make model manager v2 ready for PR review
- Replace legacy model manager service with the v2 manager.

- Update invocations to use new load interface.

- Fixed many but not all type checking errors in the invocations. Most
  were unrelated to model manager

- Updated routes. All the new routes live under the route tag
  `model_manager_v2`. To avoid confusion with the old routes,
  they have the URL prefix `/api/v2/models`. The old routes
  have been de-registered.

- Added a pytest for the loader.

- Updated documentation in contributing/MODEL_MANAGER.md
2024-02-15 18:00:08 +11:00
2b1dc74080 consolidate model manager parts into a single class 2024-02-15 17:57:14 +11:00
f7e558d165 probe for required encoder for IPAdapters and add to config 2024-02-15 17:56:01 +11:00
d959276217 fix invokeai_configure script to work with new mm; rename CLIs 2024-02-15 17:56:01 +11:00
dfcf38be91 BREAKING CHANGES: invocations now require model key, not base/type/name
- Implement new model loader and modify invocations and embeddings

- Finish implementation loaders for all models currently supported by
  InvokeAI.

- Move lora, textual_inversion, and model patching support into
  backend/embeddings.

- Restore support for model cache statistics collection (a little ugly,
  needs work).

- Fixed up invocations that load and patch models.

- Move seamless and silencewarnings utils into better location
2024-02-15 17:56:01 +11:00
fbded1c0f2 Multiple refinements on loaders:
- Cache stat collection enabled.
- Implemented ONNX loading.
- Add ability to specify the repo version variant in installer CLI.
- If caller asks for a repo version that doesn't exist, will fall back
  to empty version rather than raising an error.
2024-02-15 17:51:07 +11:00
ad2926a24c added textual inversion and lora loaders 2024-02-15 17:51:07 +11:00
34d5cad4c9 loaders for main, controlnet, ip-adapter, clipvision and t2i 2024-02-15 17:51:07 +11:00
60aa3d4893 model loading and conversion implemented for vaes 2024-02-15 17:50:51 +11:00
5c2884569e add ram cache module and support files 2024-02-15 17:50:31 +11:00
a1307b9f2e add concept of repo variant 2024-02-15 17:50:31 +11:00
554 changed files with 18347 additions and 27002 deletions

View File

@ -36,8 +36,10 @@ jobs:
- name: Typescript
run: 'pnpm run lint:tsc'
- name: Madge
run: 'pnpm run lint:madge'
run: 'pnpm run lint:dpdm'
- name: ESLint
run: 'pnpm run lint:eslint'
- name: Prettier
run: 'pnpm run lint:prettier'
- name: Knip
run: 'pnpm run lint:knip'

View File

@ -6,33 +6,44 @@ default: help
help:
@echo Developer commands:
@echo
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
@echo "frontend-build Build the frontend in order to run on localhost:9090"
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
@echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
@echo "test" Run the unit tests.
@echo "frontend-install" Install the pnpm modules needed for the front end
@echo "frontend-build Build the frontend in order to run on localhost:9090"
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
@echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
# Runs ruff, fixing any safely-fixable errors and formatting
ruff:
ruff check . --fix
ruff format .
ruff check . --fix
ruff format .
# Runs ruff, fixing all errors it can fix and formatting
ruff-unsafe:
ruff check . --fix --unsafe-fixes
ruff format .
ruff check . --fix --unsafe-fixes
ruff format .
# Runs mypy, using the config in pyproject.toml
mypy:
mypy scripts/invokeai-web.py
mypy scripts/invokeai-web.py
# Runs mypy, ignoring the config in pyproject.toml but still ignoring missing (untyped) imports
# (many files are ignored by the config, so this is useful for checking all files)
mypy-all:
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports
# Run the unit tests
test:
pytest ./tests
# Install the pnpm modules needed for the front end
frontend-install:
rm -rf invokeai/frontend/web/node_modules
cd invokeai/frontend/web && pnpm install
# Build the frontend
frontend-build:

View File

@ -28,7 +28,7 @@ model. These are the:
Hugging Face, as well as discriminating among model versions in
Civitai, but can be used for arbitrary content.
* _ModelLoadServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
* _ModelLoadServiceBase_
Responsible for loading a model from disk
into RAM and VRAM and getting it ready for inference.
@ -41,10 +41,10 @@ The four main services can be found in
* `invokeai/app/services/model_records/`
* `invokeai/app/services/model_install/`
* `invokeai/app/services/downloads/`
* `invokeai/app/services/model_loader/` (**under development**)
* `invokeai/app/services/model_load/`
Code related to the FastAPI web API can be found in
`invokeai/app/api/routers/model_records.py`.
`invokeai/app/api/routers/model_manager_v2.py`.
***
@ -84,10 +84,10 @@ diffusers model. When this happens, `original_hash` is unchanged, but
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
are defined in `invokeai.backend.model_manager.config`. They are also
imported by, and can be reexported from,
`invokeai.app.services.model_record_service`:
`invokeai.app.services.model_manager.model_records`:
```
from invokeai.app.services.model_record_service import ModelType, ModelFormat, BaseModelType
from invokeai.app.services.model_records import ModelType, ModelFormat, BaseModelType
```
The `path` field can be absolute or relative. If relative, it is taken
@ -123,7 +123,7 @@ taken to be the `models_dir` directory.
`variant` is an enumerated string class with values `normal`,
`inpaint` and `depth`. If needed, it can be imported if needed from
either `invokeai.app.services.model_record_service` or
either `invokeai.app.services.model_records` or
`invokeai.backend.model_manager.config`.
### ONNXSD2Config
@ -134,7 +134,7 @@ either `invokeai.app.services.model_record_service` or
| `upcast_attention` | bool | Model requires its attention module to be upcast |
The `SchedulerPredictionType` enum can be imported from either
`invokeai.app.services.model_record_service` or
`invokeai.app.services.model_records` or
`invokeai.backend.model_manager.config`.
### Other config classes
@ -157,15 +157,6 @@ indicates that the model is compatible with any of the base
models. This works OK for some models, such as the IP Adapter image
encoders, but is an all-or-nothing proposition.
Another issue is that the config class hierarchy is paralleled to some
extent by a `ModelBase` class hierarchy defined in
`invokeai.backend.model_manager.models.base` and its subclasses. These
are classes representing the models after they are loaded into RAM and
include runtime information such as load status and bytes used. Some
of the fields, including `name`, `model_type` and `base_model`, are
shared between `ModelConfigBase` and `ModelBase`, and this is a
potential source of confusion.
## Reading and Writing Model Configuration Records
The `ModelRecordService` provides the ability to retrieve model
@ -177,11 +168,11 @@ initialization and can be retrieved within an invocation from the
`InvocationContext` object:
```
store = context.services.model_record_store
store = context.services.model_manager.store
```
or from elsewhere in the code by accessing
`ApiDependencies.invoker.services.model_record_store`.
`ApiDependencies.invoker.services.model_manager.store`.
### Creating a `ModelRecordService`
@ -190,7 +181,7 @@ you can directly create either a `ModelRecordServiceSQL` or a
`ModelRecordServiceFile` object:
```
from invokeai.app.services.model_record_service import ModelRecordServiceSQL, ModelRecordServiceFile
from invokeai.app.services.model_records import ModelRecordServiceSQL, ModelRecordServiceFile
store = ModelRecordServiceSQL.from_connection(connection, lock)
store = ModelRecordServiceSQL.from_db_file('/path/to/sqlite_database.db')
@ -252,7 +243,7 @@ So a typical startup pattern would be:
```
import sqlite3
from invokeai.app.services.thread import lock
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
@ -260,19 +251,6 @@ db_conn = sqlite3.connect(config.db_path.as_posix(), check_same_thread=False)
store = ModelRecordServiceBase.open(config, db_conn, lock)
```
_A note on simultaneous access to `invokeai.db`_: The current InvokeAI
service architecture for the image and graph databases is careful to
use a shared sqlite3 connection and a thread lock to ensure that two
threads don't attempt to access the database simultaneously. However,
the default `sqlite3` library used by Python reports using
**Serialized** mode, which allows multiple threads to access the
database simultaneously using multiple database connections (see
https://www.sqlite.org/threadsafe.html and
https://ricardoanderegg.com/posts/python-sqlite-thread-safety/). Therefore
it should be safe to allow the record service to open its own SQLite
database connection. Opening a model record service should then be as
simple as `ModelRecordServiceBase.open(config)`.
### Fetching a Model's Configuration from `ModelRecordServiceBase`
Configurations can be retrieved in several ways.
@ -468,6 +446,44 @@ required parameters:
Once initialized, the installer will provide the following methods:
#### install_job = installer.heuristic_import(source, [config], [access_token])
This is a simplified interface to the installer which takes a source
string, an optional model configuration dictionary and an optional
access token.
The `source` is a string that can be any of these forms
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
3. A HuggingFace repo_id with any of the following formats:
- `model/name` -- entire model
- `model/name:fp32` -- entire model, using the fp32 variant
- `model/name:fp16:vae` -- vae submodel, using the fp16 variant
- `model/name::vae` -- vae submodel, using default precision
- `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
- `model/name::path/to/model.safetensors` -- an individual model file, default variant
Note that by specifying a relative path to the top of the HuggingFace
repo, you can download and install arbitrary models files.
The variant, if not provided, will be automatically filled in with
`fp32` if the user has requested full precision, and `fp16`
otherwise. If a variant that does not exist is requested, then the
method will install whatever HuggingFace returns as its default
revision.
`config` is an optional dict of values that will override the
autoprobed values for model type, base, scheduler prediction type, and
so forth. See [Model configuration and
probing](#Model-configuration-and-probing) for details.
`access_token` is an optional access token for accessing resources
that need authentication.
The method will return a `ModelInstallJob`. This object is discussed
at length in the following section.
#### install_job = installer.import_model()
The `import_model()` method is the core of the installer. The
@ -486,9 +502,10 @@ source2 = LocalModelSource(path='/opt/models/sushi_diffusers') # a local dif
source3 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5') # a repo_id
source4 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='vae') # a subfolder within a repo_id
source5 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', variant='fp16') # a named variant of a HF model
source6 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='OrangeMix/OrangeMix1.ckpt') # path to an individual model file
source6 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL
source7 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token
source7 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL
source8 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token
for source in [source1, source2, source3, source4, source5, source6, source7]:
install_job = installer.install_model(source)
@ -544,7 +561,6 @@ can be passed to `import_model()`.
attributes returned by the model prober. See the section below for
details.
#### LocalModelSource
This is used for a model that is located on a locally-accessible Posix
@ -737,7 +753,7 @@ and `cancelled`, as well as `in_terminal_state`. The last will return
True if the job is in the complete, errored or cancelled states.
#### Model confguration and probing
#### Model configuration and probing
The install service uses the `invokeai.backend.model_manager.probe`
module during import to determine the model's type, base type, and
@ -776,6 +792,14 @@ returns a list of completed jobs. The optional `timeout` argument will
return from the call if jobs aren't completed in the specified
time. An argument of 0 (the default) will block indefinitely.
#### jobs = installer.wait_for_job(job, [timeout])
Like `wait_for_installs()`, but block until a specific job has
completed or errored, and then return the job. The optional `timeout`
argument will return from the call if the job doesn't complete in the
specified time. An argument of 0 (the default) will block
indefinitely.
#### jobs = installer.list_jobs()
Return a list of all active and complete `ModelInstallJobs`.
@ -838,6 +862,31 @@ This method is similar to `unregister()`, but also unconditionally
deletes the corresponding model weights file(s), regardless of whether
they are inside or outside the InvokeAI models hierarchy.
#### path = installer.download_and_cache(remote_source, [access_token], [timeout])
This utility routine will download the model file located at source,
cache it, and return the path to the cached file. It does not attempt
to determine the model type, probe its configuration values, or
register it with the models database.
You may provide an access token if the remote source requires
authorization. The call will block indefinitely until the file is
completely downloaded, cancelled or raises an error of some sort. If
you provide a timeout (in seconds), the call will raise a
`TimeoutError` exception if the download hasn't completed in the
specified period.
You may use this mechanism to request any type of file, not just a
model. The file will be stored in a subdirectory of
`INVOKEAI_ROOT/models/.cache`. If the requested file is found in the
cache, its path will be returned without redownloading it.
Be aware that the models cache is cleared of infrequently-used files
and directories at regular intervals when the size of the cache
exceeds the value specified in Invoke's `convert_cache` configuration
variable.
#### List[str]=installer.scan_directory(scan_dir: Path, install: bool)
This method will recursively scan the directory indicated in
@ -1128,7 +1177,7 @@ job = queue.create_download_job(
event_handlers=[my_handler1, my_handler2], # if desired
start=True,
)
```
```
The `filename` argument forces the downloader to use the specified
name for the file rather than the name provided by the remote source,
@ -1171,6 +1220,13 @@ queue or was not created by this queue.
This method will block until all the active jobs in the queue have
reached a terminal state (completed, errored or cancelled).
#### queue.wait_for_job(job, [timeout])
This method will block until the indicated job has reached a terminal
state (completed, errored or cancelled). If the optional timeout is
provided, the call will block for at most timeout seconds, and raise a
TimeoutError otherwise.
#### jobs = queue.list_jobs()
This will return a list of all jobs, including ones that have not yet
@ -1449,9 +1505,9 @@ set of keys to the corresponding model config objects.
Find all model metadata records that have the given author and return
a set of keys to the corresponding model config objects.
# The remainder of this documentation is provisional, pending implementation of the Load service
***
## Let's get loaded, the lowdown on ModelLoadService
## The Lowdown on the ModelLoadService
The `ModelLoadService` is responsible for loading a named model into
memory so that it can be used for inference. Despite the fact that it
@ -1465,7 +1521,7 @@ create alternative instances if you wish.
### Creating a ModelLoadService object
The class is defined in
`invokeai.app.services.model_loader_service`. It is initialized with
`invokeai.app.services.model_load`. It is initialized with
an InvokeAIAppConfig object, from which it gets configuration
information such as the user's desired GPU and precision, and with a
previously-created `ModelRecordServiceBase` object, from which it
@ -1475,26 +1531,29 @@ Here is a typical initialization pattern:
```
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from invokeai.app.services.model_loader_service import ModelLoadService
from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegistry
config = InvokeAIAppConfig.get_config()
store = ModelRecordServiceBase.open(config)
loader = ModelLoadService(config, store)
ram_cache = ModelCache(
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
)
convert_cache = ModelConvertCache(
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
)
loader = ModelLoadService(
app_config=config,
ram_cache=ram_cache,
convert_cache=convert_cache,
registry=ModelLoaderRegistry
)
```
Note that we are relying on the contents of the application
configuration to choose the implementation of
`ModelRecordServiceBase`.
### load_model(model_config, [submodel_type], [context]) -> LoadedModel
### get_model(key, [submodel_type], [context]) -> ModelInfo:
*** TO DO: change to get_model(key, context=None, **kwargs)
The `get_model()` method, like its similarly-named cousin in
`ModelRecordService`, receives the unique key that identifies the
The `load_model()` method takes an `AnyModelConfig` returned by
`ModelRecordService.get_model()` and returns the corresponding loaded
model. It loads the model into memory, gets the model ready for use,
and returns a `ModelInfo` object.
and returns a `LoadedModel` object.
The optional second argument, `subtype` is a `SubModelType` string
enum, such as "vae". It is mandatory when used with a main model, and
@ -1504,46 +1563,45 @@ The optional third argument, `context` can be provided by
an invocation to trigger model load event reporting. See below for
details.
The returned `ModelInfo` object shares some fields in common with
`ModelConfigBase`, but is otherwise a completely different beast:
The returned `LoadedModel` object contains a copy of the configuration
record returned by the model record `get_model()` method, as well as
the in-memory loaded model:
| **Field Name** | **Type** | **Description** |
| **Attribute Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `key` | str | The model key derived from the ModelRecordService database |
| `name` | str | Name of this model |
| `base_model` | BaseModelType | Base model for this model |
| `type` | ModelType or SubModelType | Either the model type (non-main) or the submodel type (main models)|
| `location` | Path or str | Location of the model on the filesystem |
| `precision` | torch.dtype | The torch.precision to use for inference |
| `context` | ModelCache.ModelLocker | A context class used to lock the model in VRAM while in use |
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
| `model` | AnyModel | The instantiated model (details below) |
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
The types for `ModelInfo` and `SubModelType` can be imported from
`invokeai.app.services.model_loader_service`.
Because the loader can return multiple model types, it is typed to
return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious.
To use the model, you use the `ModelInfo` as a context manager using
the following pattern:
`LoadedModel` acts as a context manager. The context loads the model
into the execution device (e.g. VRAM on CUDA systems), locks the model
in the execution device for the duration of the context, and returns
the model. Use it like this:
```
model_info = loader.get_model('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with model_info as vae:
image = vae.decode(latents)[0]
```
The `vae` model will stay locked in the GPU during the period of time
it is in the context manager's scope.
`get_model_by_key()` may raise any of the following exceptions:
`get_model()` may raise any of the following exceptions:
- `UnknownModelException` -- key not in database
- `ModelNotFoundException` -- key in database but model not found at path
- `InvalidModelException` -- the model is guilty of a variety of sins
- `UnknownModelException` -- key not in database
- `ModelNotFoundException` -- key in database but model not found at path
- `NotImplementedException` -- the loader doesn't know how to load this type of model
** TO DO: ** Resolve discrepancy between ModelInfo.location and
ModelConfig.path.
### Emitting model loading events
When the `context` argument is passed to `get_model()`, it will
When the `context` argument is passed to `load_model_*()`, it will
retrieve the invocation event bus from the passed `InvocationContext`
object to emit events on the invocation bus. The two events are
"model_load_started" and "model_load_completed". Both carry the
@ -1556,10 +1614,174 @@ payload=dict(
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_key=model_key,
submodel=submodel,
submodel_type=submodel,
hash=model_info.hash,
location=str(model_info.location),
precision=str(model_info.precision),
)
```
### Adding Model Loaders
Model loaders are small classes that inherit from the `ModelLoader`
base class. They typically implement one method `_load_model()` whose
signature is:
```
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
```
`_load_model()` will be passed the path to the model on disk, an
optional repository variant (used by the diffusers loaders to select,
e.g. the `fp16` variant, and an optional submodel_type for main and
onnx models.
To install a new loader, place it in
`invokeai/backend/model_manager/load/model_loaders`. Inherit from
`ModelLoader` and use the `@ModelLoaderRegistry.register()` decorator to
indicate what type of models the loader can handle.
Here is a complete example from `generic_diffusers.py`, which is able
to load several different diffusers types:
```
from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager import (
AnyModel,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
SubModelType,
)
from .. import ModelLoader, ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
class GenericDiffusersLoader(ModelLoader):
"""Class to load simple diffusers models."""
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
model_class = self._get_hf_load_class(model_path)
if submodel_type is not None:
raise Exception(f"There are no submodels in models of type {model_class}")
variant = model_variant.value if model_variant else None
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore
return result
```
Note that a loader can register itself to handle several different
model types. An exception will be raised if more than one loader tries
to register the same model type.
#### Conversion
Some models require conversion to diffusers format before they can be
loaded. These loaders should override two additional methods:
```
_needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool
_convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
```
The first method accepts the model configuration, the path to where
the unmodified model is currently installed, and a proposed
destination for the converted model. This method returns True if the
model needs to be converted. It typically does this by comparing the
last modification time of the original model file to the modification
time of the converted model. In some cases you will also want to check
the modification date of the configuration record, in the event that
the user has changed something like the scheduler prediction type that
will require the model to be re-converted. See `controlnet.py` for an
example of this logic.
The second method accepts the model configuration, the path to the
original model on disk, and the desired output path for the converted
model. It does whatever it needs to do to get the model into diffusers
format, and returns the Path of the resulting model. (The path should
ordinarily be the same as `output_path`.)
## The ModelManagerService object
For convenience, the API provides a `ModelManagerService` object which
gives a single point of access to the major model manager
services. This object is created at initialization time and can be
found in the global `ApiDependencies.invoker.services.model_manager`
object, or in `context.services.model_manager` from within an
invocation.
In the examples below, we have retrieved the manager using:
```
mm = ApiDependencies.invoker.services.model_manager
```
The following properties and methods will be available:
### mm.store
This retrieves the `ModelRecordService` associated with the
manager. Example:
```
configs = mm.store.get_model_by_attr(name='stable-diffusion-v1-5')
```
### mm.install
This retrieves the `ModelInstallService` associated with the manager.
Example:
```
job = mm.install.heuristic_import(`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
```
### mm.load
This retrieves the `ModelLoaderService` associated with the manager. Example:
```
configs = mm.store.get_model_by_attr(name='stable-diffusion-v1-5')
assert len(configs) > 0
loaded_model = mm.load.load_model(configs[0])
```
The model manager also offers a few convenience shortcuts for loading
models:
### mm.load_model_by_config(model_config, [submodel], [context]) -> LoadedModel
Same as `mm.load.load_model()`.
### mm.load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel
This accepts the combination of the model's name, type and base, which
it passes to the model record config store for retrieval. If a unique
model config is found, this method returns a `LoadedModel`. It can
raise the following exceptions:
```
UnknownModelException -- model with these attributes not known
NotImplementedException -- the loader doesn't know how to load this type of model
ValueError -- more than one model matches this combination of base/type/name
```
### mm.load_model_by_key(key, [submodel], [context]) -> LoadedModel
This method takes a model key, looks it up using the
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
model configuration to `load_model_by_config()`. It may raise a
`NotImplementedException`.

View File

@ -0,0 +1,45 @@
# Invocation API
Each invocation's `invoke` method is provided a single arg - the Invocation
Context.
This object provides access to various methods, used to interact with the
application. Loading and saving images, logging messages, etc.
!!! warning ""
This API may shift slightly until the release of v4.0.0 as we work through a few final updates to the Model Manager.
```py
class MyInvocation(BaseInvocation):
...
def invoke(self, context: InvocationContext) -> ImageOutput:
image_pil = context.images.get_pil(image_name)
# Do something to the image
image_dto = context.images.save(image_pil)
# Log a message
context.logger.info(f"Did something cool, image saved!")
...
```
<!-- prettier-ignore-start -->
::: invokeai.app.services.shared.invocation_context.InvocationContext
options:
members: false
::: invokeai.app.services.shared.invocation_context.ImagesInterface
::: invokeai.app.services.shared.invocation_context.TensorsInterface
::: invokeai.app.services.shared.invocation_context.ConditioningInterface
::: invokeai.app.services.shared.invocation_context.ModelsInterface
::: invokeai.app.services.shared.invocation_context.LoggerInterface
::: invokeai.app.services.shared.invocation_context.ConfigInterface
::: invokeai.app.services.shared.invocation_context.UtilInterface
::: invokeai.app.services.shared.invocation_context.BoardsInterface
<!-- prettier-ignore-end -->

View File

@ -0,0 +1,148 @@
# Invoke v4.0.0 Nodes API Migration guide
Invoke v4.0.0 is versioned as such due to breaking changes to the API utilized
by nodes, both core and custom.
## Motivation
Prior to v4.0.0, the `invokeai` python package has not be set up to be utilized
as a library. That is to say, it didn't have any explicitly public API, and node
authors had to work with the unstable internal application API.
v4.0.0 introduces a stable public API for nodes.
## Changes
There are two node-author-facing changes:
1. Import Paths
1. Invocation Context API
### Import Paths
All public objects are now exported from `invokeai.invocation_api`:
```py
# Old
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
InputField,
InvocationContext,
invocation,
)
from invokeai.app.invocations.primitives import ImageField
# New
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
InvocationContext,
invocation,
)
```
It's possible that we've missed some classes you need in your node. Please let
us know if that's the case.
### Invocation Context API
Most nodes utilize the Invocation Context, an object that is passed to the
`invoke` that provides access to data and services a node may need.
Until now, that object and the services it exposed were internal. Exposing them
to nodes means that changes to our internal implementation could break nodes.
The methods on the services are also often fairly complicated and allowed nodes
to footgun.
In v4.0.0, this object has been refactored to be much simpler.
See [INVOCATION_API](./INVOCATION_API.md) for full details of the API.
!!! warning ""
This API may shift slightly until the release of v4.0.0 as we work through a few final updates to the Model Manager.
#### Improved Service Methods
The biggest offender was the image save method:
```py
# Old
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=context.workflow,
)
# New
image_dto = context.images.save(image=image)
```
Other methods are simplified, or enhanced with additional functionality:
```py
# Old
image = context.services.images.get_pil_image(image_name)
# New
image = context.images.get_pil(image_name)
image_cmyk = context.images.get_pil(image_name, "CMYK")
```
We also had some typing issues around tensors:
```py
# Old
# `latents` typed as `torch.Tensor`, but could be `ConditioningFieldData`
latents = context.services.latents.get(self.latents.latents_name)
# `data` typed as `torch.Tenssor,` but could be `ConditioningFieldData`
context.services.latents.save(latents_name, data)
# New - separate methods for tensors and conditioning data w/ correct typing
# Also, the service generates the names
tensor_name = context.tensors.save(tensor)
tensor = context.tensors.load(tensor_name)
# For conditioning
cond_name = context.conditioning.save(cond_data)
cond_data = context.conditioning.load(cond_name)
```
#### Output Construction
Core Outputs have builder functions right on them - no need to manually
construct these objects, or use an extra utility:
```py
# Old
image_output = ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
latents_output = build_latents_output(latents_name=name, latents=latents, seed=None)
noise_output = NoiseOutput(
noise=LatentsField(latents_name=latents_name, seed=seed),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
cond_output = ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
# New
image_output = ImageOutput.build(image_dto)
latents_output = LatentsOutput.build(latents_name=name, latents=noise, seed=self.seed)
noise_output = NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed)
cond_output = ConditioningOutput.build(conditioning_name)
```
You can still create the objects using constructors if you want, but we suggest
using the builder methods.

View File

@ -1,5 +0,0 @@
mkdocs
mkdocs-material>=8, <9
mkdocs-git-revision-date-localized-plugin
mkdocs-redirects==1.2.0

View File

@ -1,5 +0,0 @@
:root {
--md-primary-fg-color: #35A4DB;
--md-primary-fg-color--light: #35A4DB;
--md-primary-fg-color--dark: #35A4DB;
}

View File

@ -4,11 +4,9 @@ from logging import Logger
import torch
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@ -17,24 +15,22 @@ from ..services.board_image_records.board_image_records_sqlite import SqliteBoar
from ..services.board_images.board_images_default import BoardImagesService
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from ..services.boards.boards_default import BoardService
from ..services.bulk_download.bulk_download_default import BulkDownloadService
from ..services.config import InvokeAIAppConfig
from ..services.download import DownloadQueueService
from ..services.image_files.image_files_disk import DiskImageFileStorage
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
from ..services.images.images_default import ImageService
from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from ..services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
from ..services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker
from ..services.model_install import ModelInstallService
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_metadata import ModelMetadataStoreSQL
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.shared.graph import GraphExecutionState
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
@ -86,7 +82,7 @@ class ApiDependencies:
board_records = SqliteBoardRecordStorage(db=db)
boards = BoardService()
events = FastAPIEventService(event_handler_id)
graph_execution_manager = ItemStorageMemory[GraphExecutionState]()
bulk_download = BulkDownloadService()
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
@ -96,21 +92,16 @@ class ApiDependencies:
conditioning = ObjectSerializerForwardCache(
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
)
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)
download_queue_service = DownloadQueueService(event_bus=events)
metadata_store = ModelMetadataStore(db=db)
model_install_service = ModelInstallService(
app_config=config,
record_store=model_record_service,
model_metadata_service = ModelMetadataStoreSQL(db=db)
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
download_queue=download_queue_service,
metadata_store=metadata_store,
event_bus=events,
events=events,
)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor()
queue = MemoryInvocationQueue()
session_processor = DefaultSessionProcessor()
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()
@ -121,22 +112,18 @@ class ApiDependencies:
board_images=board_images,
board_records=board_records,
boards=boards,
bulk_download=bulk_download,
configuration=configuration,
events=events,
graph_execution_manager=graph_execution_manager,
image_files=image_files,
image_records=image_records,
images=images,
invocation_cache=invocation_cache,
logger=logger,
model_manager=model_manager,
model_records=model_record_service,
download_queue=download_queue_service,
model_install=model_install_service,
names=names,
performance_statistics=performance_statistics,
processor=processor,
queue=queue,
session_processor=session_processor,
session_queue=session_queue,
urls=urls,

View File

@ -36,7 +36,7 @@ async def list_downloads() -> List[DownloadJob]:
400: {"description": "Bad request"},
},
)
async def prune_downloads():
async def prune_downloads() -> Response:
"""Prune completed and errored jobs."""
queue = ApiDependencies.invoker.services.download_queue
queue.prune_jobs()
@ -55,7 +55,7 @@ async def download(
) -> DownloadJob:
"""Download the source URL to the file or directory indicted in dest."""
queue = ApiDependencies.invoker.services.download_queue
return queue.download(source, dest, priority, access_token)
return queue.download(source, Path(dest), priority, access_token)
@download_queue_router.get(
@ -87,7 +87,7 @@ async def get_download_job(
)
async def cancel_download_job(
id: int = Path(description="ID of the download job to cancel."),
):
) -> Response:
"""Cancel a download job using its ID."""
try:
queue = ApiDependencies.invoker.services.download_queue
@ -105,7 +105,7 @@ async def cancel_download_job(
204: {"description": "Download jobs have been cancelled"},
},
)
async def cancel_all_download_jobs():
async def cancel_all_download_jobs() -> Response:
"""Cancel all download jobs."""
ApiDependencies.invoker.services.download_queue.cancel_all_jobs()
return Response(status_code=204)

View File

@ -2,7 +2,7 @@ import io
import traceback
from typing import Optional
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
@ -375,16 +375,67 @@ async def unstar_images_in_list(
class ImagesDownloaded(BaseModel):
response: Optional[str] = Field(
description="If defined, the message to display to the user when images begin downloading"
default=None, description="The message to display to the user when images begin downloading"
)
bulk_download_item_name: Optional[str] = Field(
default=None, description="The name of the bulk download item for which events will be emitted"
)
@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded)
@images_router.post(
"/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202
)
async def download_images_from_list(
image_names: list[str] = Body(description="The list of names of images to download", embed=True),
background_tasks: BackgroundTasks,
image_names: Optional[list[str]] = Body(
default=None, description="The list of names of images to download", embed=True
),
board_id: Optional[str] = Body(
default=None, description="The board from which image should be downloaded from", embed=True
default=None, description="The board from which image should be downloaded", embed=True
),
) -> ImagesDownloaded:
# return ImagesDownloaded(response="Your images are downloading")
raise HTTPException(status_code=501, detail="Endpoint is not yet implemented")
if (image_names is None or len(image_names) == 0) and board_id is None:
raise HTTPException(status_code=400, detail="No images or board id specified.")
bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id)
background_tasks.add_task(
ApiDependencies.invoker.services.bulk_download.handler,
image_names,
board_id,
bulk_download_item_id,
)
return ImagesDownloaded(bulk_download_item_name=bulk_download_item_id + ".zip")
@images_router.api_route(
"/download/{bulk_download_item_name}",
methods=["GET"],
operation_id="get_bulk_download_item",
response_class=Response,
responses={
200: {
"description": "Return the complete bulk download item",
"content": {"application/zip": {}},
},
404: {"description": "Image not found"},
},
)
async def get_bulk_download_item(
background_tasks: BackgroundTasks,
bulk_download_item_name: str = Path(description="The bulk_download_item_name of the bulk download item to get"),
) -> FileResponse:
"""Gets a bulk download zip file"""
try:
path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name)
response = FileResponse(
path,
media_type="application/zip",
filename=bulk_download_item_name,
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.delete, bulk_download_item_name)
return response
except Exception:
raise HTTPException(status_code=404)

View File

@ -0,0 +1,751 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records."""
import pathlib
import shutil
from hashlib import sha1
from random import randbytes
from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
ModelRecordOrderBy,
ModelSummary,
UnknownModelException,
)
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
MainCheckpointConfig,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
class ModelsList(BaseModel):
"""Return list of configs."""
models: List[AnyModelConfig]
model_config = ConfigDict(use_enum_values=True)
class ModelTagSet(BaseModel):
"""Return tags for a set of models."""
key: str
name: str
author: str
tags: Set[str]
##############################################################################
# These are example inputs and outputs that are used in places where Swagger
# is unable to generate a correct example.
##############################################################################
example_model_config = {
"path": "string",
"name": "string",
"base": "sd-1",
"type": "main",
"format": "checkpoint",
"config": "string",
"key": "string",
"original_hash": "string",
"current_hash": "string",
"description": "string",
"source": "string",
"last_modified": 0,
"vae": "string",
"variant": "normal",
"prediction_type": "epsilon",
"repo_variant": "fp16",
"upcast_attention": False,
"ztsnr_training": False,
}
example_model_input = {
"path": "/path/to/model",
"name": "model_name",
"base": "sd-1",
"type": "main",
"format": "checkpoint",
"config": "configs/stable-diffusion/v1-inference.yaml",
"description": "Model description",
"vae": None,
"variant": "normal",
}
example_model_metadata = {
"name": "ip_adapter_sd_image_encoder",
"author": "InvokeAI",
"tags": [
"transformers",
"safetensors",
"clip_vision_model",
"endpoints_compatible",
"region:us",
"has_space",
"license:apache-2.0",
],
"files": [
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
"path": "ip_adapter_sd_image_encoder/README.md",
"size": 628,
"sha256": None,
},
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
"path": "ip_adapter_sd_image_encoder/config.json",
"size": 560,
"sha256": None,
},
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
"path": "ip_adapter_sd_image_encoder/model.safetensors",
"size": 2528373448,
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
},
],
"type": "huggingface",
"id": "InvokeAI/ip_adapter_sd_image_encoder",
"tag_dict": {"license": "apache-2.0"},
"last_modified": "2023-09-23T17:33:25Z",
}
##############################################################################
# ROUTES
##############################################################################
@model_manager_router.get(
"/",
operation_id="list_model_records",
)
async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
model_format: Optional[ModelFormat] = Query(
default=None, description="Exact match on the format of the model (e.g. 'diffusers')"
),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_manager.store
found_models: list[AnyModelConfig] = []
if base_models:
for base_model in base_models:
found_models.extend(
record_store.search_by_attr(
base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format
)
)
else:
found_models.extend(
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
)
return ModelsList(models=found_models)
@model_manager_router.get(
"/get_by_attrs",
operation_id="get_model_records_by_attrs",
response_model=AnyModelConfig,
)
async def get_model_records_by_attrs(
name: str = Query(description="The name of the model"),
type: ModelType = Query(description="The type of the model"),
base: BaseModelType = Query(description="The base model of the model"),
) -> AnyModelConfig:
"""Gets a model by its attributes. The main use of this route is to provide backwards compatibility with the old
model manager, which identified models by a combination of name, base and type."""
configs = ApiDependencies.invoker.services.model_manager.store.search_by_attr(
base_model=base, model_type=type, model_name=name
)
if not configs:
raise HTTPException(status_code=404, detail="No model found with these attributes")
return configs[0]
@model_manager_router.get(
"/i/{key}",
operation_id="get_model_record",
responses={
200: {
"description": "The model configuration was retrieved successfully",
"content": {"application/json": {"example": example_model_config}},
},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
},
)
async def get_model_record(
key: str = Path(description="Key of the model record to fetch."),
) -> AnyModelConfig:
"""Get a model record"""
record_store = ApiDependencies.invoker.services.model_manager.store
try:
config: AnyModelConfig = record_store.get_model(key)
return config
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@model_manager_router.get("/summary", operation_id="list_model_summary")
async def list_model_summary(
page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of models per page"),
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
) -> PaginatedResults[ModelSummary]:
"""Gets a page of model summary data."""
record_store = ApiDependencies.invoker.services.model_manager.store
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
return results
@model_manager_router.get(
"/i/{key}/metadata",
operation_id="get_model_metadata",
responses={
200: {
"description": "The model metadata was retrieved successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
},
)
async def get_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
) -> Optional[AnyModelRepoMetadata]:
"""Get a model metadata object."""
record_store = ApiDependencies.invoker.services.model_manager.store
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
return result
@model_manager_router.get(
"/tags",
operation_id="list_tags",
)
async def list_tags() -> Set[str]:
"""Get a unique set of all the model tags."""
record_store = ApiDependencies.invoker.services.model_manager.store
result: Set[str] = record_store.list_tags()
return result
class FoundModel(BaseModel):
path: str = Field(description="Path to the model")
is_installed: bool = Field(description="Whether or not the model is already installed")
@model_manager_router.get(
"/scan_folder",
operation_id="scan_for_models",
responses={
200: {"description": "Directory scanned successfully"},
400: {"description": "Invalid directory path"},
},
status_code=200,
response_model=List[FoundModel],
)
async def scan_for_models(
scan_path: str = Query(description="Directory path to search for models", default=None),
) -> List[FoundModel]:
path = pathlib.Path(scan_path)
if not scan_path or not path.is_dir():
raise HTTPException(
status_code=400,
detail=f"The search path '{scan_path}' does not exist or is not directory",
)
search = ModelSearch()
try:
found_model_paths = search.search(path)
models_path = ApiDependencies.invoker.services.configuration.models_path
# If the search path includes the main models directory, we need to exclude core models from the list.
# TODO(MM2): Core models should be handled by the model manager so we can determine if they are installed
# without needing to crawl the filesystem.
core_models_path = pathlib.Path(models_path, "core").resolve()
non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)]
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
resolved_installed_model_paths: list[str] = []
installed_model_sources: list[str] = []
# This call lists all installed models.
for model in installed_models:
path = pathlib.Path(model.path)
# If the model has a source, we need to add it to the list of installed sources.
if model.source:
installed_model_sources.append(model.source)
# If the path is not absolute, that means it is in the app models directory, and we need to join it with
# the models path before resolving.
if not path.is_absolute():
resolved_installed_model_paths.append(str(pathlib.Path(models_path, path).resolve()))
continue
resolved_installed_model_paths.append(str(path.resolve()))
scan_results: list[FoundModel] = []
# Check if the model is installed by comparing the resolved paths, appending to the scan result.
for p in non_core_model_paths:
path = str(p)
is_installed = path in resolved_installed_model_paths or path in installed_model_sources
found_model = FoundModel(path=path, is_installed=is_installed)
scan_results.append(found_model)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An error occurred while searching the directory: {e}",
)
return scan_results
@model_manager_router.get(
"/tags/search",
operation_id="search_by_metadata_tags",
)
async def search_by_metadata_tags(
tags: Set[str] = Query(default=None, description="Tags to search for"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_manager.store
results = record_store.search_by_metadata_tag(tags)
return ModelsList(models=results)
@model_manager_router.patch(
"/i/{key}",
operation_id="update_model_record",
responses={
200: {
"description": "The model was updated successfully",
"content": {"application/json": {"example": example_model_config}},
},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
409: {"description": "There is already a model corresponding to the new name"},
},
status_code=200,
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
try:
model_response: AnyModelConfig = record_store.update_model(key, config=info)
logger.info(f"Updated model: {key}")
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return model_response
@model_manager_router.delete(
"/i/{key}",
operation_id="del_model_record",
responses={
204: {"description": "Model deleted successfully"},
404: {"description": "Model not found"},
},
status_code=204,
)
async def del_model_record(
key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response:
"""
Delete model record from database.
The configuration record will be removed. The corresponding weights files will be
deleted as well if they reside within the InvokeAI "models" directory.
"""
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_manager.install
installer.delete(key)
logger.info(f"Deleted model: {key}")
return Response(status_code=204)
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@model_manager_router.post(
"/i/",
operation_id="add_model_record",
responses={
201: {
"description": "The model added successfully",
"content": {"application/json": {"example": example_model_config}},
},
409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"},
},
status_code=201,
)
async def add_model_record(
config: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
) -> AnyModelConfig:
"""Add a model using the configuration information appropriate for its type."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
if config.key == "<NOKEY>":
config.key = sha1(randbytes(100)).hexdigest()
logger.info(f"Created model {config.key} for {config.name}")
try:
record_store.add_model(config.key, config)
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
# now fetch it out
result: AnyModelConfig = record_store.get_model(config.key)
return result
@model_manager_router.post(
"/install",
operation_id="install_model",
responses={
201: {"description": "The model imported successfully"},
415: {"description": "Unrecognized file/folder format"},
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
)
async def install_model(
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
# TODO(MM2): Can we type this?
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
example={"name": "string", "description": "string"},
),
access_token: Optional[str] = None,
) -> ModelInstallJob:
"""Install a model using a string identifier.
`source` can be any of the following.
1. A path on the local filesystem ('C:\\users\\fred\\model.safetensors')
2. A Url pointing to a single downloadable model file
3. A HuggingFace repo_id with any of the following formats:
- model/name
- model/name:fp16:vae
- model/name::vae -- use default precision
- model/name:fp16:path/to/model.safetensors
- model/name::path/to/model.safetensors
`config` is an optional dict containing model configuration values that will override
the ones that are probed automatically.
`access_token` is an optional access token for use with Urls that require
authentication.
Models will be downloaded, probed, configured and installed in a
series of background threads. The return object has `status` attribute
that can be used to monitor progress.
See the documentation for `import_model_record` for more information on
interpreting the job information returned by this route.
"""
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_manager.install
result: ModelInstallJob = installer.heuristic_import(
source=source,
config=config,
access_token=access_token,
)
logger.info(f"Started installation of {source}")
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return result
@model_manager_router.get(
"/import",
operation_id="list_model_install_jobs",
)
async def list_model_install_jobs() -> List[ModelInstallJob]:
"""Return the list of model install jobs.
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
the nature of the job and its progress. The `status` is one of:
* "waiting" -- Job is waiting in the queue to run
* "downloading" -- Model file(s) are downloading
* "running" -- Model has downloaded and the model probing and registration process is running
* "completed" -- Installation completed successfully
* "error" -- An error occurred. Details will be in the "error_type" and "error" fields.
* "cancelled" -- Job was cancelled before completion.
Once completed, information about the model such as its size, base
model, type, and metadata can be retrieved from the `config_out`
field. For multi-file models such as diffusers, information on individual files
can be retrieved from `download_parts`.
See the example and schema below for more information.
"""
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs()
return jobs
@model_manager_router.get(
"/import/{id}",
operation_id="get_model_install_job",
responses={
200: {"description": "Success"},
404: {"description": "No such job"},
},
)
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
"""
Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs'
for information on the format of the return value.
"""
try:
result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id)
return result
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@model_manager_router.delete(
"/import/{id}",
operation_id="cancel_model_install_job",
responses={
201: {"description": "The job was cancelled successfully"},
415: {"description": "No such job"},
},
status_code=201,
)
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
"""Cancel the model install job(s) corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
job = installer.get_job_by_id(id)
except ValueError as e:
raise HTTPException(status_code=415, detail=str(e))
installer.cancel_job(job)
@model_manager_router.patch(
"/import",
operation_id="prune_model_install_jobs",
responses={
204: {"description": "All completed and errored jobs have been pruned"},
400: {"description": "Bad request"},
},
)
async def prune_model_install_jobs() -> Response:
"""Prune all completed and errored jobs from the install job list."""
ApiDependencies.invoker.services.model_manager.install.prune_jobs()
return Response(status_code=204)
@model_manager_router.patch(
"/sync",
operation_id="sync_models_to_config",
responses={
204: {"description": "Model config record database resynced with files on disk"},
400: {"description": "Bad request"},
},
)
async def sync_models_to_config() -> Response:
"""
Traverse the models and autoimport directories.
Model files without a corresponding
record in the database are added. Orphan records without a models file are deleted.
"""
ApiDependencies.invoker.services.model_manager.install.sync_to_config()
return Response(status_code=204)
@model_manager_router.put(
"/convert/{key}",
operation_id="convert_model",
responses={
200: {
"description": "Model converted successfully",
"content": {"application/json": {"example": example_model_config}},
},
400: {"description": "Bad request"},
404: {"description": "Model not found"},
409: {"description": "There is already a model registered at this location"},
},
)
async def convert_model(
key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
) -> AnyModelConfig:
"""
Permanently convert a model into diffusers format, replacing the safetensors version.
Note that during the conversion process the key and model hash will change.
The return value is the model configuration for the converted model.
"""
model_manager = ApiDependencies.invoker.services.model_manager
logger = ApiDependencies.invoker.services.logger
loader = ApiDependencies.invoker.services.model_manager.load
store = ApiDependencies.invoker.services.model_manager.store
installer = ApiDependencies.invoker.services.model_manager.install
try:
model_config = store.get_model(key)
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))
if not isinstance(model_config, MainCheckpointConfig):
logger.error(f"The model with key {key} is not a main checkpoint model.")
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
# loading the model will convert it into a cached diffusers file
model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
# Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key)
assert cache_path.exists()
# temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name
model_config.name = f"{original_name}.DELETE"
store.update_model(key, config=model_config)
# install the diffusers
try:
new_key = installer.install_path(
cache_path,
config={
"name": original_name,
"description": model_config.description,
"original_hash": model_config.original_hash,
"source": model_config.source,
},
)
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
# get the original metadata
if orig_metadata := store.get_metadata(key):
store.metadata_store.add_metadata(new_key, orig_metadata)
# delete the original safetensors file
installer.delete(key)
# delete the cached version
shutil.rmtree(cache_path)
# return the config record for the new diffusers directory
new_config: AnyModelConfig = store.get_model(new_key)
return new_config
@model_manager_router.put(
"/merge",
operation_id="merge",
responses={
200: {
"description": "Model converted successfully",
"content": {"application/json": {"example": example_model_config}},
},
400: {"description": "Bad request"},
404: {"description": "Model not found"},
409: {"description": "There is already a model registered at this location"},
},
)
async def merge(
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
force: bool = Body(
description="Force merging of models created with different versions of diffusers",
default=False,
),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
merge_dest_directory: Optional[str] = Body(
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
default=None,
),
) -> AnyModelConfig:
"""
Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
```
Argument Description [default]
-------- ----------------------
keys List of 2-3 model keys to merge together. All models must use the same base type.
merged_model_name Name for the merged model [Concat model names]
alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
merge_dest_directory Specify a directory to store the merged model in [models directory]
```
"""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
installer = ApiDependencies.invoker.services.model_manager.install
merger = ModelMerger(installer)
model_names = [installer.record_store.get_model(x).name for x in keys]
response = merger.merge_diffusion_models_and_save(
model_keys=keys,
merged_model_name=merged_model_name or "+".join(model_names),
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory=dest,
)
except UnknownModelException:
raise HTTPException(
status_code=404,
detail=f"One or more of the models '{keys}' not found",
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response

View File

@ -1,472 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records."""
import pathlib
from hashlib import sha1
from random import randbytes
from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
ModelRecordOrderBy,
ModelSummary,
UnknownModelException,
)
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"])
class ModelsList(BaseModel):
"""Return list of configs."""
models: List[AnyModelConfig]
model_config = ConfigDict(use_enum_values=True)
class ModelTagSet(BaseModel):
"""Return tags for a set of models."""
key: str
name: str
author: str
tags: Set[str]
@model_records_router.get(
"/",
operation_id="list_model_records",
)
async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
model_format: Optional[ModelFormat] = Query(
default=None, description="Exact match on the format of the model (e.g. 'diffusers')"
),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
found_models: list[AnyModelConfig] = []
if base_models:
for base_model in base_models:
found_models.extend(
record_store.search_by_attr(
base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format
)
)
else:
found_models.extend(
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
)
return ModelsList(models=found_models)
@model_records_router.get(
"/i/{key}",
operation_id="get_model_record",
responses={
200: {"description": "Success"},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
},
)
async def get_model_record(
key: str = Path(description="Key of the model record to fetch."),
) -> AnyModelConfig:
"""Get a model record"""
record_store = ApiDependencies.invoker.services.model_records
try:
return record_store.get_model(key)
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.get("/meta", operation_id="list_model_summary")
async def list_model_summary(
page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of models per page"),
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
) -> PaginatedResults[ModelSummary]:
"""Gets a page of model summary data."""
return ApiDependencies.invoker.services.model_records.list_models(page=page, per_page=per_page, order_by=order_by)
@model_records_router.get(
"/meta/i/{key}",
operation_id="get_model_metadata",
responses={
200: {"description": "Success"},
400: {"description": "Bad request"},
404: {"description": "No metadata available"},
},
)
async def get_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
) -> Optional[AnyModelRepoMetadata]:
"""Get a model metadata object."""
record_store = ApiDependencies.invoker.services.model_records
result = record_store.get_metadata(key)
if not result:
raise HTTPException(status_code=404, detail="No metadata for a model with this key")
return result
@model_records_router.get(
"/tags",
operation_id="list_tags",
)
async def list_tags() -> Set[str]:
"""Get a unique set of all the model tags."""
record_store = ApiDependencies.invoker.services.model_records
return record_store.list_tags()
@model_records_router.get(
"/tags/search",
operation_id="search_by_metadata_tags",
)
async def search_by_metadata_tags(
tags: Set[str] = Query(default=None, description="Tags to search for"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
results = record_store.search_by_metadata_tag(tags)
return ModelsList(models=results)
@model_records_router.patch(
"/i/{key}",
operation_id="update_model_record",
responses={
200: {"description": "The model was updated successfully"},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
409: {"description": "There is already a model corresponding to the new name"},
},
status_code=200,
response_model=AnyModelConfig,
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
try:
model_response = record_store.update_model(key, config=info)
logger.info(f"Updated model: {key}")
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return model_response
@model_records_router.delete(
"/i/{key}",
operation_id="del_model_record",
responses={
204: {"description": "Model deleted successfully"},
404: {"description": "Model not found"},
},
status_code=204,
)
async def del_model_record(
key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response:
"""
Delete model record from database.
The configuration record will be removed. The corresponding weights files will be
deleted as well if they reside within the InvokeAI "models" directory.
"""
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_install
installer.delete(key)
logger.info(f"Deleted model: {key}")
return Response(status_code=204)
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.post(
"/i/",
operation_id="add_model_record",
responses={
201: {"description": "The model added successfully"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"},
},
status_code=201,
)
async def add_model_record(
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
) -> AnyModelConfig:
"""Add a model using the configuration information appropriate for its type."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
if config.key == "<NOKEY>":
config.key = sha1(randbytes(100)).hexdigest()
logger.info(f"Created model {config.key} for {config.name}")
try:
record_store.add_model(config.key, config)
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
# now fetch it out
return record_store.get_model(config.key)
@model_records_router.post(
"/import",
operation_id="import_model_record",
responses={
201: {"description": "The model imported successfully"},
415: {"description": "Unrecognized file/folder format"},
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
)
async def import_model(
source: ModelSource,
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
),
) -> ModelInstallJob:
"""Add a model using its local path, repo_id, or remote URL.
Models will be downloaded, probed, configured and installed in a
series of background threads. The return object has `status` attribute
that can be used to monitor progress.
The source object is a discriminated Union of LocalModelSource,
HFModelSource and URLModelSource. Set the "type" field to the
appropriate value:
* To install a local path using LocalModelSource, pass a source of form:
`{
"type": "local",
"path": "/path/to/model",
"inplace": false
}`
The "inplace" flag, if true, will register the model in place in its
current filesystem location. Otherwise, the model will be copied
into the InvokeAI models directory.
* To install a HuggingFace repo_id using HFModelSource, pass a source of form:
`{
"type": "hf",
"repo_id": "stabilityai/stable-diffusion-2.0",
"variant": "fp16",
"subfolder": "vae",
"access_token": "f5820a918aaf01"
}`
The `variant`, `subfolder` and `access_token` fields are optional.
* To install a remote model using an arbitrary URL, pass:
`{
"type": "url",
"url": "http://www.civitai.com/models/123456",
"access_token": "f5820a918aaf01"
}`
The `access_token` field is optonal
The model's configuration record will be probed and filled in
automatically. To override the default guesses, pass "metadata"
with a Dict containing the attributes you wish to override.
Installation occurs in the background. Either use list_model_install_jobs()
to poll for completion, or listen on the event bus for the following events:
"model_install_running"
"model_install_completed"
"model_install_error"
On successful completion, the event's payload will contain the field "key"
containing the installed ID of the model. On an error, the event's payload
will contain the fields "error_type" and "error" describing the nature of the
error and its traceback, respectively.
"""
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_install
result: ModelInstallJob = installer.import_model(
source=source,
config=config,
)
logger.info(f"Started installation of {source}")
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return result
@model_records_router.get(
"/import",
operation_id="list_model_install_jobs",
)
async def list_model_install_jobs() -> List[ModelInstallJob]:
"""Return list of model install jobs."""
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs()
return jobs
@model_records_router.get(
"/import/{id}",
operation_id="get_model_install_job",
responses={
200: {"description": "Success"},
404: {"description": "No such job"},
},
)
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
"""Return model install job corresponding to the given source."""
try:
return ApiDependencies.invoker.services.model_install.get_job_by_id(id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.delete(
"/import/{id}",
operation_id="cancel_model_install_job",
responses={
201: {"description": "The job was cancelled successfully"},
415: {"description": "No such job"},
},
status_code=201,
)
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
"""Cancel the model install job(s) corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_install
try:
job = installer.get_job_by_id(id)
except ValueError as e:
raise HTTPException(status_code=415, detail=str(e))
installer.cancel_job(job)
@model_records_router.patch(
"/import",
operation_id="prune_model_install_jobs",
responses={
204: {"description": "All completed and errored jobs have been pruned"},
400: {"description": "Bad request"},
},
)
async def prune_model_install_jobs() -> Response:
"""Prune all completed and errored jobs from the install job list."""
ApiDependencies.invoker.services.model_install.prune_jobs()
return Response(status_code=204)
@model_records_router.patch(
"/sync",
operation_id="sync_models_to_config",
responses={
204: {"description": "Model config record database resynced with files on disk"},
400: {"description": "Bad request"},
},
)
async def sync_models_to_config() -> Response:
"""
Traverse the models and autoimport directories.
Model files without a corresponding
record in the database are added. Orphan records without a models file are deleted.
"""
ApiDependencies.invoker.services.model_install.sync_to_config()
return Response(status_code=204)
@model_records_router.put(
"/merge",
operation_id="merge",
)
async def merge(
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
force: bool = Body(
description="Force merging of models created with different versions of diffusers",
default=False,
),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
merge_dest_directory: Optional[str] = Body(
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
default=None,
),
) -> AnyModelConfig:
"""
Merge diffusers models.
keys: List of 2-3 model keys to merge together. All models must use the same base type.
merged_model_name: Name for the merged model [Concat model names]
alpha: Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
force: If true, force the merge even if the models were generated by different versions of the diffusers library [False]
interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
merge_dest_directory: Specify a directory to store the merged model in [models directory]
"""
print(f"here i am, keys={keys}")
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
installer = ApiDependencies.invoker.services.model_install
merger = ModelMerger(installer)
model_names = [installer.record_store.get_model(x).name for x in keys]
response = merger.merge_diffusion_models_and_save(
model_keys=keys,
merged_model_name=merged_model_name or "+".join(model_names),
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory=dest,
)
except UnknownModelException:
raise HTTPException(
status_code=404,
detail=f"One or more of the models '{keys}' not found",
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response

View File

@ -1,427 +0,0 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
import pathlib
from typing import Annotated, List, Literal, Optional, Union
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import MergeInterpolationMethod
from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS,
InvalidModelException,
ModelNotFoundException,
SchedulerPredictionType,
)
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
UpdateModelResponseValidator = TypeAdapter(UpdateModelResponse)
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelResponseValidator = TypeAdapter(ImportModelResponse)
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ConvertModelResponseValidator = TypeAdapter(ConvertModelResponse)
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
model_config = ConfigDict(use_enum_values=True)
ModelsListValidator = TypeAdapter(ModelsList)
@models_router.get(
"/",
operation_id="list_models",
responses={200: {"model": ModelsList}},
)
async def list_models(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList:
"""Gets a list of models"""
if base_models and len(base_models) > 0:
models_raw = []
for base_model in base_models:
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
else:
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
models = ModelsListValidator.validate_python({"models": models_raw})
return models
@models_router.patch(
"/{base_model}/{model_type}/{model_name}",
operation_id="update_model",
responses={
200: {"description": "The model was updated successfully"},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
409: {"description": "There is already a model corresponding to the new name"},
},
status_code=200,
response_model=UpdateModelResponse,
)
async def update_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> UpdateModelResponse:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
try:
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
# rename operation requested
if info.model_name != model_name or info.base_model != base_model:
ApiDependencies.invoker.services.model_manager.rename_model(
base_model=base_model,
model_type=model_type,
model_name=model_name,
new_name=info.model_name,
new_base=info.base_model,
)
logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
# update information to support an update of attributes
model_name = info.model_name
base_model = info.base_model
new_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
if new_info.get("path") != previous_info.get(
"path"
): # model manager moved model path during rename - don't overwrite it
info.path = new_info.get("path")
# replace empty string values with None/null to avoid phenomenon of vae: ''
info_dict = info.model_dump()
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_attributes=info_dict,
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
model_response = UpdateModelResponseValidator.validate_python(model_raw)
except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))
return model_response
@models_router.post(
"/import",
operation_id="import_model",
responses={
201: {"description": "The model imported successfully"},
404: {"description": "The model could not be found"},
415: {"description": "Unrecognized file/folder format"},
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ImportModelResponse,
)
async def import_model(
location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
description="Prediction type for SDv2 checkpoints and rare SDv1 checkpoints",
default=None,
),
) -> ImportModelResponse:
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
location = location.strip("\"' ")
items_to_import = {location}
prediction_types = {x.value: x for x in SchedulerPredictionType}
logger = ApiDependencies.invoker.services.logger
try:
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import=items_to_import,
prediction_type_helper=lambda x: prediction_types.get(prediction_type),
)
info = installed_models.get(location)
if not info:
logger.error("Import failed")
raise HTTPException(status_code=415)
logger.info(f"Successfully imported {location}, got {info}")
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.name, base_model=info.base_model, model_type=info.model_type
)
return ImportModelResponseValidator.validate_python(model_raw)
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/add",
operation_id="add_model",
responses={
201: {"description": "The model added successfully"},
404: {"description": "The model could not be found"},
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ImportModelResponse,
)
async def add_model(
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> ImportModelResponse:
"""Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
logger = ApiDependencies.invoker.services.logger
try:
ApiDependencies.invoker.services.model_manager.add_model(
info.model_name,
info.base_model,
info.model_type,
model_attributes=info.model_dump(),
)
logger.info(f"Successfully added {info.model_name}")
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.model_name,
base_model=info.base_model,
model_type=info.model_type,
)
return ImportModelResponseValidator.validate_python(model_raw)
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.delete(
"/{base_model}/{model_type}/{model_name}",
operation_id="del_model",
responses={
204: {"description": "Model deleted successfully"},
404: {"description": "Model not found"},
},
status_code=204,
response_model=None,
)
async def delete_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
) -> Response:
"""Delete Model"""
logger = ApiDependencies.invoker.services.logger
try:
ApiDependencies.invoker.services.model_manager.del_model(
model_name, base_model=base_model, model_type=model_type
)
logger.info(f"Deleted model: {model_name}")
return Response(status_code=204)
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@models_router.put(
"/convert/{base_model}/{model_type}/{model_name}",
operation_id="convert_model",
responses={
200: {"description": "Model converted successfully"},
400: {"description": "Bad request"},
404: {"description": "Model not found"},
},
status_code=200,
response_model=ConvertModelResponse,
)
async def convert_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
convert_dest_directory: Optional[str] = Query(
default=None, description="Save the converted model to the designated directory"
),
) -> ConvertModelResponse:
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Converting model: {model_name}")
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
ApiDependencies.invoker.services.model_manager.convert_model(
model_name,
base_model=base_model,
model_type=model_type,
convert_dest_directory=dest,
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name, base_model=base_model, model_type=model_type
)
response = ConvertModelResponseValidator.validate_python(model_raw)
except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
@models_router.get(
"/search",
operation_id="search_for_models",
responses={
200: {"description": "Directory searched successfully"},
404: {"description": "Invalid directory path"},
},
status_code=200,
response_model=List[pathlib.Path],
)
async def search_for_models(
search_path: pathlib.Path = Query(description="Directory path to search for models"),
) -> List[pathlib.Path]:
if not search_path.is_dir():
raise HTTPException(
status_code=404,
detail=f"The search path '{search_path}' does not exist or is not directory",
)
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
@models_router.get(
"/ckpt_confs",
operation_id="list_ckpt_configs",
responses={
200: {"description": "paths retrieved successfully"},
},
status_code=200,
response_model=List[pathlib.Path],
)
async def list_ckpt_configs() -> List[pathlib.Path]:
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
@models_router.post(
"/sync",
operation_id="sync_to_config",
responses={
201: {"description": "synchronization successful"},
},
status_code=201,
response_model=bool,
)
async def sync_to_config() -> bool:
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
in-memory data structures with disk data structures."""
ApiDependencies.invoker.services.model_manager.sync_to_config()
return True
# There's some weird pydantic-fastapi behaviour that requires this to be a separate class
# TODO: After a few updates, see if it works inside the route operation handler?
class MergeModelsBody(BaseModel):
model_names: List[str] = Field(description="model name", min_length=2, max_length=3)
merged_model_name: Optional[str] = Field(description="Name of destination model")
alpha: Optional[float] = Field(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5)
interp: Optional[MergeInterpolationMethod] = Field(description="Interpolation method")
force: Optional[bool] = Field(
description="Force merging of models created with different versions of diffusers",
default=False,
)
merge_dest_directory: Optional[str] = Field(
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
default=None,
)
model_config = ConfigDict(protected_namespaces=())
@models_router.put(
"/merge/{base_model}",
operation_id="merge_models",
responses={
200: {"description": "Model converted successfully"},
400: {"description": "Incompatible models"},
404: {"description": "One or more models not found"},
},
status_code=200,
response_model=MergeModelResponse,
)
async def merge_models(
body: Annotated[MergeModelsBody, Body(description="Model configuration", embed=True)],
base_model: BaseModelType = Path(description="Base model"),
) -> MergeModelResponse:
"""Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(
f"Merging models: {body.model_names} into {body.merge_dest_directory or '<MODELS>'}/{body.merged_model_name}"
)
dest = pathlib.Path(body.merge_dest_directory) if body.merge_dest_directory else None
result = ApiDependencies.invoker.services.model_manager.merge_models(
model_names=body.model_names,
base_model=base_model,
merged_model_name=body.merged_model_name or "+".join(body.model_names),
alpha=body.alpha,
interp=body.interp,
force=body.force,
merge_dest_directory=dest,
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
result.name,
base_model=base_model,
model_type=ModelType.Main,
)
response = ConvertModelResponseValidator.validate_python(model_raw)
except ModelNotFoundException:
raise HTTPException(
status_code=404,
detail=f"One or more of the models '{body.model_names}' not found",
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response

View File

@ -1,276 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from fastapi import HTTPException, Path
from fastapi.routing import APIRouter
from ...services.shared.graph import GraphExecutionState
from ..dependencies import ApiDependencies
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
# @session_router.post(
# "/",
# operation_id="create_session",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid json"},
# },
# deprecated=True,
# )
# async def create_session(
# queue_id: str = Query(default="", description="The id of the queue to associate the session with"),
# graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
# ) -> GraphExecutionState:
# """Creates a new session, optionally initializing it with an invocation graph"""
# session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph)
# return session
# @session_router.get(
# "/",
# operation_id="list_sessions",
# responses={200: {"model": PaginatedResults[GraphExecutionState]}},
# deprecated=True,
# )
# async def list_sessions(
# page: int = Query(default=0, description="The page of results to get"),
# per_page: int = Query(default=10, description="The number of results per page"),
# query: str = Query(default="", description="The query string to search for"),
# ) -> PaginatedResults[GraphExecutionState]:
# """Gets a list of sessions, optionally searching"""
# if query == "":
# result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
# else:
# result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
# return result
@session_router.get(
"/{session_id}",
operation_id="get_session",
responses={
200: {"model": GraphExecutionState},
404: {"description": "Session not found"},
},
)
async def get_session(
session_id: str = Path(description="The id of the session to get"),
) -> GraphExecutionState:
"""Gets a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
else:
return session
# @session_router.post(
# "/{session_id}/nodes",
# operation_id="add_node",
# responses={
# 200: {"model": str},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def add_node(
# session_id: str = Path(description="The id of the session"),
# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
# description="The node to add"
# ),
# ) -> str:
# """Adds a node to the graph"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# session.add_node(node)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session.id
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# @session_router.put(
# "/{session_id}/nodes/{node_path}",
# operation_id="update_node",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def update_node(
# session_id: str = Path(description="The id of the session"),
# node_path: str = Path(description="The path to the node in the graph"),
# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
# description="The new node"
# ),
# ) -> GraphExecutionState:
# """Updates a node in the graph and removes all linked edges"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# session.update_node(node_path, node)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# @session_router.delete(
# "/{session_id}/nodes/{node_path}",
# operation_id="delete_node",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def delete_node(
# session_id: str = Path(description="The id of the session"),
# node_path: str = Path(description="The path to the node to delete"),
# ) -> GraphExecutionState:
# """Deletes a node in the graph and removes all linked edges"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# session.delete_node(node_path)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# @session_router.post(
# "/{session_id}/edges",
# operation_id="add_edge",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def add_edge(
# session_id: str = Path(description="The id of the session"),
# edge: Edge = Body(description="The edge to add"),
# ) -> GraphExecutionState:
# """Adds an edge to the graph"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# session.add_edge(edge)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# # TODO: the edge being in the path here is really ugly, find a better solution
# @session_router.delete(
# "/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}",
# operation_id="delete_edge",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def delete_edge(
# session_id: str = Path(description="The id of the session"),
# from_node_id: str = Path(description="The id of the node the edge is coming from"),
# from_field: str = Path(description="The field of the node the edge is coming from"),
# to_node_id: str = Path(description="The id of the node the edge is going to"),
# to_field: str = Path(description="The field of the node the edge is going to"),
# ) -> GraphExecutionState:
# """Deletes an edge from the graph"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# edge = Edge(
# source=EdgeConnection(node_id=from_node_id, field=from_field),
# destination=EdgeConnection(node_id=to_node_id, field=to_field),
# )
# session.delete_edge(edge)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# @session_router.put(
# "/{session_id}/invoke",
# operation_id="invoke_session",
# responses={
# 200: {"model": None},
# 202: {"description": "The invocation is queued"},
# 400: {"description": "The session has no invocations ready to invoke"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def invoke_session(
# queue_id: str = Query(description="The id of the queue to associate the session with"),
# session_id: str = Path(description="The id of the session to invoke"),
# all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
# ) -> Response:
# """Invokes a session"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# if session.is_complete():
# raise HTTPException(status_code=400)
# ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all)
# return Response(status_code=202)
# @session_router.delete(
# "/{session_id}/invoke",
# operation_id="cancel_session_invoke",
# responses={202: {"description": "The invocation is canceled"}},
# deprecated=True,
# )
# async def cancel_session_invoke(
# session_id: str = Path(description="The id of the session to cancel"),
# ) -> Response:
# """Invokes a session"""
# ApiDependencies.invoker.cancel(session_id)
# return Response(status_code=202)

View File

@ -12,16 +12,26 @@ class SocketIO:
__sio: AsyncServer
__app: ASGIApp
__sub_queue: str = "subscribe_queue"
__unsub_queue: str = "unsubscribe_queue"
__sub_bulk_download: str = "subscribe_bulk_download"
__unsub_bulk_download: str = "unsubscribe_bulk_download"
def __init__(self, app: FastAPI):
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
app.mount("/ws", self.__app)
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue)
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download)
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download)
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event)
async def _handle_queue_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
@ -39,3 +49,18 @@ class SocketIO:
async def _handle_model_event(self, event: Event) -> None:
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
async def _handle_bulk_download_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["bulk_download_id"],
)
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs):
if "bulk_download_id" in data:
await self.__sio.enter_room(sid, data["bulk_download_id"])
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs):
if "bulk_download_id" in data:
await self.__sio.leave_room(sid, data["bulk_download_id"])

View File

@ -2,6 +2,7 @@
# which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file.
import sys
from contextlib import asynccontextmanager
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.version.invokeai_version import __version__
@ -48,10 +49,8 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
boards,
download_queue,
images,
model_records,
models,
model_manager,
session_queue,
sessions,
utilities,
workflows,
)
@ -73,9 +72,25 @@ logger = InvokeAILogger.get_logger(config=app_config)
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
@asynccontextmanager
async def lifespan(app: FastAPI):
# Add startup event to load dependencies
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
yield
# Shut down threads
ApiDependencies.shutdown()
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI(title="Invoke - Community Edition", docs_url=None, redoc_url=None, separate_input_output_schemas=False)
app = FastAPI(
title="Invoke - Community Edition",
docs_url=None,
redoc_url=None,
separate_input_output_schemas=False,
lifespan=lifespan,
)
# Add event handler
event_handler_id: int = id(app)
@ -98,24 +113,9 @@ app.add_middleware(
app.add_middleware(GZipMiddleware, minimum_size=1000)
# Add startup event to load dependencies
@app.on_event("startup")
async def startup_event() -> None:
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
# Shut down threads
@app.on_event("shutdown")
async def shutdown_event() -> None:
ApiDependencies.shutdown()
# Include all routers
app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(model_records.model_records_router, prefix="/api")
app.include_router(model_manager.model_manager_router, prefix="/api")
app.include_router(download_queue.download_queue_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
@ -153,6 +153,8 @@ def custom_openapi() -> dict[str, Any]:
# TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema["title"]
openapi_schema["components"]["schemas"][schema_key] = output_schema
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
# Add Node Editor UI helper schemas
ui_config_schemas = models_json_schema(
@ -175,23 +177,24 @@ def custom_openapi() -> dict[str, Any]:
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
invoker_schema["class"] = "invocation"
openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output"
from invokeai.backend.model_management.models import get_model_config_enums
# This code no longer seems to be necessary?
# Leave it here just in case
#
# from invokeai.backend.model_manager import get_model_config_formats
# formats = get_model_config_formats()
# for model_config_name, enum_set in formats.items():
for model_config_format_enum in set(get_model_config_enums()):
name = model_config_format_enum.__qualname__
# if model_config_name in openapi_schema["components"]["schemas"]:
# # print(f"Config with name {name} already defined")
# continue
if name in openapi_schema["components"]["schemas"]:
# print(f"Config with name {name} already defined")
continue
openapi_schema["components"]["schemas"][name] = {
"title": name,
"description": "An enumeration.",
"type": "string",
"enum": [v.value for v in model_config_format_enum],
}
# openapi_schema["components"]["schemas"][model_config_name] = {
# "title": model_config_name,
# "description": "An enumeration.",
# "type": "string",
# "enum": [v.value for v in enum_set],
# }
app.openapi_schema = openapi_schema
return app.openapi_schema

View File

@ -8,13 +8,26 @@ import warnings
from abc import ABC, abstractmethod
from enum import Enum
from inspect import signature
from types import UnionType
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
ClassVar,
Iterable,
Literal,
Optional,
Type,
TypeVar,
Union,
cast,
)
import semver
from pydantic import BaseModel, ConfigDict, Field, create_model
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from typing_extensions import TypeAliasType
from invokeai.app.invocations.fields import (
FieldKind,
@ -84,6 +97,7 @@ class BaseInvocationOutput(BaseModel):
"""
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
@classmethod
def register_output(cls, output: BaseInvocationOutput) -> None:
@ -96,10 +110,14 @@ class BaseInvocationOutput(BaseModel):
return cls._output_classes
@classmethod
def get_outputs_union(cls) -> UnionType:
"""Gets a union of all invocation outputs."""
outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type]
return outputs_union # type: ignore [return-value]
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
if not cls._typeadapter:
InvocationOutputsUnion = TypeAliasType(
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
return cls._typeadapter
@classmethod
def get_output_types(cls) -> Iterable[str]:
@ -148,6 +166,7 @@ class BaseInvocation(ABC, BaseModel):
"""
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
@classmethod
def get_type(cls) -> str:
@ -160,10 +179,14 @@ class BaseInvocation(ABC, BaseModel):
cls._invocation_classes.add(invocation)
@classmethod
def get_invocations_union(cls) -> UnionType:
"""Gets a union of all invocation types."""
invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type]
return invocations_union # type: ignore [return-value]
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter:
InvocationsUnion = TypeAliasType(
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(InvocationsUnion)
return cls._typeadapter
@classmethod
def get_invocations(cls) -> Iterable[BaseInvocation]:

View File

@ -1,35 +1,25 @@
from typing import List, Optional, Union
from typing import Iterator, List, Optional, Tuple, Union, cast
import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
OutputField,
UIComponent,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningFieldData,
ExtraConditioningInfo,
SDXLConditioningInfo,
)
from invokeai.backend.util.devices import torch_dtype
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import ModelNotFoundException, ModelType
from ...backend.util.devices import torch_dtype
from ..util.ti_utils import extract_ti_triggers_from_prompt
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import ClipField
# unconditioned: Optional[torch.Tensor]
@ -65,39 +55,26 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader():
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in extract_ti_triggers_from_prompt(self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
(
name,
context.models.load(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model,
)
)
except ModelNotFoundException:
# print(e)
# import traceback
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
with (
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
tokenizer,
ti_manager,
),
@ -105,8 +82,9 @@ class CompelInvocation(BaseInvocation):
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
):
assert isinstance(text_encoder, CLIPTextModel)
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
@ -144,6 +122,8 @@ class CompelInvocation(BaseInvocation):
class SDXLPromptInvocationBase:
"""Prompt processor for SDXL models."""
def run_clip_compel(
self,
context: InvocationContext,
@ -152,20 +132,25 @@ class SDXLPromptInvocationBase:
get_pooled: bool,
lora_prefix: str,
zero_on_empty: bool,
):
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
# return zero on empty
if prompt == "" and zero_on_empty:
cpu_text_encoder = text_encoder_info.context.model
cpu_text_encoder = text_encoder_info.model
assert isinstance(cpu_text_encoder, torch.nn.Module)
c = torch.zeros(
(
1,
cpu_text_encoder.config.max_position_embeddings,
cpu_text_encoder.config.hidden_size,
),
dtype=text_encoder_info.context.cache.precision,
dtype=cpu_text_encoder.dtype,
)
if get_pooled:
c_pooled = torch.zeros(
@ -176,37 +161,21 @@ class SDXLPromptInvocationBase:
c_pooled = None
return c, c_pooled, None
def _lora_loader():
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight)
del lora_info
return
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in extract_ti_triggers_from_prompt(prompt):
name = trigger[1:-1]
try:
ti_list.append(
(
name,
context.models.load(
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model,
)
)
except ModelNotFoundException:
# print(e)
# import traceback
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
with (
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
tokenizer,
ti_manager,
),
@ -214,8 +183,10 @@ class SDXLPromptInvocationBase:
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
):
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
text_encoder = cast(CLIPTextModel, text_encoder)
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
@ -332,6 +303,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
dim=1,
)
assert c2_pooled is not None
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
@ -380,6 +352,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
assert c2_pooled is not None
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
@ -414,7 +387,7 @@ class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model."""
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
self.clip.skipped_layers += self.skipped_layers
@ -424,9 +397,9 @@ class ClipSkipInvocation(BaseInvocation):
def get_max_token_count(
tokenizer,
tokenizer: CLIPTokenizer,
prompt: Union[FlattenedPrompt, Blend, Conjunction],
truncate_if_too_long=False,
truncate_if_too_long: bool = False,
) -> int:
if type(prompt) is Blend:
blend: Blend = prompt
@ -438,7 +411,9 @@ def get_max_token_count(
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]:
def get_tokens_for_prompt_object(
tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True
) -> List[str]:
if type(parsed_prompt) is Blend:
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
@ -451,24 +426,29 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
for x in parsed_prompt.children
]
text = " ".join(text_fragments)
tokens = tokenizer.tokenize(text)
tokens: List[str] = tokenizer.tokenize(text)
if truncate_if_too_long:
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
tokens = tokens[0:max_tokens_length]
return tokens
def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None):
def log_tokenization_for_conjunction(
c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
) -> None:
display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts):
if len(c.prompts) > 1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else:
assert display_label_prefix is not None
this_display_label_prefix = display_label_prefix
log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix)
def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None):
def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
) -> None:
display_label_prefix = display_label_prefix or ""
if type(p) is Blend:
blend: Blend = p
@ -508,7 +488,12 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
def log_tokenization_for_text(
text: str,
tokenizer: CLIPTokenizer,
display_label: Optional[str] = None,
truncate_if_too_long: Optional[bool] = False,
) -> None:
"""shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '

View File

@ -12,3 +12,6 @@ The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
"""A literal type representing the valid scheduler names."""
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
"""A literal type for PIL image modes supported by Invoke"""

View File

@ -23,7 +23,7 @@ from controlnet_aux import (
)
from controlnet_aux.util import HWC3, ade_palette
from PIL import Image
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from invokeai.app.invocations.fields import (
FieldDescriptions,
@ -39,14 +39,8 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
from invokeai.backend.model_management.models.base import BaseModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
CONTROLNET_RESIZE_VALUES = Literal[
@ -60,10 +54,7 @@ CONTROLNET_RESIZE_VALUES = Literal[
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
model_name: str = Field(description="Name of the ControlNet model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
key: str = Field(description="Model config record key for the ControlNet model")
class ControlField(BaseModel):
@ -152,8 +143,12 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
# superclass just passes through image without processing
return image
def load_image(self, context: InvocationContext) -> Image.Image:
# allows override for any special formatting specific to the preprocessor
return context.images.get_pil(self.image.image_name, "RGB")
def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.images.get_pil(self.image.image_name)
raw_image = self.load_image(context)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
@ -190,6 +185,10 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
)
def load_image(self, context: InvocationContext) -> Image.Image:
# Keep alpha channel for Canny processing to detect edges of transparent areas
return context.images.get_pil(self.image.image_name, "RGBA")
def run_processor(self, image):
canny_processor = CannyDetector()
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
@ -424,10 +423,6 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
def run_processor(self, image):
# MediaPipeFaceDetector throws an error if image has alpha channel
# so convert to RGB if needed
if image.mode == "RGBA":
image = image.convert("RGB")
mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
return processed_image
@ -557,7 +552,6 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
color_map_tile_size: int = InputField(default=64, ge=0, description=FieldDescriptions.tile_size)
def run_processor(self, image: Image.Image):
image = image.convert("RGB")
np_image = np.array(image, dtype=np.uint8)
height, width = np_image.shape[:2]
@ -582,7 +576,7 @@ DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
title="Depth Anything Processor",
tags=["controlnet", "depth", "depth anything"],
category="controlnet",
version="1.0.0",
version="1.0.1",
)
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a depth map based on the Depth Anything algorithm"""
@ -591,16 +585,12 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
default="small", description="The size of the depth model to use"
)
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
offload: bool = InputField(default=False)
def run_processor(self, image: Image.Image):
depth_anything_detector = DepthAnythingDetector()
depth_anything_detector.load_model(model_size=self.model_size)
if image.mode == "RGBA":
image = image.convert("RGB")
processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload)
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
return processed_image
@ -619,7 +609,7 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
draw_hands: bool = InputField(default=False)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image):
def run_processor(self, image: Image.Image):
dw_openpose = DWOpenposeDetector()
processed_image = dw_openpose(
image,

View File

@ -199,6 +199,7 @@ class DenoiseMaskField(BaseModel):
mask_name: str = Field(description="The name of the mask image")
masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents")
gradient: bool = Field(default=False, description="Used for gradient inpainting")
class LatentsField(BaseModel):

View File

@ -7,6 +7,7 @@ import cv2
import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import (
ColorField,
FieldDescriptions,
@ -21,11 +22,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
from .baseinvocation import (
BaseInvocation,
Classification,
invocation,
)
from .baseinvocation import BaseInvocation, Classification, invocation
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.1")
@ -263,9 +260,6 @@ class ImageChannelInvocation(BaseInvocation, WithMetadata, WithBoard):
return ImageOutput.build(image_dto)
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
@invocation(
"img_conv",
title="Convert Image Mode",
@ -936,3 +930,40 @@ class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)
@invocation(
"canvas_paste_back",
title="Canvas Paste Back",
tags=["image", "combine"],
category="image",
version="1.0.0",
)
class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Combines two images by using the mask provided. Intended for use on the Unified Canvas."""
source_image: ImageField = InputField(description="The source image")
target_image: ImageField = InputField(default=None, description="The target image")
mask: ImageField = InputField(
description="The mask to use when pasting",
)
mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")
def _prepare_mask(self, mask: Image.Image) -> Image.Image:
mask_array = numpy.array(mask)
kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
dilated_mask = Image.fromarray(dilated_mask_array)
if self.mask_blur > 0:
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
return ImageOps.invert(mask.convert("L"))
def invoke(self, context: InvocationContext) -> ImageOutput:
source_image = context.images.get_pil(self.source_image.image_name)
target_image = context.images.get_pil(self.target_image.image_name)
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
source_image.paste(target_image, (0, 0), mask)
image_dto = context.images.save(image=source_image)
return ImageOutput.build(image_dto)

View File

@ -1,8 +1,8 @@
import os
from builtins import float
from typing import List, Union
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@ -14,22 +14,16 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
from invokeai.backend.model_manager.config import BaseModelType, ModelType
# LS: Consider moving these two classes into model.py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
key: str = Field(description="Key to the IP-Adapter model")
class CLIPVisionModelField(BaseModel):
model_name: str = Field(description="Name of the CLIP Vision image encoder model")
base_model: BaseModelType = Field(description="Base model (usually 'Any')")
model_config = ConfigDict(protected_namespaces=())
key: str = Field(description="Key to the CLIP Vision image encoder model")
class IPAdapterField(BaseModel):
@ -46,12 +40,12 @@ class IPAdapterField(BaseModel):
@field_validator("weight")
@classmethod
def validate_ip_adapter_weight(cls, v):
def validate_ip_adapter_weight(cls, v: float) -> float:
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
def validate_begin_end_step_percent(self) -> Self:
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
@ -84,33 +78,25 @@ class IPAdapterInvocation(BaseInvocation):
@field_validator("weight")
@classmethod
def validate_ip_adapter_weight(cls, v):
def validate_ip_adapter_weight(cls, v: float) -> float:
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
def validate_begin_end_step_percent(self) -> Self:
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_info(
self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter
)
# HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model
# directly, and 2) we are reading from disk every time this invocation is called without caching the result.
# A better solution would be to store the image encoder model reference in the IP-Adapter model info, but this
# is currently messy due to differences between how the model info is generated when installing a model from
# disk vs. downloading the model.
image_encoder_model_id = get_ip_adapter_image_encoder_model_id(
os.path.join(context.config.get().models_path, ip_adapter_info["path"])
)
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_model = CLIPVisionModelField(
model_name=image_encoder_model_name,
base_model=BaseModelType.Any,
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
assert len(image_encoder_models) == 1
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,

View File

@ -3,13 +3,15 @@
import math
from contextlib import ExitStack
from functools import singledispatchmethod
from typing import List, Literal, Optional, Union
from typing import Any, Iterator, List, Literal, Optional, Tuple, Union
import einops
import numpy as np
import numpy.typing as npt
import torch
import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.configuration_utils import ConfigMixin
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.adapter import T2IAdapter
from diffusers.models.attention_processor import (
@ -18,8 +20,10 @@ from diffusers.models.attention_processor import (
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers import DPMSolverSDEScheduler
from diffusers.schedulers import SchedulerMixin as Scheduler
from PIL import Image, ImageFilter
from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize
@ -47,13 +51,13 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType, LoadedModel
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
from invokeai.backend.util.silence_warnings import SilenceWarnings
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import BaseModelType
from ...backend.model_management.seamless import set_seamless
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
IPAdapterData,
@ -124,10 +128,10 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
ui_order=4,
)
def prep_mask_tensor(self, mask_image):
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
if mask_image.mode != "L":
mask_image = mask_image.convert("L")
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
if mask_tensor.dim() == 3:
mask_tensor = mask_tensor.unsqueeze(0)
# if shape is not None:
@ -138,21 +142,21 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
if self.image is not None:
image = context.images.get_pil(self.image.image_name)
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image.dim() == 3:
image = image.unsqueeze(0)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
else:
image = None
image_tensor = None
mask = self.prep_mask_tensor(
context.images.get_pil(self.mask.image_name),
)
if image is not None:
if image_tensor is not None:
vae_info = context.models.load(**self.vae.vae.model_dump())
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
# TODO:
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
@ -165,6 +169,62 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
return DenoiseMaskOutput.build(
mask_name=mask_name,
masked_latents_name=masked_latents_name,
gradient=False,
)
@invocation(
"create_gradient_mask",
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
version="1.0.0",
)
class CreateGradientMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
edge_radius: int = InputField(
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
)
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
minimum_denoise: float = InputField(
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
if self.coherence_mode == "Box Blur":
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
else: # Gaussian Blur OR Staged
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
# redistribute blur so that the edges are 0 and blur out to 1
blur_tensor = (blur_tensor - 0.5) * 2
threshold = 1 - self.minimum_denoise
if self.coherence_mode == "Staged":
# wherever the blur_tensor is masked to any degree, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
# multiply original mask to force actually masked regions to 0
blur_tensor = mask_tensor * blur_tensor
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
return DenoiseMaskOutput.build(
mask_name=mask_name,
masked_latents_name=None,
gradient=True,
)
@ -183,7 +243,7 @@ def get_scheduler(
scheduler_config = scheduler_config["_backup"]
scheduler_config = {
**scheduler_config,
**scheduler_extra_config,
**scheduler_extra_config, # FIXME
"_backup": scheduler_config,
}
@ -196,6 +256,7 @@ def get_scheduler(
# hack copied over from generate.py
if not hasattr(scheduler, "uses_inpainting_model"):
scheduler.uses_inpainting_model = lambda: False
assert isinstance(scheduler, Scheduler)
return scheduler
@ -279,7 +340,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
@field_validator("cfg_scale")
def ge_one(cls, v):
def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
@ -293,9 +354,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
def get_conditioning_data(
self,
context: InvocationContext,
scheduler,
unet,
seed,
scheduler: Scheduler,
unet: UNet2DConditionModel,
seed: int,
) -> ConditioningData:
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
@ -318,7 +379,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
),
)
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
scheduler,
# for ddim scheduler
eta=0.0, # ddim_eta
@ -330,8 +391,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
def create_pipeline(
self,
unet,
scheduler,
unet: UNet2DConditionModel,
scheduler: Scheduler,
) -> StableDiffusionGeneratorPipeline:
# TODO:
# configure_model_padding(
@ -342,10 +403,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
class FakeVae:
class FakeVaeConfig:
def __init__(self):
def __init__(self) -> None:
self.block_out_channels = [0]
def __init__(self):
def __init__(self) -> None:
self.config = FakeVae.FakeVaeConfig()
return StableDiffusionGeneratorPipeline(
@ -362,11 +423,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
def prep_control_data(
self,
context: InvocationContext,
control_input: Union[ControlField, List[ControlField]],
control_input: Optional[Union[ControlField, List[ControlField]]],
latents_shape: List[int],
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]:
) -> Optional[List[ControlNetData]]:
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
@ -388,13 +449,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# and if weight is None, populate with default 1.0?
controlnet_data = []
for control_info in control_list:
control_model = exit_stack.enter_context(
context.models.load(
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
)
)
control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
# control_models.append(control_model)
control_image_field = control_info.image
@ -456,25 +511,17 @@ class DenoiseLatentsInvocation(BaseInvocation):
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.models.load(
model_name=single_ip_adapter.ip_adapter_model.model_name,
model_type=ModelType.IPAdapter,
base_model=single_ip_adapter.ip_adapter_model.base_model,
)
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
)
image_encoder_model_info = context.models.load(
model_name=single_ip_adapter.image_encoder_model.model_name,
model_type=ModelType.CLIPVision,
base_model=single_ip_adapter.image_encoder_model.base_model,
)
image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_images = single_ip_adapter.image
if not isinstance(single_ipa_images, list):
single_ipa_images = [single_ipa_images]
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
single_ipa_image_fields = [single_ipa_image_fields]
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_images]
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
@ -518,25 +565,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_info = context.models.load(
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
model_type=ModelType.T2IAdapter,
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
)
t2i_adapter_model_config = context.models.get_config(key=t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
image = context.images.get_pil(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
if t2i_adapter_model_config.base == BaseModelType.StableDiffusion1:
max_unet_downscale = 8
elif t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusionXL:
elif t2i_adapter_model_config.base == BaseModelType.StableDiffusionXL:
max_unet_downscale = 4
else:
raise ValueError(
f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'."
)
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")
t2i_adapter_model: T2IAdapter
with t2i_adapter_model_info as t2i_adapter_model:
with t2i_adapter_loaded_model as t2i_adapter_model:
total_downscale_factor = t2i_adapter_model.total_downscale_factor
# Resize the T2I-Adapter input image.
@ -556,7 +598,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=False,
width=t2i_input_width,
height=t2i_input_height,
num_channels=t2i_adapter_model.config.in_channels,
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
device=t2i_adapter_model.device,
dtype=t2i_adapter_model.dtype,
resize_mode=t2i_adapter_field.resize_mode,
@ -581,7 +623,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
# original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
def init_scheduler(
self,
scheduler: Union[Scheduler, ConfigMixin],
device: torch.device,
steps: int,
denoising_start: float,
denoising_end: float,
) -> Tuple[int, List[int], int]:
assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False):
scheduler.set_timesteps(steps, device="cpu")
timesteps = scheduler.timesteps.to(device=device)
@ -593,11 +643,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
_timesteps = timesteps[:: scheduler.order]
# get start timestep index
t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start)))
t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start)))
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps)))
# get end timestep index
t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end)))
t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end)))
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:])))
# apply order to indexes
@ -610,9 +660,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
return num_inference_steps, timesteps, init_timestep
def prep_inpaint_mask(self, context: InvocationContext, latents):
def prep_inpaint_mask(
self, context: InvocationContext, latents: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
if self.denoise_mask is None:
return None, None
return None, None, False
mask = context.tensors.load(self.denoise_mask.mask_name)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
@ -621,7 +673,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
else:
masked_latents = None
return 1 - mask, masked_latents
return 1 - mask, masked_latents, self.denoise_mask.gradient
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
@ -648,7 +700,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
if seed is None:
seed = 0
mask, masked_latents = self.prep_inpaint_mask(context, latents)
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
# below. Investigate whether this is appropriate.
@ -659,25 +711,30 @@ class DenoiseLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=True,
)
def step_callback(state: PipelineIntermediateState):
context.util.sd_step_callback(state, self.unet.unet.base_model)
# get the unet's config so that we can pass the base to dispatch_progress()
unet_config = context.models.get_config(self.unet.unet.key)
def _lora_loader():
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
yield (lora_info.model, lora.weight)
del lora_info
return
unet_info = context.models.load(**self.unet.unet.model_dump())
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
set_seamless(unet_info.context.model, self.unet.seamless_axes),
ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config),
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
unet_info as unet,
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
@ -731,6 +788,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed=seed,
mask=mask,
masked_latents=masked_latents,
gradient_mask=gradient_mask,
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
control_data=controlnet_data,
@ -776,7 +834,8 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae_info = context.models.load(**self.vae.vae.model_dump())
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module)
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)
@ -948,8 +1007,9 @@ class ImageToLatentsInvocation(BaseInvocation):
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
@staticmethod
def vae_encode(vae_info, upcast, tiled, image_tensor):
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info as vae:
assert isinstance(vae, torch.nn.Module)
orig_dtype = vae.dtype
if upcast:
vae.to(dtype=torch.float32)
@ -1010,14 +1070,19 @@ class ImageToLatentsInvocation(BaseInvocation):
@singledispatchmethod
@staticmethod
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
assert isinstance(vae, torch.nn.Module)
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
latents: torch.Tensor = image_tensor_dist.sample().to(
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!
return latents
@_encode_to_tensor.register
@staticmethod
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
return vae.encode(image_tensor).latents
assert isinstance(vae, torch.nn.Module)
latents: torch.FloatTensor = vae.encode(image_tensor).latents
return latents
@invocation(
@ -1050,7 +1115,12 @@ class BlendLatentsInvocation(BaseInvocation):
# TODO:
device = choose_torch_device()
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
def slerp(
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
v0: Union[torch.Tensor, npt.NDArray[Any]],
v1: Union[torch.Tensor, npt.NDArray[Any]],
DOT_THRESHOLD: float = 0.9995,
) -> Union[torch.Tensor, npt.NDArray[Any]]:
"""
Spherical linear interpolation
Args:
@ -1083,12 +1153,16 @@ class BlendLatentsInvocation(BaseInvocation):
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(device)
return v2
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
return v2_torch
else:
assert isinstance(v2, np.ndarray)
return v2
# blend
blended_latents = slerp(self.alpha, latents_a, latents_b)
bl = slerp(self.alpha, latents_a, latents_b)
assert isinstance(bl, torch.Tensor)
blended_latents: torch.Tensor = bl # for type checking convenience
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
@ -1181,15 +1255,16 @@ class IdealSizeInvocation(BaseInvocation):
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in initial generation artifacts if too large)",
)
def trim_to_multiple_of(self, *args, multiple_of=LATENT_SCALE_FACTOR):
def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]:
return tuple((x - x % multiple_of) for x in args)
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
unet_config = context.models.get_config(**self.unet.unet.model_dump())
aspect = self.width / self.height
dimension = 512
if self.unet.unet.base_model == BaseModelType.StableDiffusion2:
dimension: float = 512
if unet_config.base == BaseModelType.StableDiffusion2:
dimension = 768
elif self.unet.unet.base_model == BaseModelType.StableDiffusionXL:
elif unet_config.base == BaseModelType.StableDiffusionXL:
dimension = 1024
dimension = dimension * self.multiplier
min_dimension = math.floor(dimension * 0.5)

View File

@ -33,7 +33,7 @@ class MetadataItemField(BaseModel):
class LoRAMetadataField(BaseModel):
"""LoRA Metadata Field"""
lora: LoRAModelField = Field(description=FieldDescriptions.lora_model)
model: LoRAModelField = Field(description=FieldDescriptions.lora_model)
weight: float = Field(description=FieldDescriptions.lora_weight)
@ -114,7 +114,7 @@ GENERATION_MODES = Literal[
]
@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.0.1")
@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.1.1")
class CoreMetadataInvocation(BaseInvocation):
"""Collects core generation metadata into a MetadataField"""

View File

@ -1,13 +1,13 @@
import copy
from typing import List, Optional
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, Field
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.model_manager import SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -17,12 +17,8 @@ from .baseinvocation import (
class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Info to load submodel")
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
model_config = ConfigDict(protected_namespaces=())
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
class LoraInfo(ModelInfo):
@ -52,7 +48,7 @@ class VaeField(BaseModel):
@invocation_output("unet_output")
class UNetOutput(BaseInvocationOutput):
"""Base class for invocations that output a UNet field"""
"""Base class for invocations that output a UNet field."""
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
@ -81,20 +77,13 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
class MainModelField(BaseModel):
"""Main model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Model Type")
model_config = ConfigDict(protected_namespaces=())
key: str = Field(description="Model key")
class LoRAModelField(BaseModel):
"""LoRA model field"""
model_name: str = Field(description="Name of the LoRA model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
key: str = Field(description="LoRA model key")
@invocation(
@ -111,85 +100,40 @@ class MainModelLoaderInvocation(BaseInvocation):
# TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
key = self.model.key
# TODO: not found exceptions
if not context.models.exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
"""
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Tokenizer,
):
raise Exception(
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder,
):
raise Exception(
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.UNet,
):
raise Exception(
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
)
"""
if not context.models.exists(key):
raise Exception(f"Unknown model {key}")
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.UNet,
key=key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Scheduler,
key=key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer,
key=key,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder,
key=key,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Vae,
key=key,
submodel_type=SubModelType.Vae,
),
),
)
@ -226,21 +170,16 @@ class LoraLoaderInvocation(BaseInvocation):
if self.lora is None:
raise Exception("No LoRA provided")
base_model = self.lora.base_model
lora_name = self.lora.model_name
lora_key = self.lora.key
if not context.models.exists(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
):
raise Exception(f"Unkown lora name: {lora_name}!")
if not context.models.exists(lora_key):
raise Exception(f"Unkown lora: {lora_key}!")
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip')
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip')
output = LoraLoaderOutput()
@ -248,10 +187,8 @@ class LoraLoaderInvocation(BaseInvocation):
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
key=lora_key,
submodel_type=None,
weight=self.weight,
)
)
@ -260,10 +197,8 @@ class LoraLoaderInvocation(BaseInvocation):
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
key=lora_key,
submodel_type=None,
weight=self.weight,
)
)
@ -315,24 +250,19 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
if self.lora is None:
raise Exception("No LoRA provided")
base_model = self.lora.base_model
lora_name = self.lora.model_name
lora_key = self.lora.key
if not context.models.exists(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
):
raise Exception(f"Unknown lora name: {lora_name}!")
if not context.models.exists(lora_key):
raise Exception(f"Unknown lora: {lora_key}!")
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip')
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip')
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip2')
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip2')
output = SDXLLoraLoaderOutput()
@ -340,10 +270,8 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
key=lora_key,
submodel_type=None,
weight=self.weight,
)
)
@ -352,10 +280,8 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
key=lora_key,
submodel_type=None,
weight=self.weight,
)
)
@ -364,10 +290,8 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
output.clip2 = copy.deepcopy(self.clip2)
output.clip2.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
key=lora_key,
submodel_type=None,
weight=self.weight,
)
)
@ -378,10 +302,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
class VAEModelField(BaseModel):
"""Vae model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
key: str = Field(description="Model's key")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
@ -395,25 +316,12 @@ class VaeLoaderInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> VAEOutput:
base_model = self.vae_model.base_model
model_name = self.vae_model.model_name
model_type = ModelType.Vae
key = self.vae_model.key
if not context.models.exists(
base_model=base_model,
model_name=model_name,
model_type=model_type,
):
raise Exception(f"Unkown vae name: {model_name}!")
return VAEOutput(
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
)
)
if not context.models.exists(key):
raise Exception(f"Unkown vae: {key}!")
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
@invocation_output("seamless_output")

View File

@ -299,9 +299,13 @@ class DenoiseMaskOutput(BaseInvocationOutput):
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
@classmethod
def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "DenoiseMaskOutput":
def build(
cls, mask_name: str, masked_latents_name: Optional[str] = None, gradient: bool = False
) -> "DenoiseMaskOutput":
return cls(
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name),
denoise_mask=DenoiseMaskField(
mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=gradient
),
)

View File

@ -1,7 +1,7 @@
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import SubModelType
from ...backend.model_management import ModelType, SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -40,72 +40,52 @@ class SDXLModelLoaderInvocation(BaseInvocation):
# TODO: precision?
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
model_key = self.model.key
# TODO: not found exceptions
if not context.models.exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
return SDXLModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.UNet,
key=model_key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Scheduler,
key=model_key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer,
key=model_key,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder,
key=model_key,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
clip2=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer2,
key=model_key,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder2,
key=model_key,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Vae,
key=model_key,
submodel_type=SubModelType.Vae,
),
),
)
@ -129,56 +109,40 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
# TODO: precision?
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
model_key = self.model.key
# TODO: not found exceptions
if not context.models.exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
return SDXLRefinerModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.UNet,
key=model_key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Scheduler,
key=model_key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip2=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer2,
key=model_key,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder2,
key=model_key,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Vae,
key=model_key,
submodel_type=SubModelType.Vae,
),
),
)

View File

@ -1,6 +1,6 @@
from typing import Union
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@ -12,14 +12,10 @@ from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESI
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_management.models.base import BaseModelType
class T2IAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the T2I-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
key: str = Field(description="Model record key for the T2I-Adapter model")
class T2IAdapterField(BaseModel):

View File

@ -0,0 +1,44 @@
from abc import ABC, abstractmethod
from typing import Optional
class BulkDownloadBase(ABC):
"""Responsible for creating a zip file containing the images specified by the given image names or board id."""
@abstractmethod
def handler(
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
) -> None:
"""
Create a zip file containing the images specified by the given image names or board id.
:param image_names: A list of image names to include in the zip file.
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
:param bulk_download_item_id: The bulk_download_item_id that will be used to retrieve the bulk download item when it is prepared, if none is provided a uuid will be generated.
"""
@abstractmethod
def get_path(self, bulk_download_item_name: str) -> str:
"""
Get the path to the bulk download file.
:param bulk_download_item_name: The name of the bulk download item.
:return: The path to the bulk download file.
"""
@abstractmethod
def generate_item_id(self, board_id: Optional[str]) -> str:
"""
Generate an item ID for a bulk download item.
:param board_id: The ID of the board whose name is to be included in the item id.
:return: The generated item ID.
"""
@abstractmethod
def delete(self, bulk_download_item_name: str) -> None:
"""
Delete the bulk download file.
:param bulk_download_item_name: The name of the bulk download item.
"""

View File

@ -0,0 +1,25 @@
DEFAULT_BULK_DOWNLOAD_ID = "default"
class BulkDownloadException(Exception):
"""Exception raised when a bulk download fails."""
def __init__(self, message="Bulk download failed"):
super().__init__(message)
self.message = message
class BulkDownloadTargetException(BulkDownloadException):
"""Exception raised when a bulk download target is not found."""
def __init__(self, message="The bulk download target was not found"):
super().__init__(message)
self.message = message
class BulkDownloadParametersException(BulkDownloadException):
"""Exception raised when a bulk download parameter is invalid."""
def __init__(self, message="No image names or board ID provided"):
super().__init__(message)
self.message = message

View File

@ -0,0 +1,157 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Union
from zipfile import ZipFile
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
from invokeai.app.services.bulk_download.bulk_download_common import (
DEFAULT_BULK_DOWNLOAD_ID,
BulkDownloadException,
BulkDownloadParametersException,
BulkDownloadTargetException,
)
from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invoker import Invoker
from invokeai.app.util.misc import uuid_string
from .bulk_download_base import BulkDownloadBase
class BulkDownloadService(BulkDownloadBase):
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def __init__(self):
self._temp_directory = TemporaryDirectory()
self._bulk_downloads_folder = Path(self._temp_directory.name) / "bulk_downloads"
self._bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
def handler(
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
) -> None:
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
bulk_download_item_id = bulk_download_item_id or uuid_string()
bulk_download_item_name = bulk_download_item_id + ".zip"
self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
try:
image_dtos: list[ImageDTO] = []
if board_id:
image_dtos = self._board_handler(board_id)
elif image_names:
image_dtos = self._image_handler(image_names)
else:
raise BulkDownloadParametersException()
bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id)
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
except (
ImageRecordNotFoundException,
BoardRecordNotFoundException,
BulkDownloadException,
BulkDownloadParametersException,
) as e:
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e)
except Exception as e:
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e)
self._invoker.services.logger.error("Problem bulk downloading images.")
raise e
def _image_handler(self, image_names: list[str]) -> list[ImageDTO]:
return [self._invoker.services.images.get_dto(image_name) for image_name in image_names]
def _board_handler(self, board_id: str) -> list[ImageDTO]:
image_names = self._invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
return self._image_handler(image_names)
def generate_item_id(self, board_id: Optional[str]) -> str:
return uuid_string() if board_id is None else self._get_clean_board_name(board_id) + "_" + uuid_string()
def _get_clean_board_name(self, board_id: str) -> str:
if board_id == "none":
return "Uncategorized"
return self._clean_string_to_path_safe(self._invoker.services.board_records.get(board_id).board_name)
def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: str) -> str:
"""
Create a zip file containing the images specified by the given image names or board id.
If download with the same bulk_download_id already exists, it will be overwritten.
:return: The name of the zip file.
"""
zip_file_name = bulk_download_item_id + ".zip"
zip_file_path = self._bulk_downloads_folder / (zip_file_name)
with ZipFile(zip_file_path, "w") as zip_file:
for image_dto in image_dtos:
image_zip_path = Path(image_dto.image_category.value) / image_dto.image_name
image_disk_path = self._invoker.services.images.get_path(image_dto.image_name)
zip_file.write(image_disk_path, arcname=image_zip_path)
return str(zip_file_name)
# from https://stackoverflow.com/questions/7406102/create-sane-safe-filename-from-any-unsafe-string
def _clean_string_to_path_safe(self, s: str) -> str:
"""Clean a string to be path safe."""
return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " " or c == "_" or c == "-"]).rstrip()
def _signal_job_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Signal that a bulk download job has started."""
if self._invoker:
assert bulk_download_id is not None
self._invoker.services.events.emit_bulk_download_started(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
def _signal_job_completed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Signal that a bulk download job has completed."""
if self._invoker:
assert bulk_download_id is not None
assert bulk_download_item_name is not None
self._invoker.services.events.emit_bulk_download_completed(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
def _signal_job_failed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, exception: Exception
) -> None:
"""Signal that a bulk download job has failed."""
if self._invoker:
assert bulk_download_id is not None
assert exception is not None
self._invoker.services.events.emit_bulk_download_failed(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=str(exception),
)
def stop(self, *args, **kwargs):
self._temp_directory.cleanup()
def delete(self, bulk_download_item_name: str) -> None:
path = self.get_path(bulk_download_item_name)
Path(path).unlink()
def get_path(self, bulk_download_item_name: str) -> str:
path = str(self._bulk_downloads_folder / bulk_download_item_name)
if not self._is_valid_path(path):
raise BulkDownloadTargetException()
return path
def _is_valid_path(self, path: Union[str, Path]) -> bool:
"""Validates the path given for a bulk download."""
path = path if isinstance(path, Path) else Path(path)
return path.exists()

View File

@ -27,11 +27,11 @@ class InvokeAISettings(BaseSettings):
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
initconf: ClassVar[Optional[DictConfig]] = None
argparse_groups: ClassVar[Dict] = {}
argparse_groups: ClassVar[Dict[str, Any]] = {}
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
def parse_args(self, argv: Optional[List[str]] = sys.argv[1:]) -> None:
"""Call to parse command-line arguments."""
parser = self.get_parser()
opt, unknown_opts = parser.parse_known_args(argv)
@ -68,7 +68,7 @@ class InvokeAISettings(BaseSettings):
return OmegaConf.to_yaml(conf)
@classmethod
def add_parser_arguments(cls, parser):
def add_parser_arguments(cls, parser: ArgumentParser) -> None:
"""Dynamically create arguments for a settings parser."""
if "type" in get_type_hints(cls):
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
@ -117,7 +117,8 @@ class InvokeAISettings(BaseSettings):
"""Return the category of a setting."""
hints = get_type_hints(cls)
if command_field in hints:
return get_args(hints[command_field])[0]
result: str = get_args(hints[command_field])[0]
return result
else:
return "Uncategorized"
@ -155,10 +156,11 @@ class InvokeAISettings(BaseSettings):
"lora_dir",
"embedding_dir",
"controlnet_dir",
"conf_path",
]
@classmethod
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
def add_field_argument(cls, command_parser, name: str, field, default_override=None) -> None:
"""Add the argparse arguments for a setting parser."""
field_type = get_type_hints(cls).get(name)
default = (

View File

@ -21,7 +21,7 @@ class PagingArgumentParser(argparse.ArgumentParser):
It also supports reading defaults from an init file.
"""
def print_help(self, file=None):
def print_help(self, file=None) -> None:
text = self.format_help()
pydoc.pager(text)

View File

@ -30,7 +30,6 @@ InvokeAI:
lora_dir: null
embedding_dir: null
controlnet_dir: null
conf_path: configs/models.yaml
models_dir: models
legacy_conf_dir: configs/stable-diffusion
db_dir: databases
@ -123,7 +122,6 @@ a Path object:
root_path - path to InvokeAI root
output_path - path to default outputs directory
model_conf_path - path to models.yaml
conf - alias for the above
embedding_path - path to the embeddings directory
lora_path - path to the LoRA directory
@ -163,7 +161,6 @@ two configs are kept in separate sections of the config file:
InvokeAI:
Paths:
root: /home/lstein/invokeai-main
conf_path: configs/models.yaml
legacy_conf_dir: configs/stable-diffusion
outdir: outputs
...
@ -173,7 +170,7 @@ from __future__ import annotations
import os
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union
from typing import Any, ClassVar, Dict, List, Literal, Optional
from omegaconf import DictConfig, OmegaConf
from pydantic import Field
@ -185,7 +182,9 @@ from .config_base import InvokeAISettings
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_MAX_VRAM = 0.5
DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0
class Categories(object):
@ -235,8 +234,8 @@ class InvokeAIAppConfig(InvokeAISettings):
# PATHS
root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths)
autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths)
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
@ -260,8 +259,10 @@ class InvokeAIAppConfig(InvokeAISettings):
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
# CACHE
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
convert_cache : float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache)
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache)
@ -296,6 +297,7 @@ class InvokeAIAppConfig(InvokeAISettings):
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
# this is not referred to in the source code and can be removed entirely
#free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
@ -404,6 +406,11 @@ class InvokeAIAppConfig(InvokeAISettings):
"""Path to the models directory."""
return self._resolve(self.models_dir)
@property
def models_convert_cache_path(self) -> Path:
"""Path to the converted cache models directory."""
return self._resolve(self.convert_cache_dir)
@property
def custom_nodes_path(self) -> Path:
"""Path to the custom nodes directory."""
@ -433,15 +440,20 @@ class InvokeAIAppConfig(InvokeAISettings):
return True
@property
def ram_cache_size(self) -> Union[Literal["auto"], float]:
"""Return the ram cache size using the legacy or modern setting."""
def ram_cache_size(self) -> float:
"""Return the ram cache size using the legacy or modern setting (GB)."""
return self.max_cache_size or self.ram
@property
def vram_cache_size(self) -> Union[Literal["auto"], float]:
"""Return the vram cache size using the legacy or modern setting."""
def vram_cache_size(self) -> float:
"""Return the vram cache size using the legacy or modern setting (GB)."""
return self.max_vram_cache_size or self.vram
@property
def convert_cache_size(self) -> float:
"""Return the convert cache size on disk (GB)."""
return self.convert_cache
@property
def use_cpu(self) -> bool:
"""Return true if the device is set to CPU or the always_use_cpu flag is set."""

View File

@ -260,3 +260,16 @@ class DownloadQueueServiceBase(ABC):
def join(self) -> None:
"""Wait until all jobs are off the queue."""
pass
@abstractmethod
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
"""Wait until the indicated download job has reached a terminal state.
This will block until the indicated install job has completed,
been cancelled, or errored out.
:param job: The job to wait on.
:param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if
the job hasn't completed within the indicated time.
"""
pass

View File

@ -4,10 +4,11 @@
import os
import re
import threading
import time
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set
import requests
from pydantic.networks import AnyHttpUrl
@ -48,11 +49,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param requests_session: Optional requests.sessions.Session object, for unit tests.
"""
self._jobs = {}
self._jobs: Dict[int, DownloadJob] = {}
self._next_job_id = 0
self._queue = PriorityQueue()
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
self._stop_event = threading.Event()
self._worker_pool = set()
self._job_completed_event = threading.Event()
self._worker_pool: Set[threading.Thread] = set()
self._lock = threading.Lock()
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
self._event_bus = event_bus
@ -188,6 +190,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
if not job.in_terminal_state:
self.cancel_job(job)
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
start = time.time()
while not job.in_terminal_state:
if self._job_completed_event.wait(timeout=0.25): # in case we miss an event
self._job_completed_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded")
return job
def _start_workers(self, max_workers: int) -> None:
"""Start the requested number of worker threads."""
self._stop_event.clear()
@ -223,6 +235,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
finally:
job.job_ended = get_iso_timestamp()
self._job_completed_event.set() # signal a change to terminal state
self._queue.task_done()
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
@ -407,11 +420,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
# Example on_progress event handler to display a TQDM status bar
# Activate with:
# download_service.download('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().job_update
# download_service.download(DownloadJob('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().update))
class TqdmProgress(object):
"""TQDM-based progress bar object to use in on_progress handlers."""
_bars: Dict[int, tqdm] # the tqdm object
_bars: Dict[int, tqdm] # type: ignore
_last: Dict[int, int] # last bytes downloaded
def __init__(self) -> None: # noqa D107

View File

@ -3,7 +3,7 @@
from typing import Any, Dict, List, Optional, Union
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
@ -11,12 +11,12 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueStatus,
)
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_management.model_manager import LoadedModelInfo
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager import AnyModelConfig
class EventServiceBase:
queue_event: str = "queue_event"
bulk_download_event: str = "bulk_download_event"
download_event: str = "download_event"
model_event: str = "model_event"
@ -25,6 +25,14 @@ class EventServiceBase:
def dispatch(self, event_name: str, payload: Any) -> None:
pass
def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None:
"""Bulk download events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.bulk_download_event,
payload={"event": event_name, "data": payload},
)
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
"""Queue events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
@ -171,10 +179,7 @@ class EventServiceBase:
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
model_config: AnyModelConfig,
) -> None:
"""Emitted when a model is requested"""
self.__emit_queue_event(
@ -184,10 +189,7 @@ class EventServiceBase:
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_name": model_name,
"base_model": base_model,
"model_type": model_type,
"submodel": submodel,
"model_config": model_config.model_dump(),
},
)
@ -197,11 +199,7 @@ class EventServiceBase:
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
loaded_model_info: LoadedModelInfo,
model_config: AnyModelConfig,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event(
@ -211,59 +209,7 @@ class EventServiceBase:
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_name": model_name,
"base_model": base_model,
"model_type": model_type,
"submodel": submodel,
"hash": loaded_model_info.hash,
"location": str(loaded_model_info.location),
"precision": str(loaded_model_info.precision),
},
)
def emit_session_retrieval_error(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when session retrieval fails"""
self.__emit_queue_event(
event_name="session_retrieval_error",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"error_type": error_type,
"error": error,
},
)
def emit_invocation_retrieval_error(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when invocation retrieval fails"""
self.__emit_queue_event(
event_name="invocation_retrieval_error",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node_id": node_id,
"error_type": error_type,
"error": error,
"model_config": model_config.model_dump(),
},
)
@ -411,6 +357,7 @@ class EventServiceBase:
bytes: int,
total_bytes: int,
parts: List[Dict[str, Union[str, int]]],
id: int,
) -> None:
"""
Emit at intervals while the install job is in progress (remote models only).
@ -430,6 +377,7 @@ class EventServiceBase:
"bytes": bytes,
"total_bytes": total_bytes,
"parts": parts,
"id": id,
},
)
@ -444,7 +392,7 @@ class EventServiceBase:
payload={"source": source},
)
def emit_model_install_completed(self, source: str, key: str, total_bytes: Optional[int] = None) -> None:
def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None:
"""
Emit when an install job is completed successfully.
@ -454,11 +402,7 @@ class EventServiceBase:
"""
self.__emit_model_event(
event_name="model_install_completed",
payload={
"source": source,
"total_bytes": total_bytes,
"key": key,
},
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
)
def emit_model_install_cancelled(self, source: str) -> None:
@ -472,12 +416,7 @@ class EventServiceBase:
payload={"source": source},
)
def emit_model_install_error(
self,
source: str,
error_type: str,
error: str,
) -> None:
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:
"""
Emit when an install job encounters an exception.
@ -487,9 +426,45 @@ class EventServiceBase:
"""
self.__emit_model_event(
event_name="model_install_error",
payload={"source": source, "error_type": error_type, "error": error, "id": id},
)
def emit_bulk_download_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk download starts"""
self._emit_bulk_download_event(
event_name="bulk_download_started",
payload={
"source": source,
"error_type": error_type,
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
def emit_bulk_download_completed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk download completes"""
self._emit_bulk_download_event(
event_name="bulk_download_completed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
def emit_bulk_download_failed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> None:
"""Emitted when a bulk download fails"""
self._emit_bulk_download_event(
event_name="bulk_download_failed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
"error": error,
},
)

View File

@ -1,5 +0,0 @@
from abc import ABC
class InvocationProcessorABC(ABC): # noqa: B024
pass

View File

@ -1,15 +0,0 @@
from pydantic import BaseModel, Field
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")
class CanceledException(Exception):
"""Execution canceled by user."""
pass

View File

@ -1,241 +0,0 @@
import time
import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Event, Thread
from typing import Optional
import invokeai.backend.util.logging as logger
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
from invokeai.app.services.invocation_stats.invocation_stats_common import (
GESStatsNotFoundError,
)
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.util.profiler import Profiler
from ..invoker import Invoker
from .invocation_processor_base import InvocationProcessorABC
from .invocation_processor_common import CanceledException
class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread
__stop_event: Event
__invoker: Invoker
__threadLimit: BoundedSemaphore
def start(self, invoker: Invoker) -> None:
# if we do want multithreading at some point, we could make this configurable
self.__threadLimit = BoundedSemaphore(1)
self.__invoker = invoker
self.__stop_event = Event()
self.__invoker_thread = Thread(
name="invoker_processor",
target=self.__process,
kwargs={"stop_event": self.__stop_event},
)
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
self.__invoker_thread.start()
def stop(self, *args, **kwargs) -> None:
self.__stop_event.set()
def __process(self, stop_event: Event):
try:
self.__threadLimit.acquire()
queue_item: Optional[InvocationQueueItem] = None
profiler = (
Profiler(
logger=self.__invoker.services.logger,
output_dir=self.__invoker.services.configuration.profiles_path,
prefix=self.__invoker.services.configuration.profile_prefix,
)
if self.__invoker.services.configuration.profile_graphs
else None
)
def stats_cleanup(graph_execution_state_id: str) -> None:
if profiler:
profile_path = profiler.stop()
stats_path = profile_path.with_suffix(".json")
self.__invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=graph_execution_state_id, output_path=stats_path
)
with suppress(GESStatsNotFoundError):
self.__invoker.services.performance_statistics.log_stats(graph_execution_state_id)
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state_id)
while not stop_event.is_set():
try:
queue_item = self.__invoker.services.queue.get()
except Exception as e:
self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
if not queue_item: # Probably stopping
# do not hammer the queue
time.sleep(0.5)
continue
if profiler and profiler.profile_id != queue_item.graph_execution_state_id:
profiler.start(profile_id=queue_item.graph_execution_state_id)
try:
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
queue_item.graph_execution_state_id
)
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
self.__invoker.services.events.emit_session_retrieval_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=queue_item.graph_execution_state_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
continue
try:
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
self.__invoker.services.events.emit_invocation_retrieval_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=queue_item.graph_execution_state_id,
node_id=queue_item.invocation_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
continue
# get the source node id to provide to clients (the prepared node id is not as useful)
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
# Send starting event
self.__invoker.services.events.emit_invocation_started(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.model_dump(),
source_node_id=source_node_id,
)
# Invoke
try:
graph_id = graph_execution_state.id
with self.__invoker.services.performance_statistics.collect_stats(invocation, graph_id):
# use the internal invoke_internal(), which wraps the node's invoke() method,
# which handles a few things:
# - nodes that require a value, but get it only from a connection
# - referencing the invocation cache instead of executing the node
context_data = InvocationContextData(
invocation=invocation,
session_id=graph_id,
workflow=queue_item.workflow,
source_node_id=source_node_id,
queue_id=queue_item.session_queue_id,
queue_item_id=queue_item.session_queue_item_id,
batch_id=queue_item.session_queue_batch_id,
)
context = build_invocation_context(
services=self.__invoker.services,
context_data=context_data,
)
outputs = invocation.invoke_internal(context=context, services=self.__invoker.services)
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
continue
# Save outputs and history
graph_execution_state.complete(invocation.id, outputs)
# Save the state changes
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.model_dump(),
source_node_id=source_node_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
pass
except CanceledException:
stats_cleanup(graph_execution_state.id)
pass
except Exception as e:
error = traceback.format_exc()
logger.error(error)
# Save error
graph_execution_state.set_node_error(invocation.id, error)
# Save the state changes
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
# Send error event
self.__invoker.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.model_dump(),
source_node_id=source_node_id,
error_type=e.__class__.__name__,
error=error,
)
pass
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
continue
# Queue any further commands if invoking all
is_complete = graph_execution_state.is_complete()
if queue_item.invoke_all and not is_complete:
try:
self.__invoker.invoke(
session_queue_batch_id=queue_item.session_queue_batch_id,
session_queue_item_id=queue_item.session_queue_item_id,
session_queue_id=queue_item.session_queue_id,
graph_execution_state=graph_execution_state,
workflow=queue_item.workflow,
invoke_all=True,
)
except Exception as e:
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
self.__invoker.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.model_dump(),
source_node_id=source_node_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
)
stats_cleanup(graph_execution_state.id)
except KeyboardInterrupt:
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
finally:
self.__threadLimit.release()

View File

@ -1,26 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC, abstractmethod
from typing import Optional
from .invocation_queue_common import InvocationQueueItem
class InvocationQueueABC(ABC):
"""Abstract base class for all invocation queues"""
@abstractmethod
def get(self) -> InvocationQueueItem:
pass
@abstractmethod
def put(self, item: Optional[InvocationQueueItem]) -> None:
pass
@abstractmethod
def cancel(self, graph_execution_state_id: str) -> None:
pass
@abstractmethod
def is_canceled(self, graph_execution_state_id: str) -> bool:
pass

View File

@ -1,23 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import time
from typing import Optional
from pydantic import BaseModel, Field
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class InvocationQueueItem(BaseModel):
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
invocation_id: str = Field(description="The ID of the node being invoked")
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
session_queue_item_id: int = Field(
description="The ID of session queue item from which this invocation queue item came"
)
session_queue_batch_id: str = Field(
description="The ID of the session batch from which this invocation queue item came"
)
workflow: Optional[WorkflowWithoutID] = Field(description="The workflow associated with this queue item")
invoke_all: bool = Field(default=False)
timestamp: float = Field(default_factory=time.time)

View File

@ -1,44 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import time
from queue import Queue
from typing import Optional
from .invocation_queue_base import InvocationQueueABC
from .invocation_queue_common import InvocationQueueItem
class MemoryInvocationQueue(InvocationQueueABC):
__queue: Queue
__cancellations: dict[str, float]
def __init__(self):
self.__queue = Queue()
self.__cancellations = {}
def get(self) -> InvocationQueueItem:
item = self.__queue.get()
while (
isinstance(item, InvocationQueueItem)
and item.graph_execution_state_id in self.__cancellations
and self.__cancellations[item.graph_execution_state_id] > item.timestamp
):
item = self.__queue.get()
# Clear old items
for graph_execution_state_id in list(self.__cancellations.keys()):
if self.__cancellations[graph_execution_state_id] < item.timestamp:
del self.__cancellations[graph_execution_state_id]
return item
def put(self, item: Optional[InvocationQueueItem]) -> None:
self.__queue.put(item)
def cancel(self, graph_execution_state_id: str) -> None:
if graph_execution_state_id not in self.__cancellations:
self.__cancellations[graph_execution_state_id] = time.time()
def is_canceled(self, graph_execution_state_id: str) -> bool:
return graph_execution_state_id in self.__cancellations

View File

@ -16,6 +16,7 @@ if TYPE_CHECKING:
from .board_images.board_images_base import BoardImagesServiceABC
from .board_records.board_records_base import BoardRecordStorageBase
from .boards.boards_base import BoardServiceABC
from .bulk_download.bulk_download_base import BulkDownloadBase
from .config import InvokeAIAppConfig
from .download import DownloadQueueServiceBase
from .events.events_base import EventServiceBase
@ -23,17 +24,11 @@ if TYPE_CHECKING:
from .image_records.image_records_base import ImageRecordStorageBase
from .images.images_base import ImageServiceABC
from .invocation_cache.invocation_cache_base import InvocationCacheBase
from .invocation_processor.invocation_processor_base import InvocationProcessorABC
from .invocation_queue.invocation_queue_base import InvocationQueueABC
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .item_storage.item_storage_base import ItemStorageABC
from .model_install import ModelInstallServiceBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase
from .shared.graph import GraphExecutionState
from .urls.urls_base import UrlServiceBase
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
@ -47,20 +42,16 @@ class InvocationServices:
board_image_records: "BoardImageRecordStorageBase",
boards: "BoardServiceABC",
board_records: "BoardRecordStorageBase",
bulk_download: "BulkDownloadBase",
configuration: "InvokeAIAppConfig",
events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
images: "ImageServiceABC",
image_files: "ImageFileStorageBase",
image_records: "ImageRecordStorageBase",
logger: "Logger",
model_manager: "ModelManagerServiceBase",
model_records: "ModelRecordServiceBase",
download_queue: "DownloadQueueServiceBase",
model_install: "ModelInstallServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
session_queue: "SessionQueueBase",
session_processor: "SessionProcessorBase",
invocation_cache: "InvocationCacheBase",
@ -74,20 +65,16 @@ class InvocationServices:
self.board_image_records = board_image_records
self.boards = boards
self.board_records = board_records
self.bulk_download = bulk_download
self.configuration = configuration
self.events = events
self.graph_execution_manager = graph_execution_manager
self.images = images
self.image_files = image_files
self.image_records = image_records
self.logger = logger
self.model_manager = model_manager
self.model_records = model_records
self.download_queue = download_queue
self.model_install = model_install
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue
self.session_queue = session_queue
self.session_processor = session_processor
self.invocation_cache = invocation_cache

View File

@ -3,7 +3,7 @@
Usage:
statistics = InvocationStatsService(graph_execution_manager)
statistics = InvocationStatsService()
with statistics.collect_stats(invocation, graph_execution_state.id):
... execute graphs...
statistics.log_stats()
@ -29,8 +29,8 @@ writes to the system log is stored in InvocationServices.performance_statistics.
"""
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager
from pathlib import Path
from typing import ContextManager
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary
@ -40,18 +40,17 @@ class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics"
@abstractmethod
def __init__(self):
def __init__(self) -> None:
"""
Initialize the InvocationStatsService and reset counters to zero
"""
pass
@abstractmethod
def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
) -> AbstractContextManager:
) -> ContextManager[None]:
"""
Return a context object that will capture the statistics on the execution
of invocaation. Use with: to place around the part of the code that executes the invocation.
@ -61,16 +60,12 @@ class InvocationStatsServiceBase(ABC):
pass
@abstractmethod
def reset_stats(self, graph_execution_state_id: str):
"""
Reset all statistics for the indicated graph.
:param graph_execution_state_id: The id of the session whose stats to reset.
:raises GESStatsNotFoundError: if the graph isn't tracked in the stats.
"""
def reset_stats(self):
"""Reset all stored statistics."""
pass
@abstractmethod
def log_stats(self, graph_execution_state_id: str):
def log_stats(self, graph_execution_state_id: str) -> None:
"""
Write out the accumulated statistics to the log or somewhere else.
:param graph_execution_state_id: The id of the session whose stats to log.

View File

@ -2,6 +2,7 @@ import json
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Generator
import psutil
import torch
@ -9,8 +10,7 @@ import torch
import invokeai.backend.util.logging as logger
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError
from invokeai.backend.model_management.model_cache import CacheStats
from invokeai.backend.model_manager.load.model_cache import CacheStats
from .invocation_stats_base import InvocationStatsServiceBase
from .invocation_stats_common import (
@ -41,22 +41,23 @@ class InvocationStatsService(InvocationStatsServiceBase):
self._invoker = invoker
@contextmanager
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str):
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Generator[None, None, None]:
# This is to handle case of the model manager not being initialized, which happens
# during some tests.
services = self._invoker.services
if not self._stats.get(graph_execution_state_id):
# First time we're seeing this graph_execution_state_id.
self._stats[graph_execution_state_id] = GraphExecutionStats()
self._cache_stats[graph_execution_state_id] = CacheStats()
# Prune stale stats. There should be none since we're starting a new graph, but just in case.
self._prune_stale_stats()
# Record state before the invocation.
start_time = time.time()
start_ram = psutil.Process().memory_info().rss
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
if self._invoker.services.model_manager:
self._invoker.services.model_manager.collect_cache_stats(self._cache_stats[graph_execution_state_id])
assert services.model_manager.load is not None
services.model_manager.load.ram_cache.stats = self._cache_stats[graph_execution_state_id]
try:
# Let the invocation run.
@ -73,42 +74,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
)
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
def _prune_stale_stats(self):
"""Check all graphs being tracked and prune any that have completed/errored.
This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so
for now we call this function periodically to prevent them from accumulating.
"""
to_prune: list[str] = []
for graph_execution_state_id in self._stats:
try:
graph_execution_state = self._invoker.services.graph_execution_manager.get(graph_execution_state_id)
except ItemNotFoundError:
# TODO(ryand): What would cause this? Should this exception just be allowed to propagate?
logger.warning(f"Failed to get graph state for {graph_execution_state_id}.")
continue
if not graph_execution_state.is_complete():
# The graph is still running, don't prune it.
continue
to_prune.append(graph_execution_state_id)
for graph_execution_state_id in to_prune:
del self._stats[graph_execution_state_id]
del self._cache_stats[graph_execution_state_id]
if len(to_prune) > 0:
logger.info(f"Pruned stale graph stats for {to_prune}.")
def reset_stats(self, graph_execution_state_id: str):
try:
del self._stats[graph_execution_state_id]
del self._cache_stats[graph_execution_state_id]
except KeyError as e:
raise GESStatsNotFoundError(
f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}."
) from e
def reset_stats(self):
self._stats = {}
self._cache_stats = {}
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)

View File

@ -1,12 +1,7 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Optional
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from .invocation_queue.invocation_queue_common import InvocationQueueItem
from .invocation_services import InvocationServices
from .shared.graph import Graph, GraphExecutionState
class Invoker:
@ -18,51 +13,6 @@ class Invoker:
self.services = services
self._start()
def invoke(
self,
session_queue_id: str,
session_queue_item_id: int,
session_queue_batch_id: str,
graph_execution_state: GraphExecutionState,
workflow: Optional[WorkflowWithoutID] = None,
invoke_all: bool = False,
) -> Optional[str]:
"""Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
# Get the next invocation
invocation = graph_execution_state.next()
if not invocation:
return None
# Save the execution state
self.services.graph_execution_manager.set(graph_execution_state)
# Queue the invocation
self.services.queue.put(
InvocationQueueItem(
session_queue_id=session_queue_id,
session_queue_item_id=session_queue_item_id,
session_queue_batch_id=session_queue_batch_id,
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
workflow=workflow,
invoke_all=invoke_all,
)
)
return invocation.id
def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState:
"""Creates a new execution state for the given graph"""
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
self.services.graph_execution_manager.set(new_state)
return new_state
def cancel(self, graph_execution_state_id: str) -> None:
"""Cancels the given execution state"""
self.services.queue.cancel(graph_execution_state_id)
def __start_service(self, service) -> None:
# Call start() method on any services that have it
start_op = getattr(service, "start", None)
@ -85,5 +35,3 @@ class Invoker:
# First stop all services
for service in vars(self.services):
self.__stop_service(getattr(self.services, service))
self.services.queue.put(None)

View File

@ -14,11 +14,13 @@ from typing_extensions import Annotated
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..model_metadata import ModelMetadataStoreBase
class InstallStatus(str, Enum):
@ -127,8 +129,8 @@ class HFModelSource(StringLikeSource):
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
base += f":{self.variant or ''}"
base += f":{self.subfolder}" if self.subfolder else ""
base += f" ({self.variant})" if self.variant else ""
return base
@ -154,6 +156,7 @@ class ModelInstallJob(BaseModel):
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
@ -175,6 +178,12 @@ class ModelInstallJob(BaseModel):
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
error_traceback: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the exception traceback"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
@ -182,7 +191,10 @@ class ModelInstallJob(BaseModel):
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.error = str(e)
self.error_traceback = self._format_error(e)
self.status = InstallStatus.ERROR
self.error_reason = self._exception.__class__.__name__ if self._exception else None
def cancel(self) -> None:
"""Call to cancel the job."""
@ -193,10 +205,9 @@ class ModelInstallJob(BaseModel):
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
@property
def error(self) -> Optional[str]:
def _format_error(self, exception: Exception) -> str:
"""Error traceback."""
return "".join(traceback.format_exception(self._exception)) if self._exception else None
return "".join(traceback.format_exception(exception))
@property
def cancelled(self) -> bool:
@ -243,7 +254,7 @@ class ModelInstallServiceBase(ABC):
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
metadata_store: ModelMetadataStore,
metadata_store: ModelMetadataStoreBase,
event_bus: Optional["EventServiceBase"] = None,
):
"""
@ -324,6 +335,43 @@ class ModelInstallServiceBase(ABC):
:returns id: The string ID of the registered model.
"""
@abstractmethod
def heuristic_import(
self,
source: str,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
) -> ModelInstallJob:
r"""Install the indicated model using heuristics to interpret user intentions.
:param source: String source
:param config: Optional dict. Any fields in this dict
will override corresponding autoassigned probe fields in the
model's config record as described in `import_model()`.
:param access_token: Optional access token for remote sources.
The source can be:
1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`)
2. An http or https URL (`https://foo.bar/foo`)
3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`)
We extend the HuggingFace repo_id syntax to include the variant and the
subfolder or path. The following are acceptable alternatives:
stabilityai/stable-diffusion-v4
stabilityai/stable-diffusion-v4:fp16
stabilityai/stable-diffusion-v4:fp16:vae
stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
stabilityai/stable-diffusion-v4:onnx:vae
Because a local file path can look like a huggingface repo_id, the logic
first checks whether the path exists on disk, and if not, it is treated as
a parseable huggingface repo.
The previous support for recursing into a local folder and loading all model-like files
has been removed.
"""
pass
@abstractmethod
def import_model(
self,
@ -385,6 +433,18 @@ class ModelInstallServiceBase(ABC):
def cancel_job(self, job: ModelInstallJob) -> None:
"""Cancel the indicated job."""
@abstractmethod
def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob:
"""Wait for the indicated job to reach a terminal state.
This will block until the indicated install job has completed,
been cancelled, or errored out.
:param job: The job to wait on.
:param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if
the job hasn't completed within the indicated time.
"""
@abstractmethod
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]:
"""
@ -394,7 +454,8 @@ class ModelInstallServiceBase(ABC):
completed, been cancelled, or errored out.
:param timeout: Wait up to indicated number of seconds. Raise an Exception('timeout') if
installs do not complete within the indicated time.
installs do not complete within the indicated time. A timeout of zero (the default)
will block indefinitely until the installs complete.
"""
@abstractmethod
@ -410,3 +471,22 @@ class ModelInstallServiceBase(ABC):
@abstractmethod
def sync_to_config(self) -> None:
"""Synchronize models on disk to those in the model record database."""
@abstractmethod
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
"""
Download the model file located at source to the models cache and return its Path.
:param source: A Url or a string that can be converted into one.
:param access_token: Optional access token to access restricted resources.
The model file will be downloaded into the system-wide model cache
(`models/.cache`) if it isn't already there. Note that the model cache
is periodically cleared of infrequently-used entries when the model
converter runs.
Note that this doesn't automaticallly install or register the model, but is
intended for use by nodes that need access to models that aren't directly
supported by InvokeAI. The downloading process takes advantage of the download queue
to avoid interrupting other operations.
"""

View File

@ -17,10 +17,10 @@ from pydantic.networks import AnyHttpUrl
from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
@ -33,7 +33,6 @@ from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
CivitaiMetadataFetch,
HuggingFaceMetadataFetch,
ModelMetadataStore,
ModelMetadataWithFiles,
RemoteModelFile,
)
@ -50,6 +49,7 @@ from .model_install_base import (
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
StringLikeSource,
URLModelSource,
)
@ -64,7 +64,6 @@ class ModelInstallService(ModelInstallServiceBase):
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
metadata_store: Optional[ModelMetadataStore] = None,
event_bus: Optional[EventServiceBase] = None,
session: Optional[Session] = None,
):
@ -86,19 +85,13 @@ class ModelInstallService(ModelInstallServiceBase):
self._lock = threading.Lock()
self._stop_event = threading.Event()
self._downloads_changed_event = threading.Event()
self._install_completed_event = threading.Event()
self._download_queue = download_queue
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
self._running = False
self._session = session
self._next_job_id = 0
# There may not necessarily be a metadata store initialized
# so we create one and initialize it with the same sql database
# used by the record store service.
if metadata_store:
self._metadata_store = metadata_store
else:
assert isinstance(record_store, ModelRecordServiceSQL)
self._metadata_store = ModelMetadataStore(record_store.db)
self._metadata_store = record_store.metadata_store # for convenience
@property
def app_config(self) -> InvokeAIAppConfig: # noqa D102
@ -145,7 +138,7 @@ class ModelInstallService(ModelInstallServiceBase):
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
if config.get("source") is None:
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
return self._register(model_path, config)
@ -156,12 +149,18 @@ class ModelInstallService(ModelInstallServiceBase):
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
if config.get("source") is None:
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
info: AnyModelConfig = self._probe_model(Path(model_path), config)
old_hash = info.original_hash
dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name
old_hash = info.current_hash
if preferred_name := config.get("name"):
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
dest_path = (
self.app_config.models_path / info.base.value / info.type.value / (preferred_name or model_path.name)
)
try:
new_path = self._copy_model(model_path, dest_path)
except FileExistsError as excp:
@ -177,7 +176,40 @@ class ModelInstallService(ModelInstallServiceBase):
info,
)
def heuristic_import(
self,
source: str,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
) -> ModelInstallJob:
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None
if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source))
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
subfolder=Path(match.group(3)) if match.group(3) else None,
access_token=access_token,
)
elif re.match(r"^https?://[^/]+", source):
source_obj = URLModelSource(
url=AnyHttpUrl(source),
access_token=access_token,
)
else:
raise ValueError(f"Unsupported model source: '{source}'")
return self.import_model(source_obj, config)
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
if similar_jobs:
self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.")
return similar_jobs[0]
if isinstance(source, LocalModelSource):
install_job = self._import_local_model(source, config)
self._install_queue.put(install_job) # synchronously install
@ -207,14 +239,25 @@ class ModelInstallService(ModelInstallServiceBase):
assert isinstance(jobs[0], ModelInstallJob)
return jobs[0]
def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob:
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
start = time.time()
while not job.in_terminal_state:
if self._install_completed_event.wait(timeout=5): # in case we miss an event
self._install_completed_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded")
return job
# TODO: Better name? Maybe wait_for_jobs()? Maybe too easily confused with above
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
"""Block until all installation jobs are done."""
start = time.time()
while len(self._download_cache) > 0:
if self._downloads_changed_event.wait(timeout=5): # in case we miss an event
if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event
self._downloads_changed_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise Exception("Timeout exceeded")
raise TimeoutError("Timeout exceeded")
self._install_queue.join()
return self._install_jobs
@ -268,6 +311,38 @@ class ModelInstallService(ModelInstallServiceBase):
path.unlink()
self.unregister(key)
def download_and_cache(
self,
source: Union[str, AnyHttpUrl],
access_token: Optional[str] = None,
timeout: int = 0,
) -> Path:
"""Download the model file located at source to the models cache and return its Path."""
model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32]
model_path = self._app_config.models_convert_cache_path / model_hash
# We expect the cache directory to contain one and only one downloaded file.
# We don't know the file's name in advance, as it is set by the download
# content-disposition header.
if model_path.exists():
contents = [x for x in model_path.iterdir() if x.is_file()]
if len(contents) > 0:
return contents[0]
model_path.mkdir(parents=True, exist_ok=True)
job = self._download_queue.download(
source=AnyHttpUrl(str(source)),
dest=model_path,
access_token=access_token,
on_progress=TqdmProgress().update,
)
self._download_queue.wait_for_job(job, timeout)
if job.complete:
assert job.download_path is not None
return job.download_path
else:
raise Exception(job.error)
# --------------------------------------------------------------------------------------------
# Internal functions that manage the installer threads
# --------------------------------------------------------------------------------------------
@ -300,6 +375,7 @@ class ModelInstallService(ModelInstallServiceBase):
job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes
self._signal_job_running(job)
job.config_in["source"] = str(job.source)
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:
@ -330,6 +406,7 @@ class ModelInstallService(ModelInstallServiceBase):
# if this is an install of a remote file, then clean up the temporary directory
if job._install_tmpdir is not None:
rmtree(job._install_tmpdir)
self._install_completed_event.set()
self._install_queue.task_done()
self._logger.info("Install thread exiting")
@ -465,8 +542,10 @@ class ModelInstallService(ModelInstallServiceBase):
def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str:
info = info or ModelProbe.probe(model_path, config)
key = self._create_key()
if config and not config.get("key", None):
config["key"] = key
info = info or ModelProbe.probe(model_path, config)
model_path = model_path.absolute()
if model_path.is_relative_to(self.app_config.models_path):
@ -479,8 +558,8 @@ class ModelInstallService(ModelInstallServiceBase):
# make config relative to our root
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
self.record_store.add_model(key, info)
return key
self.record_store.add_model(info.key, info)
return info.key
def _next_id(self) -> int:
with self._lock:
@ -489,10 +568,10 @@ class ModelInstallService(ModelInstallServiceBase):
return id
@staticmethod
def _guess_variant() -> ModelRepoVariant:
def _guess_variant() -> Optional[ModelRepoVariant]:
"""Guess the best HuggingFace variant type to download."""
precision = choose_precision(choose_torch_device())
return ModelRepoVariant.FP16 if precision == "float16" else ModelRepoVariant.DEFAULT
return ModelRepoVariant.FP16 if precision == "float16" else None
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
return ModelInstallJob(
@ -517,7 +596,7 @@ class ModelInstallService(ModelInstallServiceBase):
if not source.access_token:
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id)
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
assert isinstance(metadata, ModelMetadataWithFiles)
remote_files = metadata.download_urls(
variant=source.variant or self._guess_variant(),
@ -565,6 +644,8 @@ class ModelInstallService(ModelInstallServiceBase):
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
# Currently the tmpdir isn't automatically removed at exit because it is
# being held in a daemon thread.
if len(remote_files) == 0:
raise ValueError(f"{source}: No downloadable files found")
tmpdir = Path(
mkdtemp(
dir=self._app_config.models_path,
@ -580,6 +661,16 @@ class ModelInstallService(ModelInstallServiceBase):
bytes=0,
total_bytes=0,
)
# In the event that there is a subfolder specified in the source,
# we need to remove it from the destination path in order to avoid
# creating unwanted subfolders
if hasattr(source, "subfolder") and source.subfolder:
root = Path(remote_files[0].path.parts[0])
subfolder = root / source.subfolder
else:
root = Path(".")
subfolder = Path(".")
# we remember the path up to the top of the tmpdir so that it may be
# removed safely at the end of the install process.
install_job._install_tmpdir = tmpdir
@ -589,7 +680,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.debug(f"remote_files={remote_files}")
for model_file in remote_files:
url = model_file.url
path = model_file.path
path = root / model_file.path.relative_to(subfolder)
self._logger.info(f"Downloading {url} => {path}")
install_job.total_bytes += model_file.size
assert hasattr(source, "access_token")
@ -652,6 +743,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._signal_job_downloading(install_job)
def _download_complete_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"{download_job.source}: model download complete")
with self._lock:
install_job = self._download_cache[download_job.source]
self._download_cache.pop(download_job.source, None)
@ -684,7 +776,7 @@ class ModelInstallService(ModelInstallServiceBase):
if not install_job:
return
self._downloads_changed_event.set()
self._logger.warning(f"Download {download_job.source} cancelled.")
self._logger.warning(f"{download_job.source}: model download cancelled")
# if install job has already registered an error, then do not replace its status with cancelled
if not install_job.errored:
install_job.cancel()
@ -731,6 +823,7 @@ class ModelInstallService(ModelInstallServiceBase):
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
id=job.id,
)
def _signal_job_completed(self, job: ModelInstallJob) -> None:
@ -743,7 +836,7 @@ class ModelInstallService(ModelInstallServiceBase):
assert job.local_path is not None
assert job.config_out is not None
key = job.config_out.key
self._event_bus.emit_model_install_completed(str(job.source), key)
self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
def _signal_job_errored(self, job: ModelInstallJob) -> None:
self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}\n{job.error}")
@ -752,7 +845,7 @@ class ModelInstallService(ModelInstallServiceBase):
error = job.error
assert error_type is not None
assert error is not None
self._event_bus.emit_model_install_error(str(job.source), error_type, error)
self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
self._logger.info(f"{job.source}: model installation was cancelled")

View File

@ -0,0 +1,6 @@
"""Initialization file for model load service module."""
from .model_load_base import ModelLoadServiceBase
from .model_load_default import ModelLoadService
__all__ = ["ModelLoadServiceBase", "ModelLoadService"]

View File

@ -0,0 +1,40 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
"""Base class for model loader."""
from abc import ABC, abstractmethod
from typing import Optional
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
class ModelLoadServiceBase(ABC):
"""Wrapper around AnyModelLoader."""
@abstractmethod
def load_model(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch.
:param context_data: Invocation context data used for event reporting
"""
@property
@abstractmethod
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache used by this loader."""
@property
@abstractmethod
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""

View File

@ -0,0 +1,113 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
"""Implementation of model loader service."""
from typing import Optional, Type
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import (
LoadedModel,
ModelLoaderRegistry,
ModelLoaderRegistryBase,
)
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.util.logging import InvokeAILogger
from .model_load_base import ModelLoadServiceBase
class ModelLoadService(ModelLoadServiceBase):
"""Wrapper around ModelLoaderRegistry."""
def __init__(
self,
app_config: InvokeAIAppConfig,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
):
"""Initialize the model load service."""
logger = InvokeAILogger.get_logger(self.__class__.__name__)
logger.setLevel(app_config.log_level.upper())
self._logger = logger
self._app_config = app_config
self._ram_cache = ram_cache
self._convert_cache = convert_cache
self._registry = registry
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache used by this loader."""
return self._ram_cache
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
return self._convert_cache
def load_model(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
"""
if context_data:
self._emit_load_event(
context_data=context_data,
model_config=model_config,
)
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model: LoadedModel = implementation(
app_config=self._app_config,
logger=self._logger,
ram_cache=self._ram_cache,
convert_cache=self._convert_cache,
).load_model(model_config, submodel_type)
if context_data:
self._emit_load_event(
context_data=context_data,
model_config=model_config,
loaded=True,
)
return loaded_model
def _emit_load_event(
self,
context_data: InvocationContextData,
model_config: AnyModelConfig,
loaded: Optional[bool] = False,
) -> None:
if not self._invoker:
return
if not loaded:
self._invoker.services.events.emit_model_load_started(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
)
else:
self._invoker.services.events.emit_model_load_completed(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
)

View File

@ -1 +1,17 @@
from .model_manager_default import ModelManagerService # noqa F401
"""Initialization file for model manager service."""
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import LoadedModel
from .model_manager_default import ModelManagerService, ModelManagerServiceBase
__all__ = [
"ModelManagerServiceBase",
"ModelManagerService",
"AnyModel",
"AnyModelConfig",
"BaseModelType",
"ModelType",
"SubModelType",
"LoadedModel",
]

View File

@ -1,283 +1,101 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from __future__ import annotations
from abc import ABC, abstractmethod
from logging import Logger
from pathlib import Path
from typing import Callable, List, Literal, Optional, Tuple, Union
from typing import Optional
from pydantic import Field
import torch
from typing_extensions import Self
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_management import (
AddModelResult,
BaseModelType,
LoadedModelInfo,
MergeInterpolationMethod,
ModelType,
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.model_management.model_cache import CacheStats
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallServiceBase
from ..model_load import ModelLoadServiceBase
from ..model_records import ModelRecordServiceBase
class ModelManagerServiceBase(ABC):
"""Responsible for managing models on disk and in memory"""
"""Abstract base class for the model manager service."""
# attributes:
# store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.")
# install: ModelInstallServiceBase = Field(description="An instance of the model install service.")
# load: ModelLoadServiceBase = Field(description="An instance of the model load service.")
@classmethod
@abstractmethod
def __init__(
self,
config: InvokeAIAppConfig,
logger: Logger,
):
def build_model_manager(
cls,
app_config: InvokeAIAppConfig,
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_device: torch.device,
) -> Self:
"""
Initialize with the path to the models.yaml config file.
Optional parameters are the torch device type, precision, max_models,
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
Construct the model manager service instance.
Use it rather than the __init__ constructor. This class
method simplifies the construction considerably.
"""
pass
@property
@abstractmethod
def get_model(
def store(self) -> ModelRecordServiceBase:
"""Return the ModelRecordServiceBase used to store and retrieve configuration records."""
pass
@property
@abstractmethod
def load(self) -> ModelLoadServiceBase:
"""Return the ModelLoadServiceBase used to load models from their configuration records."""
pass
@property
@abstractmethod
def install(self) -> ModelInstallServiceBase:
"""Return the ModelInstallServiceBase used to download and manipulate model files."""
pass
@abstractmethod
def start(self, invoker: Invoker) -> None:
pass
@abstractmethod
def stop(self, invoker: Invoker) -> None:
pass
@abstractmethod
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass
@abstractmethod
def load_model_by_key(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass
@abstractmethod
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModelInfo:
"""Retrieve the indicated model with name and type.
submodel can be used to get a part (such as the vae)
of a diffusers pipeline."""
pass
@property
@abstractmethod
def logger(self):
pass
@abstractmethod
def model_exists(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
pass
@abstractmethod
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
Uses the exact format as the omegaconf stanza.
"""
pass
@abstractmethod
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
"""
Return a dict of models in the format:
{ model_type1:
{ model_name1: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description,
'format': 'folder'|'safetensors'|'ckpt'
},
model_name2: { etc }
},
model_type2:
{ model_name_n: etc
}
"""
pass
@abstractmethod
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Return information about the model using the same format as list_models()
"""
pass
@abstractmethod
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
"""
pass
@abstractmethod
def add_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
pass
@abstractmethod
def update_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
ModelNotFoundException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
pass
@abstractmethod
def del_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
):
"""
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk.
"""
pass
@abstractmethod
def rename_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str,
):
"""
Rename the indicated model.
"""
pass
@abstractmethod
def list_checkpoint_configs(self) -> List[Path]:
"""
List the checkpoint config paths from ROOT/configs/stable-diffusion.
"""
pass
@abstractmethod
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Literal[ModelType.Main, ModelType.Vae],
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
pass
@abstractmethod
def heuristic_import(
self,
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> dict[str, AddModelResult]:
"""Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
The prediction type helper is necessary to distinguish between
models based on Stable Diffusion 2 Base (requiring
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
(requiring SchedulerPredictionType.VPrediction). It is
generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
"""
pass
@abstractmethod
def merge_models(
self,
model_names: List[str] = Field(
default=None, min_length=2, max_length=3, description="List of model names to merge"
),
base_model: Union[BaseModelType, str] = Field(
default=None, description="Base model shared by all models to be merged"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None,
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
:param base_model: Base model to use for all models
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
pass
@abstractmethod
def search_for_models(self, directory: Path) -> List[Path]:
"""
Return list of all models found in the designated directory.
"""
pass
@abstractmethod
def sync_to_config(self):
"""
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
"""
pass
@abstractmethod
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics for graph with graph_id.
"""
pass
@abstractmethod
def commit(self, conf_file: Optional[Path] = None) -> None:
"""
Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the
original file/database used to initialize the object.
"""
) -> LoadedModel:
pass

View File

@ -1,421 +1,155 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
from __future__ import annotations
from logging import Logger
from pathlib import Path
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
from typing import Optional
import torch
from pydantic import Field
from typing_extensions import Self
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_management import (
AddModelResult,
BaseModelType,
LoadedModelInfo,
MergeInterpolationMethod,
ModelManager,
ModelMerger,
ModelNotFoundException,
ModelType,
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.model_management.model_cache import CacheStats
from invokeai.backend.model_management.model_search import FindModels
from invokeai.backend.util import choose_precision, choose_torch_device
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallService, ModelInstallServiceBase
from ..model_load import ModelLoadService, ModelLoadServiceBase
from ..model_records import ModelRecordServiceBase, UnknownModelException
from .model_manager_base import ModelManagerServiceBase
if TYPE_CHECKING:
pass
# simple implementation
class ModelManagerService(ModelManagerServiceBase):
"""Responsible for managing models on disk and in memory"""
"""
The ModelManagerService handles various aspects of model installation, maintenance and loading.
It bundles three distinct services:
model_manager.store -- Routines to manage the database of model configuration records.
model_manager.install -- Routines to install, move and delete models.
model_manager.load -- Routines to load models into memory.
"""
def __init__(
self,
config: InvokeAIAppConfig,
logger: Logger,
store: ModelRecordServiceBase,
install: ModelInstallServiceBase,
load: ModelLoadServiceBase,
):
"""
Initialize with the path to the models.yaml config file.
Optional parameters are the torch device type, precision, max_models,
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
"""
if config.model_conf_path and config.model_conf_path.exists():
config_file = config.model_conf_path
else:
config_file = config.root_dir / "configs/models.yaml"
self._store = store
self._install = install
self._load = load
logger.debug(f"Config file={config_file}")
@property
def store(self) -> ModelRecordServiceBase:
return self._store
device = torch.device(choose_torch_device())
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
logger.info(f"GPU device = {device} {device_name}")
@property
def install(self) -> ModelInstallServiceBase:
return self._install
precision = config.precision
if precision == "auto":
precision = choose_precision(device)
dtype = torch.float32 if precision == "float32" else torch.float16
# this is transitional backward compatibility
# support for the deprecated `max_loaded_models`
# configuration value. If present, then the
# cache size is set to 2.5 GB times
# the number of max_loaded_models. Otherwise
# use new `ram_cache_size` config setting
max_cache_size = config.ram_cache_size
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
sequential_offload = config.sequential_guidance
self.mgr = ModelManager(
config=config_file,
device_type=device,
precision=dtype,
max_cache_size=max_cache_size,
sequential_offload=sequential_offload,
logger=logger,
)
logger.info("Model manager service initialized")
@property
def load(self) -> ModelLoadServiceBase:
return self._load
def start(self, invoker: Invoker) -> None:
self._invoker: Optional[Invoker] = invoker
for service in [self._store, self._install, self._load]:
if hasattr(service, "start"):
service.start(invoker)
def get_model(
def stop(self, invoker: Invoker) -> None:
for service in [self._store, self._install, self._load]:
if hasattr(service, "stop"):
service.stop(invoker)
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
return self.load.load_model(model_config, submodel_type, context_data)
def load_model_by_key(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
config = self.store.get_model(key)
return self.load.load_model(config, submodel_type, context_data)
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModelInfo:
"""
Retrieve the indicated model. submodel can be used to get a
part (such as the vae) of a diffusers mode.
) -> LoadedModel:
"""
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
# we can emit model loading events if we are executing with access to the invocation context
if context_data is not None:
self._emit_load_event(
context_data=context_data,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
)
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
loaded_model_info = self.mgr.get_model(
model_name,
base_model,
model_type,
submodel,
)
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
if context_data is not None:
self._emit_load_event(
context_data=context_data,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
loaded_model_info=loaded_model_info,
)
return loaded_model_info
def model_exists(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""
Given a model name, returns True if it is a valid
identifier.
"""
return self.mgr.model_exists(
model_name,
base_model,
model_type,
)
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
"""
return self.mgr.model_info(model_name, base_model, model_type)
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
"""
return self.mgr.model_names()
def list_models(
self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None
) -> list[dict]:
"""
Return a list of models.
"""
return self.mgr.list_models(base_model, model_type)
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
"""
Return information about the model using the same format as list_models()
"""
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
def add_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
self.logger.debug(f"add/update model {model_name}")
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
def update_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
ModelNotFoundException exception if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
self.logger.debug(f"update model {model_name}")
if not self.model_exists(model_name, base_model, model_type):
raise ModelNotFoundException(f"Unknown model {model_name}")
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
def del_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
):
"""
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well.
"""
self.logger.debug(f"delete model {model_name}")
self.mgr.del_model(model_name, base_model, model_type)
self.mgr.commit()
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Literal[ModelType.Main, ModelType.Vae],
convert_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model"
),
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
self.logger.debug(f"convert model {model_name}")
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics for graph with graph_id.
"""
self.mgr.cache.stats = cache_stats
def commit(self, conf_file: Optional[Path] = None):
"""
Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the
original file/database used to initialize the object.
"""
return self.mgr.commit(conf_file)
def _emit_load_event(
self,
context_data: InvocationContextData,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
loaded_model_info: Optional[LoadedModelInfo] = None,
):
if self._invoker is None:
return
if self._invoker.services.queue.is_canceled(context_data.session_id):
raise CanceledException()
if loaded_model_info:
self._invoker.services.events.emit_model_load_completed(
queue_id=context_data.queue_id,
queue_item_id=context_data.queue_item_id,
queue_batch_id=context_data.batch_id,
graph_execution_state_id=context_data.session_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
loaded_model_info=loaded_model_info,
)
configs = self.store.search_by_attr(model_name, base_model, model_type)
if len(configs) == 0:
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
elif len(configs) > 1:
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
else:
self._invoker.services.events.emit_model_load_started(
queue_id=context_data.queue_id,
queue_item_id=context_data.queue_item_id,
queue_batch_id=context_data.batch_id,
graph_execution_state_id=context_data.session_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
)
return self.load.load_model(configs[0], submodel, context_data)
@property
def logger(self):
return self.mgr.logger
@classmethod
def build_model_manager(
cls,
app_config: InvokeAIAppConfig,
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_device: torch.device = choose_torch_device(),
) -> Self:
"""
Construct the model manager service instance.
def heuristic_import(
self,
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> dict[str, AddModelResult]:
"""Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
For simplicity, use this class method rather than the __init__ constructor.
"""
logger = InvokeAILogger.get_logger(cls.__name__)
logger.setLevel(app_config.log_level.upper())
The prediction type helper is necessary to distinguish between
models based on Stable Diffusion 2 Base (requiring
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
(requiring SchedulerPredictionType.VPrediction). It is
generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
"""
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
def merge_models(
self,
model_names: List[str] = Field(
default=None, min_length=2, max_length=3, description="List of model names to merge"
),
base_model: Union[BaseModelType, str] = Field(
default=None, description="Base model shared by all models to be merged"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: float = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: bool = False,
merge_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model"
),
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
:param base_model: Base model to use for all models
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
merger = ModelMerger(self.mgr)
try:
result = merger.merge_diffusion_models_and_save(
model_names=model_names,
base_model=base_model,
merged_model_name=merged_model_name,
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory=merge_dest_directory,
)
except AssertionError as e:
raise ValueError(e)
return result
def search_for_models(self, directory: Path) -> List[Path]:
"""
Return list of all models found in the designated directory.
"""
search = FindModels([directory], self.logger)
return search.list_models()
def sync_to_config(self):
"""
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
"""
return self.mgr.sync_to_config()
def list_checkpoint_configs(self) -> List[Path]:
"""
List the checkpoint config paths from ROOT/configs/stable-diffusion.
"""
config = self.mgr.app_config
conf_path = config.legacy_conf_path
root_path = config.root_path
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
def rename_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: Optional[str] = None,
new_base: Optional[BaseModelType] = None,
):
"""
Rename the indicated model. Can provide a new name and/or a new base.
:param model_name: Current name of the model
:param base_model: Current base of the model
:param model_type: Model type (can't be changed)
:param new_name: New name for the model
:param new_base: New base for the model
"""
self.mgr.rename_model(
base_model=base_model,
model_type=model_type,
model_name=model_name,
new_name=new_name,
new_base=new_base,
ram_cache = ModelCache(
max_cache_size=app_config.ram_cache_size,
max_vram_cache_size=app_config.vram_cache_size,
logger=logger,
execution_device=execution_device,
)
convert_cache = ModelConvertCache(
cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size
)
loader = ModelLoadService(
app_config=app_config,
ram_cache=ram_cache,
convert_cache=convert_cache,
registry=ModelLoaderRegistry,
)
installer = ModelInstallService(
app_config=app_config,
record_store=model_record_service,
download_queue=download_queue,
event_bus=events,
)
return cls(store=model_record_service, install=installer, load=loader)

View File

@ -0,0 +1,9 @@
"""Init file for ModelMetadataStoreService module."""
from .metadata_store_base import ModelMetadataStoreBase
from .metadata_store_sql import ModelMetadataStoreSQL
__all__ = [
"ModelMetadataStoreBase",
"ModelMetadataStoreSQL",
]

View File

@ -0,0 +1,65 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Storage for Model Metadata
"""
from abc import ABC, abstractmethod
from typing import List, Set, Tuple
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class ModelMetadataStoreBase(ABC):
"""Store, search and fetch model metadata retrieved from remote repositories."""
@abstractmethod
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
@abstractmethod
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
@abstractmethod
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
@abstractmethod
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
@abstractmethod
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
@abstractmethod
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
@abstractmethod
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
@abstractmethod
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""

View File

@ -0,0 +1,222 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Storage for Model Metadata
"""
import sqlite3
from typing import List, Optional, Set, Tuple
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
from .metadata_store_base import ModelMetadataStoreBase
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
"""Store, search and fetch model metadata retrieved from remote repositories."""
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
"""
Add a block of repo metadata to a model record.
The model record config must already exist in the database with the
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_metadata(
id,
metadata
)
VALUES (?,?);
""",
(
model_key,
json_serialized,
),
)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
raise UnknownMetadataException from excp
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata FROM model_metadata
WHERE id=?;
""",
(model_key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownMetadataException("model metadata not found")
return ModelMetadataFetchBase.from_json(rows[0])
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
"""Dump out all the metadata."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT id,metadata FROM model_metadata;
""",
(),
)
rows = self._cursor.fetchall()
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
"""
Update metadata corresponding to the model with the indicated key.
:param model_key: Existing model key in the `model_config` table
:param metadata: ModelRepoMetadata object to update
"""
json_serialized = metadata.model_dump_json() # turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_metadata
SET
metadata=?
WHERE id=?;
""",
(json_serialized, model_key),
)
if self._cursor.rowcount == 0:
raise UnknownMetadataException("model metadata not found")
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_metadata(model_key)
def list_tags(self) -> Set[str]:
"""Return all tags in the tags table."""
self._cursor.execute(
"""--sql
select tag_text from tags;
"""
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_tag(self, tags: Set[str]) -> Set[str]:
"""Return the keys of models containing all of the listed tags."""
with self._db.lock:
try:
matches: Optional[Set[str]] = None
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.model_id FROM model_tags AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
if matches is None:
matches = model_keys
matches = matches.intersection(model_keys)
except sqlite3.Error as e:
raise e
return matches if matches else set()
def search_by_author(self, author: str) -> Set[str]:
"""Return the keys of models authored by the indicated author."""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE author=?;
""",
(author,),
)
return {x[0] for x in self._cursor.fetchall()}
def search_by_name(self, name: str) -> Set[str]:
"""
Return the keys of models with the indicated name.
Note that this is the name of the model given to it by
the remote source. The user may have changed the local
name. The local name will be located in the model config
record object.
"""
self._cursor.execute(
"""--sql
SELECT id FROM model_metadata
WHERE name=?;
""",
(name,),
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@ -11,8 +11,15 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
from invokeai.backend.model_manager import (
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..model_metadata import ModelMetadataStoreBase
class DuplicateModelException(Exception):
@ -104,7 +111,7 @@ class ModelRecordServiceBase(ABC):
@property
@abstractmethod
def metadata_store(self) -> ModelMetadataStore:
def metadata_store(self) -> ModelMetadataStoreBase:
"""Return a ModelMetadataStore initialized on the same database."""
pass
@ -146,7 +153,7 @@ class ModelRecordServiceBase(ABC):
@abstractmethod
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
Return True if a model with the indicated key exists in the database.
:param key: Unique key for the model to be deleted
"""

View File

@ -54,8 +54,9 @@ from invokeai.backend.model_manager.config import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import (
DuplicateModelException,
@ -69,16 +70,16 @@ from .model_records_base import (
class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database."""
def __init__(self, db: SqliteDatabase):
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
:param db: Sqlite connection object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
self._cursor = db.conn.cursor()
self._metadata_store = metadata_store
@property
def db(self) -> SqliteDatabase:
@ -158,7 +159,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db.conn.rollback()
raise e
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
"""
Update the model, returning the updated version.
@ -199,7 +200,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
SELECT config, strftime('%s',updated_at) FROM model_config
WHERE id=?;
""",
(key,),
@ -207,7 +208,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]))
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def exists(self, key: str) -> bool:
@ -265,12 +266,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
f"""--sql
select config FROM model_config
select config, strftime('%s',updated_at) FROM model_config
{where};
""",
tuple(bindings),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
results = [
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
]
return results
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
@ -279,12 +282,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
SELECT config, strftime('%s',updated_at) FROM model_config
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
results = [
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
]
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
@ -293,18 +298,20 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
SELECT config, strftime('%s',updated_at) FROM model_config
WHERE original_hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
results = [
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
]
return results
@property
def metadata_store(self) -> ModelMetadataStore:
def metadata_store(self) -> ModelMetadataStoreBase:
"""Return a ModelMetadataStore initialized on the same database."""
return ModelMetadataStore(self._db)
return self._metadata_store
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""
@ -325,18 +332,18 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param tags: Set of tags to search for. All tags must be present.
"""
store = ModelMetadataStore(self._db)
store = ModelMetadataStoreSQL(self._db)
keys = store.search_by_tag(tags)
return [self.get_model(x) for x in keys]
def list_tags(self) -> Set[str]:
"""Return a unique set of all the model tags in the metadata database."""
store = ModelMetadataStore(self._db)
store = ModelMetadataStoreSQL(self._db)
return store.list_tags()
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
"""List metadata for all models that have it."""
store = ModelMetadataStore(self._db)
store = ModelMetadataStoreSQL(self._db)
return store.list_all_metadata()
def list_models(

View File

@ -4,3 +4,17 @@ from pydantic import BaseModel, Field
class SessionProcessorStatus(BaseModel):
is_started: bool = Field(description="Whether the session processor is started")
is_processing: bool = Field(description="Whether a session is being processed")
class CanceledException(Exception):
"""Execution canceled by user."""
pass
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")

View File

@ -1,4 +1,5 @@
import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from typing import Optional
@ -6,136 +7,270 @@ from typing import Optional
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.util.profiler import Profiler
from ..invoker import Invoker
from .session_processor_base import SessionProcessorBase
from .session_processor_common import SessionProcessorStatus
POLLING_INTERVAL = 1
THREAD_LIMIT = 1
class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker) -> None:
self.__invoker: Invoker = invoker
self.__queue_item: Optional[SessionQueueItem] = None
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
self._invoker: Invoker = invoker
self._queue_item: Optional[SessionQueueItem] = None
self._invocation: Optional[BaseInvocation] = None
self.__resume_event = ThreadEvent()
self.__stop_event = ThreadEvent()
self.__poll_now_event = ThreadEvent()
self._resume_event = ThreadEvent()
self._stop_event = ThreadEvent()
self._poll_now_event = ThreadEvent()
self._cancel_event = ThreadEvent()
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
self.__threadLimit = BoundedSemaphore(THREAD_LIMIT)
self.__thread = Thread(
self._thread_limit = thread_limit
self._thread_semaphore = BoundedSemaphore(thread_limit)
self._polling_interval = polling_interval
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
# the profiler will create a new profile for each session.
self._profiler = (
Profiler(
logger=self._invoker.services.logger,
output_dir=self._invoker.services.configuration.profiles_path,
prefix=self._invoker.services.configuration.profile_prefix,
)
if self._invoker.services.configuration.profile_graphs
else None
)
self._thread = Thread(
name="session_processor",
target=self.__process,
target=self._process,
kwargs={
"stop_event": self.__stop_event,
"poll_now_event": self.__poll_now_event,
"resume_event": self.__resume_event,
"stop_event": self._stop_event,
"poll_now_event": self._poll_now_event,
"resume_event": self._resume_event,
"cancel_event": self._cancel_event,
},
)
self.__thread.start()
self._thread.start()
def stop(self, *args, **kwargs) -> None:
self.__stop_event.set()
self._stop_event.set()
def _poll_now(self) -> None:
self.__poll_now_event.set()
self._poll_now_event.set()
async def _on_queue_event(self, event: FastAPIEvent) -> None:
event_name = event[1]["event"]
# This was a match statement, but match is not supported on python 3.9
if event_name in [
"graph_execution_state_complete",
"invocation_error",
"session_retrieval_error",
"invocation_retrieval_error",
]:
self.__queue_item = None
self._poll_now()
elif (
event_name == "session_canceled"
and self.__queue_item is not None
and self.__queue_item.session_id == event[1]["data"]["graph_execution_state_id"]
):
self.__queue_item = None
if event_name == "session_canceled" or event_name == "queue_cleared":
# These both mean we should cancel the current session.
self._cancel_event.set()
self._poll_now()
elif event_name == "batch_enqueued":
self._poll_now()
elif event_name == "queue_cleared":
self.__queue_item = None
self._poll_now()
def resume(self) -> SessionProcessorStatus:
if not self.__resume_event.is_set():
self.__resume_event.set()
if not self._resume_event.is_set():
self._resume_event.set()
return self.get_status()
def pause(self) -> SessionProcessorStatus:
if self.__resume_event.is_set():
self.__resume_event.clear()
if self._resume_event.is_set():
self._resume_event.clear()
return self.get_status()
def get_status(self) -> SessionProcessorStatus:
return SessionProcessorStatus(
is_started=self.__resume_event.is_set(),
is_processing=self.__queue_item is not None,
is_started=self._resume_event.is_set(),
is_processing=self._queue_item is not None,
)
def __process(
def _process(
self,
stop_event: ThreadEvent,
poll_now_event: ThreadEvent,
resume_event: ThreadEvent,
cancel_event: ThreadEvent,
):
# Outermost processor try block; any unhandled exception is a fatal processor error
try:
self._thread_semaphore.acquire()
stop_event.clear()
resume_event.set()
self.__threadLimit.acquire()
queue_item: Optional[SessionQueueItem] = None
cancel_event.clear()
while not stop_event.is_set():
poll_now_event.clear()
# Middle processor try block; any unhandled exception is a non-fatal processor error
try:
# do not dequeue if there is already a session running
if self.__queue_item is None and resume_event.is_set():
queue_item = self.__invoker.services.session_queue.dequeue()
# Get the next session to process
self._queue_item = self._invoker.services.session_queue.dequeue()
if self._queue_item is not None and resume_event.is_set():
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
if queue_item is not None:
self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
self.__queue_item = queue_item
self.__invoker.services.graph_execution_manager.set(queue_item.session)
self.__invoker.invoke(
session_queue_batch_id=queue_item.batch_id,
session_queue_id=queue_item.queue_id,
session_queue_item_id=queue_item.item_id,
graph_execution_state=queue_item.session,
workflow=queue_item.workflow,
invoke_all=True,
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=self._queue_item.session_id)
# Prepare invocations and take the first
self._invocation = self._queue_item.session.next()
# Loop over invocations until the session is complete or canceled
while self._invocation is not None and not cancel_event.is_set():
# get the source node id to provide to clients (the prepared node id is not as useful)
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
# Send starting event
self._invoker.services.events.emit_invocation_started(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session_id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
)
queue_item = None
if queue_item is None:
self.__invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(POLLING_INTERVAL)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
try:
with self._invoker.services.performance_statistics.collect_stats(
self._invocation, self._queue_item.session.id
):
# Build invocation context (the node-facing API)
data = InvocationContextData(
invocation=self._invocation,
source_invocation_id=source_invocation_id,
queue_item=self._queue_item,
)
context = build_invocation_context(
data=data,
services=self._invoker.services,
cancel_event=self._cancel_event,
)
# Invoke the node
outputs = self._invocation.invoke_internal(
context=context, services=self._invoker.services
)
# Save outputs and history
self._queue_item.session.complete(self._invocation.id, outputs)
# Send complete event
self._invoker.services.events.emit_invocation_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
self._queue_item.session.set_node_error(self._invocation.id, error)
self._invoker.services.logger.error(
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
)
# Send error event
self._invoker.services.events.emit_invocation_error(
queue_batch_id=self._queue_item.session_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
error_type=e.__class__.__name__,
error=error,
)
pass
# The session is complete if the all invocations are complete or there was an error
if self._queue_item.session.is_complete() or cancel_event.is_set():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
# Set the invocation to None to prepare for the next session
self._invocation = None
else:
# Prepare the next invocation
self._invocation = self._queue_item.session.next()
# The session is complete, immediately poll for next session
self._queue_item = None
poll_now_event.set()
else:
# The queue was empty, wait for next polling interval or event to try again
self._invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(self._polling_interval)
continue
except Exception as e:
self.__invoker.services.logger.error(f"Error in session processor: {e}")
if queue_item is not None:
self.__invoker.services.session_queue.cancel_queue_item(
queue_item.item_id, error=traceback.format_exc()
except Exception:
# Non-fatal error in processor
self._invoker.services.logger.error(
f"Non-fatal error in session processor:\n{traceback.format_exc()}"
)
# Cancel the queue item
if self._queue_item is not None:
self._invoker.services.session_queue.cancel_queue_item(
self._queue_item.item_id, error=traceback.format_exc()
)
poll_now_event.wait(POLLING_INTERVAL)
# Reset the invocation to None to prepare for the next session
self._invocation = None
# Immediately poll for next queue item
poll_now_event.wait(self._polling_interval)
continue
except Exception as e:
self.__invoker.services.logger.error(f"Fatal Error in session processor: {e}")
except Exception:
# Fatal error in processor, log and pass - we're done here
self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}")
pass
finally:
stop_event.clear()
poll_now_event.clear()
self.__queue_item = None
self.__threadLimit.release()
self._queue_item = None
self._thread_semaphore.release()

View File

@ -60,7 +60,7 @@ class SqliteSessionQueue(SessionQueueBase):
# This was a match statement, but match is not supported on python 3.9
if event_name == "graph_execution_state_complete":
await self._handle_complete_event(event)
elif event_name in ["invocation_error", "session_retrieval_error", "invocation_retrieval_error"]:
elif event_name == "invocation_error":
await self._handle_error_event(event)
elif event_name == "session_canceled":
await self._handle_cancel_event(event)
@ -429,7 +429,6 @@ class SqliteSessionQueue(SessionQueueBase):
if queue_item.status not in ["canceled", "failed", "completed"]:
status = "failed" if error is not None else "canceled"
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here
self.__invoker.services.queue.cancel(queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
@ -471,7 +470,6 @@ class SqliteSessionQueue(SessionQueueBase):
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self.__invoker.services.queue.cancel(current_queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
@ -523,7 +521,6 @@ class SqliteSessionQueue(SessionQueueBase):
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self.__invoker.services.queue.cancel(current_queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,

View File

@ -1,92 +0,0 @@
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
from ...invocations.compel import CompelInvocation
from ...invocations.image import ImageNSFWBlurInvocation
from ...invocations.latent import DenoiseLatentsInvocation, LatentsToImageInvocation
from ...invocations.noise import NoiseInvocation
from ...invocations.primitives import IntegerInvocation
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
def create_text_to_image() -> LibraryGraph:
graph = Graph(
nodes={
"width": IntegerInvocation(id="width", value=512),
"height": IntegerInvocation(id="height", value=512),
"seed": IntegerInvocation(id="seed", value=-1),
"3": NoiseInvocation(id="3"),
"4": CompelInvocation(id="4"),
"5": CompelInvocation(id="5"),
"6": DenoiseLatentsInvocation(id="6"),
"7": LatentsToImageInvocation(id="7"),
"8": ImageNSFWBlurInvocation(id="8"),
},
edges=[
Edge(
source=EdgeConnection(node_id="width", field="value"),
destination=EdgeConnection(node_id="3", field="width"),
),
Edge(
source=EdgeConnection(node_id="height", field="value"),
destination=EdgeConnection(node_id="3", field="height"),
),
Edge(
source=EdgeConnection(node_id="seed", field="value"),
destination=EdgeConnection(node_id="3", field="seed"),
),
Edge(
source=EdgeConnection(node_id="3", field="noise"),
destination=EdgeConnection(node_id="6", field="noise"),
),
Edge(
source=EdgeConnection(node_id="6", field="latents"),
destination=EdgeConnection(node_id="7", field="latents"),
),
Edge(
source=EdgeConnection(node_id="4", field="conditioning"),
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
),
Edge(
source=EdgeConnection(node_id="5", field="conditioning"),
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
),
Edge(
source=EdgeConnection(node_id="7", field="image"),
destination=EdgeConnection(node_id="8", field="image"),
),
],
)
return LibraryGraph(
id=default_text_to_image_graph_id,
name="t2i",
description="Converts text to an image",
graph=graph,
exposed_inputs=[
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
ExposedNodeInput(node_path="width", field="value", alias="width"),
ExposedNodeInput(node_path="height", field="value", alias="height"),
ExposedNodeInput(node_path="seed", field="value", alias="seed"),
],
exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")],
)
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
graphs: list[LibraryGraph] = []
text_to_image = graph_library.get(default_text_to_image_graph_id)
# TODO: Check if the graph is the same as the default one, and if not, update it
# if text_to_image is None:
text_to_image = create_text_to_image()
graph_library.set(text_to_image)
graphs.append(text_to_image)
return graphs

View File

@ -5,8 +5,14 @@ import itertools
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
import networkx as nx
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
from pydantic import (
BaseModel,
GetJsonSchemaHandler,
field_validator,
)
from pydantic.fields import Field
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import CoreSchema
# Importing * is bad karma but needed here for node detection
from invokeai.app.invocations import * # noqa: F401 F403
@ -176,10 +182,6 @@ class NodeIdMismatchError(ValueError):
pass
class InvalidSubGraphError(ValueError):
pass
class CyclicalGraphError(ValueError):
pass
@ -188,25 +190,6 @@ class UnknownGraphValidationError(ValueError):
pass
# TODO: Create and use an Empty output?
@invocation_output("graph_output")
class GraphInvocationOutput(BaseInvocationOutput):
pass
# TODO: Fill this out and move to invocations
@invocation("graph", version="1.0.0")
class GraphInvocation(BaseInvocation):
"""Execute a graph"""
# TODO: figure out how to create a default here
graph: "Graph" = InputField(description="The graph to run", default=None)
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
"""Invoke with provided services and return outputs."""
return GraphInvocationOutput()
@invocation_output("iterate_output")
class IterateInvocationOutput(BaseInvocationOutput):
"""Used to connect iteration outputs. Will be expanded to a specific output."""
@ -260,21 +243,73 @@ class CollectInvocation(BaseInvocation):
return CollectInvocationOutput(collection=copy.copy(self.collection))
InvocationsUnion: Any = BaseInvocation.get_invocations_union()
InvocationOutputsUnion: Any = BaseInvocationOutput.get_outputs_union()
class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=uuid_string)
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
description="The nodes in this graph", default_factory=dict
)
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
edges: list[Edge] = Field(
description="The connections between nodes and their fields in this graph",
default_factory=list,
)
@field_validator("nodes", mode="plain")
@classmethod
def validate_nodes(cls, v: dict[str, Any]):
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
# Invocations register themselves as their python modules are executed. The union of all invocations is
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
#
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
# invocations will cause a graph to fail if they are used.
#
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
#
# This same pattern is used in `GraphExecutionState`.
nodes: dict[str, BaseInvocation] = {}
typeadapter = BaseInvocation.get_typeadapter()
for node_id, node in v.items():
nodes[node_id] = typeadapter.validate_python(node)
return nodes
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
# the generated schema as options for the `nodes` field.
#
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
# expected.
#
# You might be tempted to do something like this:
#
# ```py
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
# delattr(cloned_model, "validate_nodes")
# cloned_model.model_rebuild(force=True)
# json_schema = handler(cloned_model.__pydantic_core_schema__)
# ```
#
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
#
# This same pattern is used in `GraphExecutionState`.
class Graph(BaseModel):
id: Optional[str] = Field(default=None, description="The id of this graph")
nodes: dict[
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
] = Field(description="The nodes in this graph")
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
json_schema = handler(Graph.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph
@ -286,41 +321,21 @@ class Graph(BaseModel):
self.nodes[node.id] = node
def _get_graph_and_node(self, node_path: str) -> tuple["Graph", str]:
"""Returns the graph and node id for a node path."""
# Materialized graphs may have nodes at the top level
if node_path in self.nodes:
return (self, node_path)
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
if node_id not in self.nodes:
raise NodeNotFoundError(f"Node {node_path} not found in graph")
node = self.nodes[node_id]
if not isinstance(node, GraphInvocation):
# There's more node path left but this isn't a graph - failure
raise NodeNotFoundError("Node path terminated early at a non-graph node")
return node.graph._get_graph_and_node(node_path[node_path.index(".") + 1 :])
def delete_node(self, node_path: str) -> None:
def delete_node(self, node_id: str) -> None:
"""Deletes a node from a graph"""
try:
graph, node_id = self._get_graph_and_node(node_path)
# Delete edges for this node
input_edges = self._get_input_edges_and_graphs(node_path)
output_edges = self._get_output_edges_and_graphs(node_path)
input_edges = self._get_input_edges(node_id)
output_edges = self._get_output_edges(node_id)
for edge_graph, _, edge in input_edges:
edge_graph.delete_edge(edge)
for edge in input_edges:
self.delete_edge(edge)
for edge_graph, _, edge in output_edges:
edge_graph.delete_edge(edge)
for edge in output_edges:
self.delete_edge(edge)
del graph.nodes[node_id]
del self.nodes[node_id]
except NodeNotFoundError:
pass # Ignore, not doesn't exist (should this throw?)
@ -370,13 +385,6 @@ class Graph(BaseModel):
if k != v.id:
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
# Validate all subgraphs
for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)):
try:
gn.graph.validate_self()
except Exception as e:
raise InvalidSubGraphError(f"Subgraph {gn.id} is invalid") from e
# Validate that all edges match nodes and fields in the graph
for edge in self.edges:
source_node = self.nodes.get(edge.source.node_id, None)
@ -438,7 +446,6 @@ class Graph(BaseModel):
except (
DuplicateNodeIdError,
NodeIdMismatchError,
InvalidSubGraphError,
NodeNotFoundError,
NodeFieldNotFoundError,
CyclicalGraphError,
@ -459,7 +466,7 @@ class Graph(BaseModel):
def _validate_edge(self, edge: Edge):
"""Validates that a new edge doesn't create a cycle in the graph"""
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
# Validate that the nodes exist
try:
from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id)
@ -526,171 +533,90 @@ class Graph(BaseModel):
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
def has_node(self, node_path: str) -> bool:
def has_node(self, node_id: str) -> bool:
"""Determines whether or not a node exists in the graph."""
try:
n = self.get_node(node_path)
if n is not None:
return True
else:
return False
_ = self.get_node(node_id)
return True
except NodeNotFoundError:
return False
def get_node(self, node_path: str) -> BaseInvocation:
"""Gets a node from the graph using a node path."""
# Materialized graphs may have nodes at the top level
graph, node_id = self._get_graph_and_node(node_path)
return graph.nodes[node_id]
def get_node(self, node_id: str) -> BaseInvocation:
"""Gets a node from the graph."""
try:
return self.nodes[node_id]
except KeyError as e:
raise NodeNotFoundError(f"Node {node_id} not found in graph") from e
def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str:
return node_id if prefix is None or prefix == "" else f"{prefix}.{node_id}"
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
def update_node(self, node_id: str, new_node: BaseInvocation) -> None:
"""Updates a node in the graph."""
graph, node_id = self._get_graph_and_node(node_path)
node = graph.nodes[node_id]
node = self.nodes[node_id]
# Ensure the node type matches the new node
if type(node) is not type(new_node):
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
raise TypeError(f"Node {node_id} is type {type(node)} but new node is type {type(new_node)}")
# Ensure the new id is either the same or is not in the graph
prefix = None if "." not in node_path else node_path[: node_path.rindex(".")]
new_path = self._get_node_path(new_node.id, prefix=prefix)
if new_node.id != node.id and self.has_node(new_path):
raise NodeAlreadyInGraphError("Node with id {new_node.id} already exists in graph")
if new_node.id != node.id and self.has_node(new_node.id):
raise NodeAlreadyInGraphError(f"Node with id {new_node.id} already exists in graph")
# Set the new node in the graph
graph.nodes[new_node.id] = new_node
self.nodes[new_node.id] = new_node
if new_node.id != node.id:
input_edges = self._get_input_edges_and_graphs(node_path)
output_edges = self._get_output_edges_and_graphs(node_path)
input_edges = self._get_input_edges(node_id)
output_edges = self._get_output_edges(node_id)
# Delete node and all edges
graph.delete_node(node_path)
self.delete_node(node_id)
# Create new edges for each input and output
for graph, _, edge in input_edges:
# Remove the graph prefix from the node path
new_graph_node_path = (
new_node.id
if "." not in edge.destination.node_id
else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}'
)
graph.add_edge(
for edge in input_edges:
self.add_edge(
Edge(
source=edge.source,
destination=EdgeConnection(node_id=new_graph_node_path, field=edge.destination.field),
destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field),
)
)
for graph, _, edge in output_edges:
# Remove the graph prefix from the node path
new_graph_node_path = (
new_node.id
if "." not in edge.source.node_id
else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}'
)
graph.add_edge(
for edge in output_edges:
self.add_edge(
Edge(
source=EdgeConnection(node_id=new_graph_node_path, field=edge.source.field),
source=EdgeConnection(node_id=new_node.id, field=edge.source.field),
destination=edge.destination,
)
)
def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[Edge]:
"""Gets all input edges for a node"""
edges = self._get_input_edges_and_graphs(node_path)
def _get_input_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]:
"""Gets all input edges for a node. If field is provided, only edges to that field are returned."""
# Filter to edges that match the field
filtered_edges = (e for e in edges if field is None or e[2].destination.field == field)
edges = [e for e in self.edges if e.destination.node_id == node_id]
# Create full node paths for each edge
return [
Edge(
source=EdgeConnection(
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
field=e.source.field,
),
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
),
)
for _, prefix, e in filtered_edges
]
if field is None:
return edges
def _get_input_edges_and_graphs(
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", Union[str, None], Edge]]:
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
edges = []
filtered_edges = [e for e in edges if e.destination.field == field]
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
return filtered_edges
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
node = self.nodes[node_id]
def _get_output_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]:
"""Gets all output edges for a node. If field is provided, only edges from that field are returned."""
edges = [e for e in self.edges if e.source.node_id == node_id]
if isinstance(node, GraphInvocation):
graph = node.graph
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
edges.extend(graph_edges)
if field is None:
return edges
return edges
filtered_edges = [e for e in edges if e.source.field == field]
def _get_output_edges(self, node_path: str, field: str) -> list[Edge]:
"""Gets all output edges for a node"""
edges = self._get_output_edges_and_graphs(node_path)
# Filter to edges that match the field
filtered_edges = (e for e in edges if e[2].source.field == field)
# Create full node paths for each edge
return [
Edge(
source=EdgeConnection(
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
field=e.source.field,
),
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
),
)
for _, prefix, e in filtered_edges
]
def _get_output_edges_and_graphs(
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", Union[str, None], Edge]]:
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
edges = []
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
node = self.nodes[node_id]
if isinstance(node, GraphInvocation):
graph = node.graph
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
edges.extend(graph_edges)
return edges
return filtered_edges
def _is_iterator_connection_valid(
self,
node_path: str,
node_id: str,
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
inputs = [e.source for e in self._get_input_edges(node_path, "collection")]
outputs = [e.destination for e in self._get_output_edges(node_path, "item")]
inputs = [e.source for e in self._get_input_edges(node_id, "collection")]
outputs = [e.destination for e in self._get_output_edges(node_id, "item")]
if new_input is not None:
inputs.append(new_input)
@ -718,12 +644,12 @@ class Graph(BaseModel):
def _is_collector_connection_valid(
self,
node_path: str,
node_id: str,
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
inputs = [e.source for e in self._get_input_edges(node_path, "item")]
outputs = [e.destination for e in self._get_output_edges(node_path, "collection")]
inputs = [e.source for e in self._get_input_edges(node_id, "item")]
outputs = [e.destination for e in self._get_output_edges(node_id, "collection")]
if new_input is not None:
inputs.append(new_input)
@ -779,27 +705,17 @@ class Graph(BaseModel):
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
return g
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None) -> nx.DiGraph:
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
g = nx_graph or nx.DiGraph()
# Add all nodes from this graph except graph/iteration nodes
g.add_nodes_from(
[
self._get_node_path(n.id, prefix)
for n in self.nodes.values()
if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)
]
)
# Expand graph nodes
for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)):
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
g.add_nodes_from([n.id for n in self.nodes.values() if not isinstance(n, IterateInvocation)])
# TODO: figure out if iteration nodes need to be expanded
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
g.add_edges_from([(e[0], e[1]) for e in unique_edges])
return g
@ -824,9 +740,7 @@ class GraphExecutionState(BaseModel):
)
# The results of executed nodes
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(
description="The results of node executions", default_factory=dict
)
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
# Errors raised when executing nodes
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
@ -843,27 +757,51 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
@field_validator("results", mode="plain")
@classmethod
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
results: dict[str, BaseInvocationOutput] = {}
typeadapter = BaseInvocationOutput.get_typeadapter()
for result_id, result in v.items():
results[result_id] = typeadapter.validate_python(result)
return results
@field_validator("graph")
def graph_is_valid(cls, v: Graph):
"""Validates that the graph is valid"""
v.validate_self()
return v
model_config = ConfigDict(
json_schema_extra={
"required": [
"id",
"graph",
"execution_graph",
"executed",
"executed_history",
"results",
"errors",
"prepared_source_mapping",
"source_prepared_mapping",
]
}
)
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(description="The id of the execution state")
graph: Graph = Field(description="The graph being executed")
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
executed: set[str] = Field(description="The set of node ids that have been executed")
executed_history: list[str] = Field(
description="The list of node ids that have been executed, in order of execution"
)
results: dict[
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
] = Field(description="The results of node executions")
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
prepared_source_mapping: dict[str, str] = Field(
description="The map of prepared nodes to original graph nodes"
)
source_prepared_mapping: dict[str, set[str]] = Field(
description="The map of original graph nodes to prepared nodes"
)
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def next(self) -> Optional[BaseInvocation]:
"""Gets the next node ready to execute."""
@ -919,17 +857,17 @@ class GraphExecutionState(BaseModel):
"""Returns true if the graph has any errors"""
return len(self.errors) > 0
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
def _create_execution_node(self, node_id: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
"""Prepares an iteration node and connects all edges, returning the new node id"""
node = self.graph.get_node(node_path)
node = self.graph.get_node(node_id)
self_iteration_count = -1
# If this is an iterator node, we must create a copy for each iteration
if isinstance(node, IterateInvocation):
# Get input collection edge (should error if there are no inputs)
input_collection_edge = next(iter(self.graph._get_input_edges(node_path, "collection")))
input_collection_edge = next(iter(self.graph._get_input_edges(node_id, "collection")))
input_collection_prepared_node_id = next(
n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id
)
@ -943,7 +881,7 @@ class GraphExecutionState(BaseModel):
return new_nodes
# Get all input edges
input_edges = self.graph._get_input_edges(node_path)
input_edges = self.graph._get_input_edges(node_id)
# Create new edges for this iteration
# For collect nodes, this may contain multiple inputs to the same field
@ -970,10 +908,10 @@ class GraphExecutionState(BaseModel):
# Add to execution graph
self.execution_graph.add_node(new_node)
self.prepared_source_mapping[new_node.id] = node_path
if node_path not in self.source_prepared_mapping:
self.source_prepared_mapping[node_path] = set()
self.source_prepared_mapping[node_path].add(new_node.id)
self.prepared_source_mapping[new_node.id] = node_id
if node_id not in self.source_prepared_mapping:
self.source_prepared_mapping[node_id] = set()
self.source_prepared_mapping[node_id].add(new_node.id)
# Add new edges to execution graph
for edge in new_edges:
@ -1077,13 +1015,13 @@ class GraphExecutionState(BaseModel):
def _get_iteration_node(
self,
source_node_path: str,
source_node_id: str,
graph: nx.DiGraph,
execution_graph: nx.DiGraph,
prepared_iterator_nodes: list[str],
) -> Optional[str]:
"""Gets the prepared version of the specified source node that matches every iteration specified"""
prepared_nodes = self.source_prepared_mapping[source_node_path]
prepared_nodes = self.source_prepared_mapping[source_node_id]
if len(prepared_nodes) == 1:
return next(iter(prepared_nodes))
@ -1094,7 +1032,7 @@ class GraphExecutionState(BaseModel):
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)]
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)]
return next(
(n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)),
@ -1163,19 +1101,19 @@ class GraphExecutionState(BaseModel):
def add_node(self, node: BaseInvocation) -> None:
self.graph.add_node(node)
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
if not self._is_node_updatable(node_path):
def update_node(self, node_id: str, new_node: BaseInvocation) -> None:
if not self._is_node_updatable(node_id):
raise NodeAlreadyExecutedError(
f"Node {node_path} has already been prepared or executed and cannot be updated"
f"Node {node_id} has already been prepared or executed and cannot be updated"
)
self.graph.update_node(node_path, new_node)
self.graph.update_node(node_id, new_node)
def delete_node(self, node_path: str) -> None:
if not self._is_node_updatable(node_path):
def delete_node(self, node_id: str) -> None:
if not self._is_node_updatable(node_id):
raise NodeAlreadyExecutedError(
f"Node {node_path} has already been prepared or executed and cannot be deleted"
f"Node {node_id} has already been prepared or executed and cannot be deleted"
)
self.graph.delete_node(node_path)
self.graph.delete_node(node_id)
def add_edge(self, edge: Edge) -> None:
if not self._is_node_updatable(edge.destination.node_id):
@ -1190,63 +1128,3 @@ class GraphExecutionState(BaseModel):
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted"
)
self.graph.delete_edge(edge)
class ExposedNodeInput(BaseModel):
node_path: str = Field(description="The node path to the node with the input")
field: str = Field(description="The field name of the input")
alias: str = Field(description="The alias of the input")
class ExposedNodeOutput(BaseModel):
node_path: str = Field(description="The node path to the node with the output")
field: str = Field(description="The field name of the output")
alias: str = Field(description="The alias of the output")
class LibraryGraph(BaseModel):
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid_string)
graph: Graph = Field(description="The graph")
name: str = Field(description="The name of the graph")
description: str = Field(description="The description of the graph")
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
exposed_outputs: list[ExposedNodeOutput] = Field(
description="The outputs exposed by this graph", default_factory=list
)
@field_validator("exposed_inputs", "exposed_outputs")
def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]):
if len(v) != len({i.alias for i in v}):
raise ValueError("Duplicate exposed alias")
return v
@model_validator(mode="after")
def validate_exposed_nodes(cls, values):
graph = values.graph
# Validate exposed inputs
for exposed_input in values.exposed_inputs:
if not graph.has_node(exposed_input.node_path):
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
node = graph.get_node(exposed_input.node_path)
if get_input_field(node, exposed_input.field) is None:
raise ValueError(
f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}"
)
# Validate exposed outputs
for exposed_output in values.exposed_outputs:
if not graph.has_node(exposed_output.node_path):
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
node = graph.get_node(exposed_output.node_path)
if get_output_field(node, exposed_output.field) is None:
raise ValueError(
f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}"
)
return values
GraphInvocation.model_rebuild(force=True)
Graph.model_rebuild(force=True)
GraphExecutionState.model_rebuild(force=True)

View File

@ -1,24 +1,28 @@
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from PIL.Image import Image
from torch import Tensor
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.model_manager import LoadedModelInfo
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
"""
The InvocationContext provides access to various services and data about the current invocation.
@ -45,99 +49,102 @@ Note: The docstrings are in weird places, but that's where they must be to get I
@dataclass
class InvocationContextData:
queue_item: "SessionQueueItem"
"""The queue item that is being executed."""
invocation: "BaseInvocation"
"""The invocation that is being executed."""
session_id: str
"""The session that is being executed."""
queue_id: str
"""The queue in which the session is being executed."""
source_node_id: str
"""The ID of the node from which the currently executing invocation was prepared."""
queue_item_id: int
"""The ID of the queue item that is being executed."""
batch_id: str
"""The ID of the batch that is being executed."""
workflow: Optional[WorkflowWithoutID] = None
"""The workflow associated with this queue item, if any."""
source_invocation_id: str
"""The ID of the invocation from which the currently executing invocation was prepared."""
class InvocationContextInterface:
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
def __init__(self, services: InvocationServices, data: InvocationContextData) -> None:
self._services = services
self._context_data = context_data
self._data = data
class BoardsInterface(InvocationContextInterface):
def create(self, board_name: str) -> BoardDTO:
"""
Creates a board.
"""Creates a board.
:param board_name: The name of the board to create.
Args:
board_name: The name of the board to create.
Returns:
The created board DTO.
"""
return self._services.boards.create(board_name)
def get_dto(self, board_id: str) -> BoardDTO:
"""
Gets a board DTO.
"""Gets a board DTO.
:param board_id: The ID of the board to get.
Args:
board_id: The ID of the board to get.
Returns:
The board DTO.
"""
return self._services.boards.get_dto(board_id)
def get_all(self) -> list[BoardDTO]:
"""
Gets all boards.
"""Gets all boards.
Returns:
A list of all boards.
"""
return self._services.boards.get_all()
def add_image_to_board(self, board_id: str, image_name: str) -> None:
"""
Adds an image to a board.
"""Adds an image to a board.
:param board_id: The ID of the board to add the image to.
:param image_name: The name of the image to add to the board.
Args:
board_id: The ID of the board to add the image to.
image_name: The name of the image to add to the board.
"""
return self._services.board_images.add_image_to_board(board_id, image_name)
def get_all_image_names_for_board(self, board_id: str) -> list[str]:
"""
Gets all image names for a board.
"""Gets all image names for a board.
:param board_id: The ID of the board to get the image names for.
Args:
board_id: The ID of the board to get the image names for.
Returns:
A list of all image names for the board.
"""
return self._services.board_images.get_all_board_image_names_for_board(board_id)
class LoggerInterface(InvocationContextInterface):
def debug(self, message: str) -> None:
"""
Logs a debug message.
"""Logs a debug message.
:param message: The message to log.
Args:
message: The message to log.
"""
self._services.logger.debug(message)
def info(self, message: str) -> None:
"""
Logs an info message.
"""Logs an info message.
:param message: The message to log.
Args:
message: The message to log.
"""
self._services.logger.info(message)
def warning(self, message: str) -> None:
"""
Logs a warning message.
"""Logs a warning message.
:param message: The message to log.
Args:
message: The message to log.
"""
self._services.logger.warning(message)
def error(self, message: str) -> None:
"""
Logs an error message.
"""Logs an error message.
:param message: The message to log.
Args:
message: The message to log.
"""
self._services.logger.error(message)
@ -150,164 +157,286 @@ class ImagesInterface(InvocationContextInterface):
image_category: ImageCategory = ImageCategory.GENERAL,
metadata: Optional[MetadataField] = None,
) -> ImageDTO:
"""
Saves an image, returning its DTO.
"""Saves an image, returning its DTO.
If the current queue item has a workflow or metadata, it is automatically saved with the image.
:param image: The image to save, as a PIL image.
:param board_id: The board ID to add the image to, if it should be added. It the invocation \
Args:
image: The image to save, as a PIL image.
board_id: The board ID to add the image to, if it should be added. It the invocation \
inherits from `WithBoard`, that board will be used automatically. **Use this only if \
you want to override or provide a board manually!**
:param image_category: The category of the image. Only the GENERAL category is added \
image_category: The category of the image. Only the GENERAL category is added \
to the gallery.
:param metadata: The metadata to save with the image, if it should have any. If the \
metadata: The metadata to save with the image, if it should have any. If the \
invocation inherits from `WithMetadata`, that metadata will be used automatically. \
**Use this only if you want to override or provide metadata manually!**
Returns:
The saved image DTO.
"""
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
metadata_ = None
if metadata:
metadata_ = metadata
elif isinstance(self._context_data.invocation, WithMetadata):
metadata_ = self._context_data.invocation.metadata
elif isinstance(self._data.invocation, WithMetadata):
metadata_ = self._data.invocation.metadata
# If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None.
board_id_ = None
if board_id:
board_id_ = board_id
elif isinstance(self._context_data.invocation, WithBoard) and self._context_data.invocation.board:
board_id_ = self._context_data.invocation.board.board_id
elif isinstance(self._data.invocation, WithBoard) and self._data.invocation.board:
board_id_ = self._data.invocation.board.board_id
return self._services.images.create(
image=image,
is_intermediate=self._context_data.invocation.is_intermediate,
is_intermediate=self._data.invocation.is_intermediate,
image_category=image_category,
board_id=board_id_,
metadata=metadata_,
image_origin=ResourceOrigin.INTERNAL,
workflow=self._context_data.workflow,
session_id=self._context_data.session_id,
node_id=self._context_data.invocation.id,
workflow=self._data.queue_item.workflow,
session_id=self._data.queue_item.session_id,
node_id=self._data.invocation.id,
)
def get_pil(self, image_name: str) -> Image:
"""
Gets an image as a PIL Image object.
def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image:
"""Gets an image as a PIL Image object.
:param image_name: The name of the image to get.
Args:
image_name: The name of the image to get.
mode: The color mode to convert the image to. If None, the original mode is used.
Returns:
The image as a PIL Image object.
"""
return self._services.images.get_pil_image(image_name)
image = self._services.images.get_pil_image(image_name)
if mode and mode != image.mode:
try:
image = image.convert(mode)
except ValueError:
self._services.logger.warning(
f"Could not convert image from {image.mode} to {mode}. Using original mode instead."
)
return image
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
"""
Gets an image's metadata, if it has any.
"""Gets an image's metadata, if it has any.
:param image_name: The name of the image to get the metadata for.
Args:
image_name: The name of the image to get the metadata for.
Returns:
The image's metadata, if it has any.
"""
return self._services.images.get_metadata(image_name)
def get_dto(self, image_name: str) -> ImageDTO:
"""
Gets an image as an ImageDTO object.
"""Gets an image as an ImageDTO object.
:param image_name: The name of the image to get.
Args:
image_name: The name of the image to get.
Returns:
The image as an ImageDTO object.
"""
return self._services.images.get_dto(image_name)
class TensorsInterface(InvocationContextInterface):
def save(self, tensor: Tensor) -> str:
"""
Saves a tensor, returning its name.
"""Saves a tensor, returning its name.
:param tensor: The tensor to save.
Args:
tensor: The tensor to save.
Returns:
The name of the saved tensor.
"""
name = self._services.tensors.save(obj=tensor)
return name
def load(self, name: str) -> Tensor:
"""
Loads a tensor by name.
"""Loads a tensor by name.
:param name: The name of the tensor to load.
Args:
name: The name of the tensor to load.
Returns:
The loaded tensor.
"""
return self._services.tensors.load(name)
class ConditioningInterface(InvocationContextInterface):
def save(self, conditioning_data: ConditioningFieldData) -> str:
"""
Saves a conditioning data object, returning its name.
"""Saves a conditioning data object, returning its name.
:param conditioning_context_data: The conditioning data to save.
Args:
conditioning_data: The conditioning data to save.
Returns:
The name of the saved conditioning data.
"""
name = self._services.conditioning.save(obj=conditioning_data)
return name
def load(self, name: str) -> ConditioningFieldData:
"""
Loads conditioning data by name.
"""Loads conditioning data by name.
:param name: The name of the conditioning data to load.
Args:
name: The name of the conditioning data to load.
Returns:
The loaded conditioning data.
"""
return self._services.conditioning.load(name)
class ModelsInterface(InvocationContextInterface):
def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool:
"""
Checks if a model exists.
def exists(self, key: str) -> bool:
"""Checks if a model exists.
:param model_name: The name of the model to check.
:param base_model: The base model of the model to check.
:param model_type: The type of the model to check.
"""
return self._services.model_manager.model_exists(model_name, base_model, model_type)
Args:
key: The key of the model.
def load(
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
) -> LoadedModelInfo:
Returns:
True if the model exists, False if not.
"""
Loads a model.
return self._services.model_manager.store.exists(key)
:param model_name: The name of the model to get.
:param base_model: The base model of the model to get.
:param model_type: The type of the model to get.
:param submodel: The submodel of the model to get.
:returns: An object representing the loaded model.
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Loads a model.
Args:
key: The key of the model.
submodel_type: The submodel of the model to get.
Returns:
An object representing the loaded model.
"""
# The model manager emits events as it loads the model. It needs the context data to build
# the event payloads.
return self._services.model_manager.get_model(
model_name, base_model, model_type, submodel, context_data=self._context_data
return self._services.model_manager.load_model_by_key(
key=key, submodel_type=submodel_type, context_data=self._data
)
def get_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Gets a model's info, an dict-like object.
def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""Loads a model by its attributes.
:param model_name: The name of the model to get.
:param base_model: The base model of the model to get.
:param model_type: The type of the model to get.
Args:
name: Name of the model.
base: The models' base type, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc.
type: Type of the model, e.g. `ModelType.Main`, `ModelType.Vae`, etc.
submodel_type: The type of submodel to load, e.g. `SubModelType.UNet`, `SubModelType.TextEncoder`, etc. Only main
models have submodels.
Returns:
An object representing the loaded model.
"""
return self._services.model_manager.model_info(model_name, base_model, model_type)
return self._services.model_manager.load_model_by_attr(
model_name=name,
base_model=base,
model_type=type,
submodel=submodel_type,
context_data=self._data,
)
def get_config(self, key: str) -> AnyModelConfig:
"""Gets a model's config.
Args:
key: The key of the model.
Returns:
The model's config.
"""
return self._services.model_manager.store.get_model(key=key)
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""Gets a model's metadata, if it has any.
Args:
key: The key of the model.
Returns:
The model's metadata, if it has any.
"""
return self._services.model_manager.store.get_metadata(key=key)
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
"""Searches for models by path.
Args:
path: The path to search for.
Returns:
A list of models that match the path.
"""
return self._services.model_manager.store.search_by_path(path)
def search_by_attrs(
self,
name: Optional[str] = None,
base: Optional[BaseModelType] = None,
type: Optional[ModelType] = None,
format: Optional[ModelFormat] = None,
) -> list[AnyModelConfig]:
"""Searches for models by attributes.
Args:
name: The name to search for (exact match).
base: The base to search for, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc.
type: Type type of model to search for, e.g. `ModelType.Main`, `ModelType.Vae`, etc.
format: The format of model to search for, e.g. `ModelFormat.Checkpoint`, `ModelFormat.Diffusers`, etc.
Returns:
A list of models that match the attributes.
"""
return self._services.model_manager.store.search_by_attr(
model_name=name,
base_model=base,
model_type=type,
model_format=format,
)
class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig:
"""Gets the app's config."""
"""Gets the app's config.
Returns:
The app's config.
"""
return self._services.configuration.get_config()
class UtilInterface(InvocationContextInterface):
def __init__(
self, services: InvocationServices, data: InvocationContextData, cancel_event: threading.Event
) -> None:
super().__init__(services, data)
self._cancel_event = cancel_event
def is_canceled(self) -> bool:
"""Checks if the current session has been canceled.
Returns:
True if the current session has been canceled, False if not.
"""
return self._cancel_event.is_set()
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
"""
The step callback emits a progress event with the current step, the total number of
@ -315,27 +444,32 @@ class UtilInterface(InvocationContextInterface):
This should be called after each denoising step.
:param intermediate_state: The intermediate state of the diffusion pipeline.
:param base_model: The base model for the current denoising step.
Args:
intermediate_state: The intermediate state of the diffusion pipeline.
base_model: The base model for the current denoising step.
"""
# The step callback needs access to the events and the invocation queue services, but this
# represents a dangerous level of access.
#
# We wrap the step callback so that nodes do not have direct access to these services.
stable_diffusion_step_callback(
context_data=self._context_data,
context_data=self._data,
intermediate_state=intermediate_state,
base_model=base_model,
invocation_queue=self._services.queue,
events=self._services.events,
is_canceled=self.is_canceled,
)
class InvocationContext:
"""
The `InvocationContext` provides access to various services and data for the current invocation.
"""Provides access to various services and data for the current invocation.
Attributes:
images (ImagesInterface): Methods to save, get and update images and their metadata.
tensors (TensorsInterface): Methods to save and get tensors, including image, noise, masks, and masked images.
conditioning (ConditioningInterface): Methods to save and get conditioning data.
models (ModelsInterface): Methods to check if a model exists, get a model, and get a model's info.
logger (LoggerInterface): The app logger.
config (ConfigInterface): The app config.
util (UtilInterface): Utility methods, including a method to check if an invocation was canceled and step callbacks.
boards (BoardsInterface): Methods to interact with boards.
"""
def __init__(
@ -348,50 +482,54 @@ class InvocationContext:
config: ConfigInterface,
util: UtilInterface,
boards: BoardsInterface,
context_data: InvocationContextData,
data: InvocationContextData,
services: InvocationServices,
) -> None:
self.images = images
"""Provides methods to save, get and update images and their metadata."""
"""Methods to save, get and update images and their metadata."""
self.tensors = tensors
"""Provides methods to save and get tensors, including image, noise, masks, and masked images."""
"""Methods to save and get tensors, including image, noise, masks, and masked images."""
self.conditioning = conditioning
"""Provides methods to save and get conditioning data."""
"""Methods to save and get conditioning data."""
self.models = models
"""Provides methods to check if a model exists, get a model, and get a model's info."""
"""Methods to check if a model exists, get a model, and get a model's info."""
self.logger = logger
"""Provides access to the app logger."""
"""The app logger."""
self.config = config
"""Provides access to the app's config."""
"""The app config."""
self.util = util
"""Provides utility methods."""
"""Utility methods, including a method to check if an invocation was canceled and step callbacks."""
self.boards = boards
"""Provides methods to interact with boards."""
self._data = context_data
"""Provides data about the current queue item and invocation. This is an internal API and may change without warning."""
"""Methods to interact with boards."""
self._data = data
"""An internal API providing access to data about the current queue item and invocation. You probably shouldn't use this. It may change without warning."""
self._services = services
"""Provides access to the full application services. This is an internal API and may change without warning."""
"""An internal API providing access to all application services. You probably shouldn't use this. It may change without warning."""
def build_invocation_context(
services: InvocationServices,
context_data: InvocationContextData,
data: InvocationContextData,
cancel_event: threading.Event,
) -> InvocationContext:
"""
Builds the invocation context for a specific invocation execution.
"""Builds the invocation context for a specific invocation execution.
:param invocation_services: The invocation services to wrap.
:param invocation_context_data: The invocation context data.
Args:
services: The invocation services to wrap.
data: The invocation context data.
Returns:
The invocation context.
"""
logger = LoggerInterface(services=services, context_data=context_data)
images = ImagesInterface(services=services, context_data=context_data)
tensors = TensorsInterface(services=services, context_data=context_data)
models = ModelsInterface(services=services, context_data=context_data)
config = ConfigInterface(services=services, context_data=context_data)
util = UtilInterface(services=services, context_data=context_data)
conditioning = ConditioningInterface(services=services, context_data=context_data)
boards = BoardsInterface(services=services, context_data=context_data)
logger = LoggerInterface(services=services, data=data)
images = ImagesInterface(services=services, data=data)
tensors = TensorsInterface(services=services, data=data)
models = ModelsInterface(services=services, data=data)
config = ConfigInterface(services=services, data=data)
util = UtilInterface(services=services, data=data, cancel_event=cancel_event)
conditioning = ConditioningInterface(services=services, data=data)
boards = BoardsInterface(services=services, data=data)
ctx = InvocationContext(
images=images,
@ -399,7 +537,7 @@ def build_invocation_context(
config=config,
tensors=tensors,
models=models,
context_data=context_data,
data=data,
util=util,
conditioning=conditioning,
services=services,

View File

@ -8,6 +8,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -33,6 +34,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_3(app_config=config, logger=logger))
migrator.register_migration(build_migration_4())
migrator.register_migration(build_migration_5())
migrator.register_migration(build_migration_6())
migrator.run_migrations()
return db

View File

@ -0,0 +1,62 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration6Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._recreate_model_triggers(cursor)
self._delete_ip_adapters(cursor)
def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None:
"""
Adds the timestamp trigger to the model_config table.
This trigger was inadvertently dropped in earlier migration scripts.
"""
cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
AFTER UPDATE
ON model_config FOR EACH ROW
BEGIN
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
)
def _delete_ip_adapters(self, cursor: sqlite3.Cursor) -> None:
"""
Delete all the IP adapters.
The model manager will automatically find and re-add them after the migration
is done. This allows the manager to add the correct image encoder to their
configuration records.
"""
cursor.execute(
"""--sql
DELETE FROM model_config
WHERE type='ip_adapter';
"""
)
def build_migration_6() -> Migration:
"""
Build the migration from database version 5 to 6.
This migration does the following:
- Adds the model_config_updated_at trigger if it does not exist
- Delete all ip_adapter models so that the model prober can find and
update with the correct image processor model.
"""
migration_6 = Migration(
from_version=5,
to_version=6,
callback=Migration6Callback(),
)
return migration_6

View File

@ -5,7 +5,7 @@ import uuid
import numpy as np
def get_timestamp():
def get_timestamp() -> int:
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
@ -20,16 +20,16 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
SEED_MAX = np.iinfo(np.uint32).max
def get_random_seed():
def get_random_seed() -> int:
rng = np.random.default_rng(seed=None)
return int(rng.integers(0, SEED_MAX))
def uuid_string():
def uuid_string() -> str:
res = uuid.uuid4()
return str(res)
def is_optional(value: typing.Any):
def is_optional(value: typing.Any) -> bool:
"""Checks if a value is typed as Optional. Note that Optional is sugar for Union[x, None]."""
return typing.get_origin(value) is typing.Union and type(None) in typing.get_args(value)

View File

@ -1,17 +1,16 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
import torch
from PIL import Image
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
from invokeai.app.services.session_processor.session_processor_common import CanceledException, ProgressImage
from invokeai.backend.model_manager.config import BaseModelType
from ...backend.model_management.models import BaseModelType
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC
from invokeai.app.services.shared.invocation_context import InvocationContextData
@ -34,10 +33,10 @@ def stable_diffusion_step_callback(
context_data: "InvocationContextData",
intermediate_state: PipelineIntermediateState,
base_model: BaseModelType,
invocation_queue: "InvocationQueueABC",
events: "EventServiceBase",
is_canceled: Callable[[], bool],
) -> None:
if invocation_queue.is_canceled(context_data.session_id):
if is_canceled():
raise CanceledException
# Some schedulers report not only the noisy latents at the current timestep,
@ -115,12 +114,12 @@ def stable_diffusion_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG")
events.emit_generator_progress(
queue_id=context_data.queue_id,
queue_item_id=context_data.queue_item_id,
queue_batch_id=context_data.batch_id,
graph_execution_state_id=context_data.session_id,
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
node_id=context_data.invocation.id,
source_node_id=context_data.source_node_id,
source_node_id=context_data.source_invocation_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step,
order=intermediate_state.order,

View File

@ -1,8 +1,47 @@
import re
from typing import List, Tuple
import invokeai.backend.util.logging as logger
from invokeai.app.services.model_records import UnknownModelException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.textual_inversion import TextualInversionModelRaw
def extract_ti_triggers_from_prompt(prompt: str) -> list[str]:
ti_triggers = []
def extract_ti_triggers_from_prompt(prompt: str) -> List[str]:
ti_triggers: List[str] = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
ti_triggers.append(trigger)
ti_triggers.append(str(trigger))
return ti_triggers
def generate_ti_list(
prompt: str, base: BaseModelType, context: InvocationContext
) -> List[Tuple[str, TextualInversionModelRaw]]:
ti_list: List[Tuple[str, TextualInversionModelRaw]] = []
for trigger in extract_ti_triggers_from_prompt(prompt):
name_or_key = trigger[1:-1]
try:
loaded_model = context.models.load(key=name_or_key)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)
assert loaded_model.config.base == base
ti_list.append((name_or_key, model))
except UnknownModelException:
try:
loaded_model = context.models.load_by_attrs(
name=name_or_key, base=base, type=ModelType.TextualInversion
)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)
assert loaded_model.config.base == base
ti_list.append((name_or_key, model))
except UnknownModelException:
pass
except ValueError:
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
except AssertionError:
logger.warning(f'trigger: "{trigger}" not a valid textual inversion model for this graph')
except Exception:
logger.warning(f'Failed to load TI model for trigger: "{trigger}"')
return ti_list

View File

@ -1,12 +1,3 @@
"""
Initialization file for invokeai.backend
"""
from .model_management import ( # noqa: F401
BaseModelType,
LoadedModelInfo,
ModelCache,
ModelManager,
ModelType,
SubModelType,
)
from .model_management.models import SilenceWarnings # noqa: F401

View File

@ -0,0 +1,4 @@
"""Initialization file for invokeai.backend.embeddings modules."""
# from .model_patcher import ModelPatcher
# __all__ = ["ModelPatcher"]

View File

@ -0,0 +1,12 @@
"""Base class for LoRA and Textual Inversion models.
The EmbeddingRaw class is the base class of LoRAModelRaw and TextualInversionModelRaw,
and is used for type checking of calls to the model patcher.
The use of "Raw" here is a historical artifact, and carried forward in
order to avoid confusion.
"""
class EmbeddingModelRaw:
"""Base class for LoRA and Textual Inversion models."""

View File

@ -17,6 +17,8 @@ from invokeai.backend.util.util import download_with_progress_bar
config = InvokeAIAppConfig.get_config()
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
DEPTH_ANYTHING_MODELS = {
"large": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
@ -53,9 +55,9 @@ transform = Compose(
class DepthAnythingDetector:
def __init__(self) -> None:
self.model = None
self.model_size: Union[Literal["large", "base", "small"], None] = None
self.model_size: Union[DEPTH_ANYTHING_MODEL_SIZES, None] = None
def load_model(self, model_size=Literal["large", "base", "small"]):
def load_model(self, model_size: DEPTH_ANYTHING_MODEL_SIZES = "small"):
DEPTH_ANYTHING_MODEL_PATH = pathlib.Path(config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"])
if not DEPTH_ANYTHING_MODEL_PATH.exists():
download_with_progress_bar(DEPTH_ANYTHING_MODELS[model_size]["url"], DEPTH_ANYTHING_MODEL_PATH)
@ -84,16 +86,19 @@ class DepthAnythingDetector:
self.model.to(device)
return self
def __call__(self, image, resolution=512, offload=False):
image = np.array(image, dtype=np.uint8)
image = image[:, :, ::-1] / 255.0
def __call__(self, image: Image.Image, resolution: int = 512):
if self.model is None:
raise Exception("Depth Anything Model not loaded")
image_height, image_width = image.shape[:2]
image = transform({"image": image})["image"]
image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
np_image = np.array(image, dtype=np.uint8)
np_image = np_image[:, :, ::-1] / 255.0
image_height, image_width = np_image.shape[:2]
np_image = transform({"image": image})["image"]
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device())
with torch.no_grad():
depth = self.model(image)
depth = self.model(tensor_image)
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
@ -103,7 +108,4 @@ class DepthAnythingDetector:
new_height = int(image_height * (resolution / image_width))
depth_map = depth_map.resize((resolution, new_height))
if offload:
del self.model
return depth_map

View File

@ -8,8 +8,8 @@ from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend import SilenceWarnings
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.silence_warnings import SilenceWarnings
config = InvokeAIAppConfig.get_config()

View File

@ -8,7 +8,6 @@ from invokeai.app.services.config import InvokeAIAppConfig
def check_invokeai_root(config: InvokeAIAppConfig):
try:
assert config.model_conf_path.exists(), f"{config.model_conf_path} not found"
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
assert config.models_path.exists(), f"{config.models_path} not found"
if not config.ignore_missing_core_models:

View File

@ -1,14 +1,11 @@
"""Utility (backend) functions used by model_install.py"""
import re
from logging import Logger
from pathlib import Path
from typing import Any, Dict, List, Optional
import omegaconf
from huggingface_hub import HfFolder
from pydantic import BaseModel, Field
from pydantic.dataclasses import dataclass
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm
@ -18,13 +15,10 @@ from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
from invokeai.app.services.model_install import (
HFModelSource,
LocalModelSource,
ModelInstallService,
ModelInstallServiceBase,
ModelSource,
URLModelSource,
)
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import (
@ -36,7 +30,7 @@ from invokeai.backend.model_manager.metadata import UnknownMetadataException
from invokeai.backend.util.logging import InvokeAILogger
# name of the starter models file
INITIAL_MODELS = "INITIAL_MODELS2.yaml"
INITIAL_MODELS = "INITIAL_MODELS.yaml"
def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
@ -44,7 +38,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService
logger = InvokeAILogger.get_logger(config=app_config)
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
db = init_db(config=app_config, logger=logger, image_files=image_files)
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
return obj
@ -53,12 +47,10 @@ def initialize_installer(
) -> ModelInstallServiceBase:
"""Return an initialized ModelInstallService object."""
record_store = initialize_record_store(app_config)
metadata_store = record_store.metadata_store
download_queue = DownloadQueueService()
installer = ModelInstallService(
app_config=app_config,
record_store=record_store,
metadata_store=metadata_store,
download_queue=download_queue,
event_bus=event_bus,
)
@ -98,11 +90,13 @@ class TqdmEventService(EventServiceBase):
super().__init__()
self._bars: Dict[str, tqdm] = {}
self._last: Dict[str, int] = {}
self._logger = InvokeAILogger.get_logger(__name__)
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events."""
data = payload["data"]
source = data["source"]
if payload["event"] == "model_install_downloading":
data = payload["data"]
dest = data["local_path"]
total_bytes = data["total_bytes"]
bytes = data["bytes"]
@ -111,6 +105,12 @@ class TqdmEventService(EventServiceBase):
self._last[dest] = 0
self._bars[dest].update(bytes - self._last[dest])
self._last[dest] = bytes
elif payload["event"] == "model_install_completed":
self._logger.info(f"{source}: installed successfully.")
elif payload["event"] == "model_install_error":
self._logger.warning(f"{source}: installation failed with error {data['error']}")
elif payload["event"] == "model_install_cancelled":
self._logger.warning(f"{source}: installation cancelled")
class InstallHelper(object):
@ -218,29 +218,13 @@ class InstallHelper(object):
additional_models.append(reverse_source[requirement])
model_list.extend(additional_models)
def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource:
assert model_info.source
model_path_id_or_url = model_info.source.strip("\"' ")
model_path = Path(model_path_id_or_url)
if model_path.exists(): # local file on disk
return LocalModelSource(path=model_path.absolute(), inplace=True)
if re.match(r"^[^/]+/[^/]+$", model_path_id_or_url): # hugging face repo_id
return HFModelSource(
repo_id=model_path_id_or_url,
access_token=HfFolder.get_token(),
subfolder=model_info.subfolder,
)
if re.match(r"^(http|https):", model_path_id_or_url):
return URLModelSource(url=AnyHttpUrl(model_path_id_or_url))
raise ValueError(f"Unsupported model source: {model_path_id_or_url}")
def add_or_delete(self, selections: InstallSelections) -> None:
"""Add or delete selected models."""
installer = self._installer
self._add_required_models(selections.install_models)
for model in selections.install_models:
source = self._make_install_source(model)
assert model.source
model_path_id_or_url = model.source.strip("\"' ")
config = (
{
"description": model.description,
@ -251,12 +235,12 @@ class InstallHelper(object):
)
try:
installer.import_model(
source=source,
installer.heuristic_import(
source=model_path_id_or_url,
config=config,
)
except (UnknownMetadataException, InvalidModelConfigException, HTTPError, OSError) as e:
self._logger.warning(f"{source}: {e}")
self._logger.warning(f"{model.source}: {e}")
for model_to_remove in selections.remove_models:
parts = model_to_remove.split("/")
@ -270,12 +254,14 @@ class InstallHelper(object):
model_name=model_name,
)
if len(matches) > 1:
print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.")
self._logger.error(
"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate"
)
elif not matches:
print(f"{model}: unknown model")
self._logger.error(f"{model_to_remove}: unknown model")
else:
for m in matches:
print(f"Deleting {m.type}:{m.name}")
self._logger.info(f"Deleting {m.type}:{m.name}")
installer.delete(m.key)
installer.wait_for_installs()

View File

@ -18,31 +18,30 @@ from argparse import Namespace
from enum import Enum
from pathlib import Path
from shutil import get_terminal_size
from typing import Any, get_args, get_type_hints
from typing import Any, Optional, Set, Tuple, Type, get_args, get_type_hints
from urllib import request
import npyscreen
import omegaconf
import psutil
import torch
import transformers
import yaml
from diffusers import AutoencoderKL
from diffusers import AutoencoderKL, ModelMixin
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder
from huggingface_hub import login as hf_hub_login
from omegaconf import OmegaConf
from pydantic import ValidationError
from omegaconf import DictConfig, OmegaConf
from pydantic.error_wrappers import ValidationError
from tqdm import tqdm
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, hf_download_from_pretrained
from invokeai.backend.model_management.model_probe import BaseModelType, ModelType
from invokeai.backend.model_manager import BaseModelType, ModelType
from invokeai.backend.util import choose_precision, choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
from invokeai.frontend.install.model_install import addModelsForm
# TO DO - Move all the frontend code into invokeai.frontend.install
from invokeai.frontend.install.widgets import (
@ -61,7 +60,7 @@ warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
def get_literal_fields(field) -> list[Any]:
def get_literal_fields(field: str) -> Tuple[Any]:
return get_args(get_type_hints(InvokeAIAppConfig).get(field))
@ -80,12 +79,13 @@ ATTENTION_SLICE_CHOICES = get_literal_fields("attention_slice_size")
GENERATION_OPT_CHOICES = ["sequential_guidance", "force_tiled_decode", "lazy_offload"]
GB = 1073741824 # GB in bytes
HAS_CUDA = torch.cuda.is_available()
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0)
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0.0, 0.0)
MAX_VRAM /= GB
MAX_RAM = psutil.virtual_memory().total / GB
FORCE_FULL_PRECISION = False
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# This is the InvokeAI initialization file, which contains command-line default values.
# Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting
@ -96,13 +96,15 @@ logger = InvokeAILogger.get_logger()
class DummyWidgetValue(Enum):
"""Dummy widget values."""
zero = 0
true = True
false = False
# --------------------------------------------
def postscript(errors: None):
def postscript(errors: Set[str]) -> None:
if not any(errors):
message = f"""
** INVOKEAI INSTALLATION SUCCESSFUL **
@ -112,9 +114,6 @@ then run one of the following commands to start InvokeAI.
Web UI:
invokeai-web
Command-line client:
invokeai
If you installed using an installation script, run:
{config.root_path}/invoke.{"bat" if sys.platform == "win32" else "sh"}
@ -143,7 +142,7 @@ def yes_or_no(prompt: str, default_yes=True):
# ---------------------------------------------
def HfLogin(access_token) -> str:
def HfLogin(access_token) -> None:
"""
Helper for logging in to Huggingface
The stdout capture is needed to hide the irrelevant "git credential helper" warning
@ -162,7 +161,7 @@ def HfLogin(access_token) -> str:
# -------------------------------------
class ProgressBar:
def __init__(self, model_name="file"):
def __init__(self, model_name: str = "file"):
self.pbar = None
self.name = model_name
@ -179,6 +178,22 @@ class ProgressBar:
self.pbar.update(block_size)
# ---------------------------------------------
def hf_download_from_pretrained(model_class: Type[ModelMixin], model_name: str, destination: Path, **kwargs: Any):
filter = lambda x: "fp16 is not a valid" not in x.getMessage() # noqa E731
logger.addFilter(filter)
try:
model = model_class.from_pretrained(
model_name,
resume_download=True,
**kwargs,
)
model.save_pretrained(destination, safe_serialization=True)
finally:
logger.removeFilter(filter)
return destination
# ---------------------------------------------
def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"):
try:
@ -249,6 +264,7 @@ def download_conversion_models():
# ---------------------------------------------
# TO DO: use the download queue here.
def download_realesrgan():
logger.info("Installing ESRGAN Upscaling models...")
URLs = [
@ -288,18 +304,19 @@ def download_lama():
# ---------------------------------------------
def download_support_models():
def download_support_models() -> None:
download_realesrgan()
download_lama()
download_conversion_models()
# -------------------------------------
def get_root(root: str = None) -> str:
def get_root(root: Optional[str] = None) -> str:
if root:
return root
elif os.environ.get("INVOKEAI_ROOT"):
return os.environ.get("INVOKEAI_ROOT")
elif root := os.environ.get("INVOKEAI_ROOT"):
assert root is not None
return root
else:
return str(config.root_path)
@ -390,7 +407,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
begin_entry_at=3,
max_height=2,
relx=30,
max_width=56,
max_width=80,
scroll_exit=True,
)
self.add_widget_intelligent(
@ -455,6 +472,25 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
max_width=110,
scroll_exit=True,
)
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="Model disk conversion cache size (GB). This is used to cache safetensors files that need to be converted to diffusers..",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.disk = self.add_widget_intelligent(
npyscreen.Slider,
value=clip(old_opts.convert_cache, range=(0, 100), step=0.5),
out_of=100,
lowest=0.0,
step=0.5,
relx=8,
scroll_exit=True,
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).",
@ -495,6 +531,14 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
)
else:
self.vram = DummyWidgetValue.zero
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.FixedText,
value="Location of the database used to store model path and configuration information:",
editable=False,
color="CONTROL",
)
self.nextrely += 1
self.outdir = self.add_widget_intelligent(
FileBox,
@ -506,19 +550,21 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
labelColor="GOOD",
begin_entry_at=40,
max_height=3,
max_width=127,
scroll_exit=True,
)
self.autoimport_dirs = {}
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
FileBox,
name="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
value=str(config.root_path / config.autoimport_dir),
name="Optional folder to scan for new checkpoints, ControlNets, LoRAs and TI models",
value=str(config.root_path / config.autoimport_dir) if config.autoimport_dir else "",
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=32,
max_height=3,
max_width=127,
scroll_exit=True,
)
self.nextrely += 1
@ -555,6 +601,10 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
self.attention_slice_label.hidden = not show
self.attention_slice_size.hidden = not show
def show_hide_model_conf_override(self, value):
self.model_conf_override.hidden = value
self.model_conf_override.display()
def on_ok(self):
options = self.marshall_arguments()
if self.validate_field_values(options):
@ -584,18 +634,21 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
else:
return True
def marshall_arguments(self):
def marshall_arguments(self) -> Namespace:
new_opts = Namespace()
for attr in [
"ram",
"vram",
"convert_cache",
"outdir",
]:
if hasattr(self, attr):
setattr(new_opts, attr, getattr(self, attr).value)
for attr in self.autoimport_dirs:
if not self.autoimport_dirs[attr].value:
continue
directory = Path(self.autoimport_dirs[attr].value)
if directory.is_relative_to(config.root_path):
directory = directory.relative_to(config.root_path)
@ -610,18 +663,18 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
generation_options = [GENERATION_OPT_CHOICES[x] for x in self.generation_options.value]
for v in GENERATION_OPT_CHOICES:
setattr(new_opts, v, v in generation_options)
return new_opts
class EditOptApplication(npyscreen.NPSAppManaged):
def __init__(self, program_opts: Namespace, invokeai_opts: Namespace):
def __init__(self, program_opts: Namespace, invokeai_opts: InvokeAIAppConfig, install_helper: InstallHelper):
super().__init__()
self.program_opts = program_opts
self.invokeai_opts = invokeai_opts
self.user_cancelled = False
self.autoload_pending = True
self.install_selections = default_user_selections(program_opts)
self.install_helper = install_helper
self.install_selections = default_user_selections(program_opts, install_helper)
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
@ -640,15 +693,6 @@ class EditOptApplication(npyscreen.NPSAppManaged):
cycle_widgets=False,
)
def new_opts(self):
return self.options.marshall_arguments()
def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Namespace:
editApp = EditOptApplication(program_opts, invokeai_opts)
editApp.run()
return editApp.new_opts()
def default_ramcache() -> float:
"""Run a heuristic for the default RAM cache based on installed RAM."""
@ -660,27 +704,19 @@ def default_ramcache() -> float:
) # 2.1 is just large enough for sd 1.5 ;-)
def default_startup_options(init_file: Path) -> Namespace:
def default_startup_options(init_file: Path) -> InvokeAIAppConfig:
opts = InvokeAIAppConfig.get_config()
opts.ram = opts.ram or default_ramcache()
opts.ram = default_ramcache()
opts.precision = "float32" if FORCE_FULL_PRECISION else choose_precision(torch.device(choose_torch_device()))
return opts
def default_user_selections(program_opts: Namespace) -> InstallSelections:
try:
installer = ModelInstall(config)
except omegaconf.errors.ConfigKeyError:
logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing")
initialize_rootdir(config.root_path, True)
installer = ModelInstall(config)
models = installer.all_models()
def default_user_selections(program_opts: Namespace, install_helper: InstallHelper) -> InstallSelections:
default_model = install_helper.default_model()
assert default_model is not None
default_models = [default_model] if program_opts.default_only else install_helper.recommended_models()
return InstallSelections(
install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id]
if program_opts.default_only
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
if program_opts.yes_to_all
else [],
install_models=default_models if program_opts.yes_to_all else [],
)
@ -716,22 +752,12 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
path.mkdir(parents=True, exist_ok=True)
def maybe_create_models_yaml(root: Path):
models_yaml = root / "configs" / "models.yaml"
if models_yaml.exists():
if OmegaConf.load(models_yaml).get("__metadata__"): # up to date
return
else:
logger.info("Creating new models.yaml, original saved as models.yaml.orig")
models_yaml.rename(models_yaml.parent / "models.yaml.orig")
with open(models_yaml, "w") as yaml_file:
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
# -------------------------------------
def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
invokeai_opts = default_startup_options(initfile)
def run_console_ui(
program_opts: Namespace, initfile: Path, install_helper: InstallHelper
) -> Tuple[Optional[Namespace], Optional[InstallSelections]]:
first_time = not (config.root_path / "invokeai.yaml").exists()
invokeai_opts = default_startup_options(initfile) if first_time else config
invokeai_opts.root = program_opts.root
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
@ -739,13 +765,7 @@ def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace
"Could not increase terminal size. Try running again with a larger window or smaller font size."
)
# the install-models application spawns a subprocess to install
# models, and will crash unless this is set before running.
import torch
torch.multiprocessing.set_start_method("spawn")
editApp = EditOptApplication(program_opts, invokeai_opts)
editApp = EditOptApplication(program_opts, invokeai_opts, install_helper)
editApp.run()
if editApp.user_cancelled:
return (None, None)
@ -754,7 +774,7 @@ def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace
# -------------------------------------
def write_opts(opts: Namespace, init_file: Path):
def write_opts(opts: InvokeAIAppConfig, init_file: Path) -> None:
"""
Update the invokeai.yaml file with values from current settings.
"""
@ -762,7 +782,7 @@ def write_opts(opts: Namespace, init_file: Path):
new_config = InvokeAIAppConfig.get_config()
new_config.root = config.root
for key, value in opts.__dict__.items():
for key, value in vars(opts).items():
if hasattr(new_config, key):
setattr(new_config, key, value)
@ -779,7 +799,7 @@ def default_output_dir() -> Path:
# -------------------------------------
def write_default_options(program_opts: Namespace, initfile: Path):
def write_default_options(program_opts: Namespace, initfile: Path) -> None:
opt = default_startup_options(initfile)
write_opts(opt, initfile)
@ -789,16 +809,11 @@ def write_default_options(program_opts: Namespace, initfile: Path):
# the legacy Args object in order to parse
# the old init file and write out the new
# yaml format.
def migrate_init_file(legacy_format: Path):
def migrate_init_file(legacy_format: Path) -> None:
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
new = InvokeAIAppConfig.get_config()
fields = [
x
for x, y in InvokeAIAppConfig.model_fields.items()
if (y.json_schema_extra.get("category", None) if y.json_schema_extra else None) != "DEPRECATED"
]
for attr in fields:
for attr in InvokeAIAppConfig.model_fields.keys():
if hasattr(old, attr):
try:
setattr(new, attr, getattr(old, attr))
@ -819,7 +834,7 @@ def migrate_init_file(legacy_format: Path):
# -------------------------------------
def migrate_models(root: Path):
def migrate_models(root: Path) -> None:
from invokeai.backend.install.migrate_to_3 import do_migrate
do_migrate(root, root)
@ -838,7 +853,9 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool:
):
logger.info("** Migrating invokeai.init to invokeai.yaml")
migrate_init_file(old_init_file)
config.parse_args(argv=[], conf=OmegaConf.load(new_init_file))
omegaconf = OmegaConf.load(new_init_file)
assert isinstance(omegaconf, DictConfig)
config.parse_args(argv=[], conf=omegaconf)
if old_hub.exists():
migrate_models(config.root_path)
@ -850,6 +867,7 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool:
# -------------------------------------
def main() -> None:
global FORCE_FULL_PRECISION # FIXME
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
parser.add_argument(
"--skip-sd-weights",
@ -901,7 +919,6 @@ def main() -> None:
help="path to root of install directory",
)
opt = parser.parse_args()
invoke_args = []
if opt.root:
invoke_args.extend(["--root", opt.root])
@ -911,6 +928,7 @@ def main() -> None:
logger = InvokeAILogger().get_logger(config=config)
errors = set()
FORCE_FULL_PRECISION = opt.full_precision # FIXME global
try:
# if we do a root migration/upgrade, then we are keeping previous
@ -921,14 +939,18 @@ def main() -> None:
# run this unconditionally in case new directories need to be added
initialize_rootdir(config.root_path, opt.yes_to_all)
models_to_download = default_user_selections(opt)
# this will initialize and populate the models tables if not present
install_helper = InstallHelper(config, logger)
models_to_download = default_user_selections(opt, install_helper)
new_init_file = config.root_path / "invokeai.yaml"
if opt.yes_to_all:
write_default_options(opt, new_init_file)
init_options = Namespace(precision="float32" if opt.full_precision else "float16")
else:
init_options, models_to_download = run_console_ui(opt, new_init_file)
init_options, models_to_download = run_console_ui(opt, new_init_file, install_helper)
if init_options:
write_opts(init_options, new_init_file)
else:
@ -943,10 +965,12 @@ def main() -> None:
if opt.skip_sd_weights:
logger.warning("Skipping diffusion weights download per user request")
elif models_to_download:
process_and_execute(opt, models_to_download)
install_helper.add_or_delete(models_to_download)
postscript(errors=errors)
if not opt.yes_to_all:
input("Press any key to continue...")
except WindowTooSmallException as e:

View File

@ -1,591 +0,0 @@
"""
Migrate the models directory and models.yaml file from an existing
InvokeAI 2.3 installation to 3.0.0.
"""
import argparse
import os
import shutil
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Union
import diffusers
import transformers
import yaml
from diffusers import AutoencoderKL, StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from omegaconf import DictConfig, OmegaConf
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager
from invokeai.backend.model_management.model_probe import BaseModelType, ModelProbe, ModelProbeInfo, ModelType
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
diffusers.logging.set_verbosity_error()
# holder for paths that we will migrate
@dataclass
class ModelPaths:
models: Path
embeddings: Path
loras: Path
controlnets: Path
class MigrateTo3(object):
def __init__(
self,
from_root: Path,
to_models: Path,
model_manager: ModelManager,
src_paths: ModelPaths,
):
self.root_directory = from_root
self.dest_models = to_models
self.mgr = model_manager
self.src_paths = src_paths
@classmethod
def initialize_yaml(cls, yaml_file: Path):
with open(yaml_file, "w") as file:
file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
def create_directory_structure(self):
"""
Create the basic directory structure for the models folder.
"""
for model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
for model_type in [
ModelType.Main,
ModelType.Vae,
ModelType.Lora,
ModelType.ControlNet,
ModelType.TextualInversion,
]:
path = self.dest_models / model_base.value / model_type.value
path.mkdir(parents=True, exist_ok=True)
path = self.dest_models / "core"
path.mkdir(parents=True, exist_ok=True)
@staticmethod
def copy_file(src: Path, dest: Path):
"""
copy a single file with logging
"""
if dest.exists():
logger.info(f"Skipping existing {str(dest)}")
return
logger.info(f"Copying {str(src)} to {str(dest)}")
try:
shutil.copy(src, dest)
except Exception as e:
logger.error(f"COPY FAILED: {str(e)}")
@staticmethod
def copy_dir(src: Path, dest: Path):
"""
Recursively copy a directory with logging
"""
if dest.exists():
logger.info(f"Skipping existing {str(dest)}")
return
logger.info(f"Copying {str(src)} to {str(dest)}")
try:
shutil.copytree(src, dest)
except Exception as e:
logger.error(f"COPY FAILED: {str(e)}")
def migrate_models(self, src_dir: Path):
"""
Recursively walk through src directory, probe anything
that looks like a model, and copy the model into the
appropriate location within the destination models directory.
"""
directories_scanned = set()
for root, dirs, files in os.walk(src_dir, followlinks=True):
for d in dirs:
try:
model = Path(root, d)
info = ModelProbe().heuristic_probe(model)
if not info:
continue
dest = self._model_probe_to_path(info) / model.name
self.copy_dir(model, dest)
directories_scanned.add(model)
except Exception as e:
logger.error(str(e))
except KeyboardInterrupt:
raise
for f in files:
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
# let them be copied as part of a tree copy operation
try:
if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}:
continue
model = Path(root, f)
if model.parent in directories_scanned:
continue
info = ModelProbe().heuristic_probe(model)
if not info:
continue
dest = self._model_probe_to_path(info) / f
self.copy_file(model, dest)
except Exception as e:
logger.error(str(e))
except KeyboardInterrupt:
raise
def migrate_support_models(self):
"""
Copy the clipseg, upscaler, and restoration models to their new
locations.
"""
dest_directory = self.dest_models
if (self.root_directory / "models/clipseg").exists():
self.copy_dir(self.root_directory / "models/clipseg", dest_directory / "core/misc/clipseg")
if (self.root_directory / "models/realesrgan").exists():
self.copy_dir(self.root_directory / "models/realesrgan", dest_directory / "core/upscaling/realesrgan")
for d in ["codeformer", "gfpgan"]:
path = self.root_directory / "models" / d
if path.exists():
self.copy_dir(path, dest_directory / f"core/face_restoration/{d}")
def migrate_tuning_models(self):
"""
Migrate the embeddings, loras and controlnets directories to their new homes.
"""
for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]:
if not src:
continue
if src.is_dir():
logger.info(f"Scanning {src}")
self.migrate_models(src)
else:
logger.info(f"{src} directory not found; skipping")
continue
def migrate_conversion_models(self):
"""
Migrate all the models that are needed by the ckpt_to_diffusers conversion
script.
"""
dest_directory = self.dest_models
kwargs = {
"cache_dir": self.root_directory / "models/hub",
# local_files_only = True
}
try:
logger.info("Migrating core tokenizers and text encoders")
target_dir = dest_directory / "core" / "convert"
self._migrate_pretrained(
BertTokenizerFast, repo_id="bert-base-uncased", dest=target_dir / "bert-base-uncased", **kwargs
)
# sd-1
repo_id = "openai/clip-vit-large-patch14"
self._migrate_pretrained(
CLIPTokenizer, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", **kwargs
)
self._migrate_pretrained(
CLIPTextModel, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", force=True, **kwargs
)
# sd-2
repo_id = "stabilityai/stable-diffusion-2"
self._migrate_pretrained(
CLIPTokenizer,
repo_id=repo_id,
dest=target_dir / "stable-diffusion-2-clip" / "tokenizer",
**{"subfolder": "tokenizer", **kwargs},
)
self._migrate_pretrained(
CLIPTextModel,
repo_id=repo_id,
dest=target_dir / "stable-diffusion-2-clip" / "text_encoder",
**{"subfolder": "text_encoder", **kwargs},
)
# VAE
logger.info("Migrating stable diffusion VAE")
self._migrate_pretrained(
AutoencoderKL, repo_id="stabilityai/sd-vae-ft-mse", dest=target_dir / "sd-vae-ft-mse", **kwargs
)
# safety checking
logger.info("Migrating safety checker")
repo_id = "CompVis/stable-diffusion-safety-checker"
self._migrate_pretrained(
AutoFeatureExtractor, repo_id=repo_id, dest=target_dir / "stable-diffusion-safety-checker", **kwargs
)
self._migrate_pretrained(
StableDiffusionSafetyChecker,
repo_id=repo_id,
dest=target_dir / "stable-diffusion-safety-checker",
**kwargs,
)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def _model_probe_to_path(self, info: ModelProbeInfo) -> Path:
return Path(self.dest_models, info.base_type.value, info.model_type.value)
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force: bool = False, **kwargs):
if dest.exists() and not force:
logger.info(f"Skipping existing {dest}")
return
model = model_class.from_pretrained(repo_id, **kwargs)
self._save_pretrained(model, dest, overwrite=force)
def _save_pretrained(self, model, dest: Path, overwrite: bool = False):
model_name = dest.name
if overwrite:
model.save_pretrained(dest, safe_serialization=True)
else:
download_path = dest.with_name(f"{model_name}.downloading")
model.save_pretrained(download_path, safe_serialization=True)
download_path.replace(dest)
def _download_vae(self, repo_id: str, subfolder: str = None) -> Path:
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder)
info = ModelProbe().heuristic_probe(vae)
_, model_name = repo_id.split("/")
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
vae.save_pretrained(dest, safe_serialization=True)
return dest
def _vae_path(self, vae: Union[str, dict]) -> Path:
"""
Convert 2.3 VAE stanza to a straight path.
"""
vae_path = None
# First get a path
if isinstance(vae, str):
vae_path = vae
elif isinstance(vae, DictConfig):
if p := vae.get("path"):
vae_path = p
elif repo_id := vae.get("repo_id"):
if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded
vae_path = "models/core/convert/sd-vae-ft-mse"
return vae_path
else:
vae_path = self._download_vae(repo_id, vae.get("subfolder"))
assert vae_path is not None, "Couldn't find VAE for this model"
# if the VAE is in the old models directory, then we must move it into the new
# one. VAEs outside of this directory can stay where they are.
vae_path = Path(vae_path)
if vae_path.is_relative_to(self.src_paths.models):
info = ModelProbe().heuristic_probe(vae_path)
dest = self._model_probe_to_path(info) / vae_path.name
if not dest.exists():
if vae_path.is_dir():
self.copy_dir(vae_path, dest)
else:
self.copy_file(vae_path, dest)
vae_path = dest
if vae_path.is_relative_to(self.dest_models):
rel_path = vae_path.relative_to(self.dest_models)
return Path("models", rel_path)
else:
return vae_path
def migrate_repo_id(self, repo_id: str, model_name: str = None, **extra_config):
"""
Migrate a locally-cached diffusers pipeline identified with a repo_id
"""
dest_dir = self.dest_models
cache = self.root_directory / "models/hub"
kwargs = {
"cache_dir": cache,
"safety_checker": None,
# local_files_only = True,
}
owner, repo_name = repo_id.split("/")
model_name = model_name or repo_name
model = cache / "--".join(["models", owner, repo_name])
if len(list(model.glob("snapshots/**/model_index.json"))) == 0:
return
revisions = [x.name for x in model.glob("refs/*")]
# if an fp16 is available we use that
revision = "fp16" if len(revisions) > 1 and "fp16" in revisions else revisions[0]
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, revision=revision, **kwargs)
info = ModelProbe().heuristic_probe(pipeline)
if not info:
return
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
return
dest = self._model_probe_to_path(info) / model_name
self._save_pretrained(pipeline, dest)
rel_path = Path("models", dest.relative_to(dest_dir))
self._add_model(model_name, info, rel_path, **extra_config)
def migrate_path(self, location: Path, model_name: str = None, **extra_config):
"""
Migrate a model referred to using 'weights' or 'path'
"""
# handle relative paths
dest_dir = self.dest_models
location = self.root_directory / location
model_name = model_name or location.stem
info = ModelProbe().heuristic_probe(location)
if not info:
return
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
return
# uh oh, weights is in the old models directory - move it into the new one
if Path(location).is_relative_to(self.src_paths.models):
dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name)
if location.is_dir():
self.copy_dir(location, dest)
else:
self.copy_file(location, dest)
location = Path("models", info.base_type.value, info.model_type.value, location.name)
self._add_model(model_name, info, location, **extra_config)
def _add_model(self, model_name: str, info: ModelProbeInfo, location: Path, **extra_config):
if info.model_type != ModelType.Main:
return
self.mgr.add_model(
model_name=model_name,
base_model=info.base_type,
model_type=info.model_type,
clobber=True,
model_attributes={
"path": str(location),
"description": f"A {info.base_type.value} {info.model_type.value} model",
"model_format": info.format,
"variant": info.variant_type.value,
**extra_config,
},
)
def migrate_defined_models(self):
"""
Migrate models defined in models.yaml
"""
# find any models referred to in old models.yaml
conf = OmegaConf.load(self.root_directory / "configs/models.yaml")
for model_name, stanza in conf.items():
try:
passthru_args = {}
if vae := stanza.get("vae"):
try:
passthru_args["vae"] = str(self._vae_path(vae))
except Exception as e:
logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"')
logger.warning(str(e))
if config := stanza.get("config"):
passthru_args["config"] = config
if description := stanza.get("description"):
passthru_args["description"] = description
if repo_id := stanza.get("repo_id"):
logger.info(f"Migrating diffusers model {model_name}")
self.migrate_repo_id(repo_id, model_name, **passthru_args)
elif location := stanza.get("weights"):
logger.info(f"Migrating checkpoint model {model_name}")
self.migrate_path(Path(location), model_name, **passthru_args)
elif location := stanza.get("path"):
logger.info(f"Migrating diffusers model {model_name}")
self.migrate_path(Path(location), model_name, **passthru_args)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def migrate(self):
self.create_directory_structure()
# the configure script is doing this
self.migrate_support_models()
self.migrate_conversion_models()
self.migrate_tuning_models()
self.migrate_defined_models()
def _parse_legacy_initfile(root: Path, initfile: Path) -> ModelPaths:
"""
Returns tuple of (embedding_path, lora_path, controlnet_path)
"""
parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
parser.add_argument(
"--embedding_directory",
"--embedding_path",
type=Path,
dest="embedding_path",
default=Path("embeddings"),
)
parser.add_argument(
"--lora_directory",
dest="lora_path",
type=Path,
default=Path("loras"),
)
opt, _ = parser.parse_known_args([f"@{str(initfile)}"])
return ModelPaths(
models=root / "models",
embeddings=root / str(opt.embedding_path).strip('"'),
loras=root / str(opt.lora_path).strip('"'),
controlnets=root / "controlnets",
)
def _parse_legacy_yamlfile(root: Path, initfile: Path) -> ModelPaths:
"""
Returns tuple of (embedding_path, lora_path, controlnet_path)
"""
# Don't use the config object because it is unforgiving of version updates
# Just use omegaconf directly
opt = OmegaConf.load(initfile)
paths = opt.InvokeAI.Paths
models = paths.get("models_dir", "models")
embeddings = paths.get("embedding_dir", "embeddings")
loras = paths.get("lora_dir", "loras")
controlnets = paths.get("controlnet_dir", "controlnets")
return ModelPaths(
models=root / models if models else None,
embeddings=root / embeddings if embeddings else None,
loras=root / loras if loras else None,
controlnets=root / controlnets if controlnets else None,
)
def get_legacy_embeddings(root: Path) -> ModelPaths:
path = root / "invokeai.init"
if path.exists():
return _parse_legacy_initfile(root, path)
path = root / "invokeai.yaml"
if path.exists():
return _parse_legacy_yamlfile(root, path)
def do_migrate(src_directory: Path, dest_directory: Path):
"""
Migrate models from src to dest InvokeAI root directories
"""
config_file = dest_directory / "configs" / "models.yaml.3"
dest_models = dest_directory / "models.3"
version_3 = (dest_directory / "models" / "core").exists()
# Here we create the destination models.yaml file.
# If we are writing into a version 3 directory and the
# file already exists, then we write into a copy of it to
# avoid deleting its previous customizations. Otherwise we
# create a new empty one.
if version_3: # write into the dest directory
try:
shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
except Exception:
MigrateTo3.initialize_yaml(config_file)
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
(dest_directory / "models").replace(dest_models)
else:
MigrateTo3.initialize_yaml(config_file)
mgr = ModelManager(config_file)
paths = get_legacy_embeddings(src_directory)
migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths)
migrator.migrate()
print("Migration successful.")
if not version_3:
(dest_directory / "models").replace(src_directory / "models.orig")
print(f"Original models directory moved to {dest_directory}/models.orig")
(dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig")
print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig")
config_file.replace(config_file.with_suffix(""))
dest_models.replace(dest_models.with_suffix(""))
def main():
parser = argparse.ArgumentParser(
prog="invokeai-migrate3",
description="""
This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format
'--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a
The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively.
It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure
script, which will perform a full upgrade in place.""",
)
parser.add_argument(
"--from-directory",
dest="src_root",
type=Path,
required=True,
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")',
)
parser.add_argument(
"--to-directory",
dest="dest_root",
type=Path,
required=True,
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")',
)
args = parser.parse_args()
src_root = args.src_root
assert src_root.is_dir(), f"{src_root} is not a valid directory"
assert (src_root / "models").is_dir(), f"{src_root} does not contain a 'models' subdirectory"
assert (src_root / "models" / "hub").exists(), f"{src_root} does not contain a version 2.3 models directory"
assert (src_root / "invokeai.init").exists() or (
src_root / "invokeai.yaml"
).exists(), f"{src_root} does not contain an InvokeAI init file."
dest_root = args.dest_root
assert dest_root.is_dir(), f"{dest_root} is not a valid directory"
config = InvokeAIAppConfig.get_config()
config.parse_args(["--root", str(dest_root)])
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists()
if not dest_is_setup:
from invokeai.backend.install.invokeai_configure import initialize_rootdir
initialize_rootdir(dest_root, True)
do_migrate(src_root, dest_root)
if __name__ == "__main__":
main()

View File

@ -1,637 +0,0 @@
"""
Utility (backend) functions used by model_install.py
"""
import os
import re
import shutil
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Callable, Dict, List, Optional, Set, Union
import requests
import torch
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
from huggingface_hub import HfApi, HfFolder, hf_hub_url
from omegaconf import OmegaConf
from tqdm import tqdm
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType
from invokeai.backend.model_management.model_probe import ModelProbe, ModelProbeInfo, SchedulerPredictionType
from invokeai.backend.util import download_with_resume
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from ..util.logging import InvokeAILogger
warnings.filterwarnings("ignore")
# --------------------------globals-----------------------
config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger(name="InvokeAI")
# the initial "configs" dir is now bundled in the `invokeai.configs` package
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
Config_preamble = """
# This file describes the alternative machine learning models
# available to InvokeAI script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
"""
LEGACY_CONFIGS = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: "v1-inference.yaml",
SchedulerPredictionType.VPrediction: "v1-inference-v.yaml",
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: "v1-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: "v1-inpainting-inference-v.yaml",
},
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
},
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
},
}
@dataclass
class InstallSelections:
install_models: List[str] = field(default_factory=list)
remove_models: List[str] = field(default_factory=list)
@dataclass
class ModelLoadInfo:
name: str
model_type: ModelType
base_type: BaseModelType
path: Optional[Path] = None
repo_id: Optional[str] = None
subfolder: Optional[str] = None
description: str = ""
installed: bool = False
recommended: bool = False
default: bool = False
requires: Optional[List[str]] = field(default_factory=list)
class ModelInstall(object):
def __init__(
self,
config: InvokeAIAppConfig,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
model_manager: Optional[ModelManager] = None,
access_token: Optional[str] = None,
civitai_api_key: Optional[str] = None,
):
self.config = config
self.mgr = model_manager or ModelManager(config.model_conf_path)
self.datasets = OmegaConf.load(Dataset_path)
self.prediction_helper = prediction_type_helper
self.access_token = access_token or HfFolder.get_token()
self.civitai_api_key = civitai_api_key or config.civitai_api_key
self.reverse_paths = self._reverse_paths(self.datasets)
def all_models(self) -> Dict[str, ModelLoadInfo]:
"""
Return dict of model_key=>ModelLoadInfo objects.
This method consolidates and simplifies the entries in both
models.yaml and INITIAL_MODELS.yaml so that they can
be treated uniformly. It also sorts the models alphabetically
by their name, to improve the display somewhat.
"""
model_dict = {}
# first populate with the entries in INITIAL_MODELS.yaml
for key, value in self.datasets.items():
name, base, model_type = ModelManager.parse_key(key)
value["name"] = name
value["base_type"] = base
value["model_type"] = model_type
model_info = ModelLoadInfo(**value)
if model_info.subfolder and model_info.repo_id:
model_info.repo_id += f":{model_info.subfolder}"
model_dict[key] = model_info
# supplement with entries in models.yaml
installed_models = list(self.mgr.list_models())
for md in installed_models:
base = md["base_model"]
model_type = md["model_type"]
name = md["model_name"]
key = ModelManager.create_key(name, base, model_type)
if key in model_dict:
model_dict[key].installed = True
else:
model_dict[key] = ModelLoadInfo(
name=name,
base_type=base,
model_type=model_type,
path=value.get("path"),
installed=True,
)
return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())}
def _is_autoloaded(self, model_info: dict) -> bool:
path = model_info.get("path")
if not path:
return False
for autodir in ["autoimport_dir", "lora_dir", "embedding_dir", "controlnet_dir"]:
if autodir_path := getattr(self.config, autodir):
autodir_path = self.config.root_path / autodir_path
if Path(path).is_relative_to(autodir_path):
return True
return False
def list_models(self, model_type):
installed = self.mgr.list_models(model_type=model_type)
print()
print(f"Installed models of type `{model_type}`:")
print(f"{'Model Key':50} Model Path")
for i in installed:
print(f"{'/'.join([i['base_model'],i['model_type'],i['model_name']]):50} {i['path']}")
print()
# logic here a little reversed to maintain backward compatibility
def starter_models(self, all_models: bool = False) -> Set[str]:
models = set()
for key, _value in self.datasets.items():
name, base, model_type = ModelManager.parse_key(key)
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
models.add(key)
return models
def recommended_models(self) -> Set[str]:
starters = self.starter_models(all_models=True)
return {x for x in starters if self.datasets[x].get("recommended", False)}
def default_model(self) -> str:
starters = self.starter_models()
defaults = [x for x in starters if self.datasets[x].get("default", False)]
return defaults[0]
def install(self, selections: InstallSelections):
verbosity = dlogging.get_verbosity() # quench NSFW nags
dlogging.set_verbosity_error()
job = 1
jobs = len(selections.remove_models) + len(selections.install_models)
# remove requested models
for key in selections.remove_models:
name, base, mtype = self.mgr.parse_key(key)
logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]")
try:
self.mgr.del_model(name, base, mtype)
except FileNotFoundError as e:
logger.warning(e)
job += 1
# add requested models
self._remove_installed(selections.install_models)
self._add_required_models(selections.install_models)
for path in selections.install_models:
logger.info(f"Installing {path} [{job}/{jobs}]")
try:
self.heuristic_import(path)
except (ValueError, KeyError) as e:
logger.error(str(e))
job += 1
dlogging.set_verbosity(verbosity)
self.mgr.commit()
def heuristic_import(
self,
model_path_id_or_url: Union[str, Path],
models_installed: Set[Path] = None,
) -> Dict[str, AddModelResult]:
"""
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
:param models_installed: Set of installed models, used for recursive invocation
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
"""
if not models_installed:
models_installed = {}
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
# A little hack to allow nested routines to retrieve info on the requested ID
self.current_id = model_path_id_or_url
path = Path(model_path_id_or_url)
# fix relative paths
if path.exists() and not path.is_absolute():
path = path.absolute() # make relative to current WD
# checkpoint file, or similar
if path.is_file():
models_installed.update({str(path): self._install_path(path)})
# folders style or similar
elif path.is_dir() and any(
(path / x).exists()
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"pytorch_lora_weights.safetensors",
}
):
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
# recursive scan
elif path.is_dir():
for child in path.iterdir():
self.heuristic_import(child, models_installed=models_installed)
# huggingface repo
elif len(str(model_path_id_or_url).split("/")) == 2:
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
# a URL
elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
else:
raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping")
return models_installed
def _remove_installed(self, model_list: List[str]):
all_models = self.all_models()
models_to_remove = []
for path in model_list:
key = self.reverse_paths.get(path)
if key and all_models[key].installed:
models_to_remove.append(path)
for path in models_to_remove:
logger.warning(f"{path} already installed. Skipping")
model_list.remove(path)
def _add_required_models(self, model_list: List[str]):
additional_models = []
all_models = self.all_models()
for path in model_list:
if not (key := self.reverse_paths.get(path)):
continue
for requirement in all_models[key].requires:
requirement_key = self.reverse_paths.get(requirement)
if not all_models[requirement_key].installed:
additional_models.append(requirement)
model_list.extend(additional_models)
# install a model from a local path. The optional info parameter is there to prevent
# the model from being probed twice in the event that it has already been probed.
def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult:
info = info or ModelProbe().heuristic_probe(path, self.prediction_helper)
if not info:
logger.warning(f"Unable to parse format of {path}")
return None
model_name = path.stem if path.is_file() else path.name
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
raise ValueError(f'A model named "{model_name}" is already installed.')
attributes = self._make_attributes(path, info)
return self.mgr.add_model(
model_name=model_name,
base_model=info.base_type,
model_type=info.model_type,
model_attributes=attributes,
)
def _install_url(self, url: str) -> AddModelResult:
with TemporaryDirectory(dir=self.config.models_path) as staging:
CIVITAI_RE = r".*civitai.com.*"
civit_url = re.match(CIVITAI_RE, url, re.IGNORECASE)
location = download_with_resume(
url, Path(staging), access_token=self.civitai_api_key if civit_url else None
)
if not location:
logger.error(f"Unable to download {url}. Skipping.")
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
dest.parent.mkdir(parents=True, exist_ok=True)
models_path = shutil.move(location, dest)
# staged version will be garbage-collected at this time
return self._install_path(Path(models_path), info)
def _install_repo(self, repo_id: str) -> AddModelResult:
# hack to recover models stored in subfolders --
# Required to get the "v2" model of monster-labs/control_v1p_sd15_qrcode_monster
subfolder = None
if match := re.match(r"^([^/]+/[^/]+):(\w+)$", repo_id):
repo_id = match.group(1)
subfolder = match.group(2)
hinfo = HfApi().model_info(repo_id)
# we try to figure out how to download this most economically
# list all the files in the repo
files = [x.rfilename for x in hinfo.siblings]
if subfolder:
files = [x for x in files if x.startswith(f"{subfolder}/")]
prefix = f"{subfolder}/" if subfolder else ""
location = None
with TemporaryDirectory(dir=self.config.models_path) as staging:
staging = Path(staging)
if f"{prefix}model_index.json" in files:
location = self._download_hf_pipeline(repo_id, staging, subfolder=subfolder) # pipeline
elif f"{prefix}unet/model.onnx" in files:
location = self._download_hf_model(repo_id, files, staging)
else:
for suffix in ["safetensors", "bin"]:
if f"{prefix}pytorch_lora_weights.{suffix}" in files:
location = self._download_hf_model(
repo_id, [f"pytorch_lora_weights.{suffix}"], staging, subfolder=subfolder
) # LoRA
break
elif (
self.config.precision == "float16" and f"{prefix}diffusion_pytorch_model.fp16.{suffix}" in files
): # vae, controlnet or some other standalone
files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
break
elif f"{prefix}diffusion_pytorch_model.{suffix}" in files:
files = ["config.json", f"diffusion_pytorch_model.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
break
elif f"{prefix}learned_embeds.{suffix}" in files:
location = self._download_hf_model(
repo_id, [f"learned_embeds.{suffix}"], staging, subfolder=subfolder
)
break
elif (
f"{prefix}image_encoder.txt" in files and f"{prefix}ip_adapter.{suffix}" in files
): # IP-Adapter
files = ["image_encoder.txt", f"ip_adapter.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
break
elif f"{prefix}model.{suffix}" in files and f"{prefix}config.json" in files:
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
# by InvokeAI for use with IP-Adapters.
files = ["config.json", f"model.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
break
if not location:
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
return {}
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
if not info:
logger.warning(f"Could not probe {location}. Skipping install.")
return {}
dest = (
self.config.models_path
/ info.base_type.value
/ info.model_type.value
/ self._get_model_name(repo_id, location)
)
if dest.exists():
shutil.rmtree(dest)
shutil.copytree(location, dest)
return self._install_path(dest, info)
def _get_model_name(self, path_name: str, location: Path) -> str:
"""
Calculate a name for the model - primitive implementation.
"""
if key := self.reverse_paths.get(path_name):
(name, base, mtype) = ModelManager.parse_key(key)
return name
elif location.is_dir():
return location.name
else:
return location.stem
def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict:
model_name = path.name if path.is_dir() else path.stem
description = f"{info.base_type.value} {info.model_type.value} model {model_name}"
if key := self.reverse_paths.get(self.current_id):
if key in self.datasets:
description = self.datasets[key].get("description") or description
rel_path = self.relative_to_root(path, self.config.models_path)
attributes = {
"path": str(rel_path),
"description": str(description),
"model_format": info.format,
}
legacy_conf = None
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
attributes.update(
{
"variant": info.variant_type,
}
)
if info.format == "checkpoint":
try:
possible_conf = path.with_suffix(".yaml")
if possible_conf.exists():
legacy_conf = str(self.relative_to_root(possible_conf))
elif info.base_type in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
legacy_conf = Path(
self.config.legacy_conf_dir,
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],
)
else:
legacy_conf = Path(
self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type]
)
except KeyError:
legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess
if info.model_type == ModelType.ControlNet and info.format == "checkpoint":
possible_conf = path.with_suffix(".yaml")
if possible_conf.exists():
legacy_conf = str(self.relative_to_root(possible_conf))
else:
legacy_conf = Path(
self.config.root_path,
"configs/controlnet",
("cldm_v15.yaml" if info.base_type == BaseModelType("sd-1") else "cldm_v21.yaml"),
)
if legacy_conf:
attributes.update({"config": str(legacy_conf)})
return attributes
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
root = root or self.config.root_path
if path.is_relative_to(root):
return path.relative_to(root)
else:
return path
def _download_hf_pipeline(self, repo_id: str, staging: Path, subfolder: str = None) -> Path:
"""
Retrieve a StableDiffusion model from cache or remote and then
does a save_pretrained() to the indicated staging area.
"""
_, name = repo_id.split("/")
precision = torch_dtype(choose_torch_device())
variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"]
model = None
for variant in variants:
try:
model = DiffusionPipeline.from_pretrained(
repo_id,
variant=variant,
torch_dtype=precision,
safety_checker=None,
subfolder=subfolder,
)
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
if "fp16" not in str(e):
print(e)
if model:
break
if not model:
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
return None
model.save_pretrained(staging / name, safe_serialization=True)
return staging / name
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
_, name = repo_id.split("/")
location = staging / name
paths = []
for filename in files:
filePath = Path(filename)
p = hf_download_with_resume(
repo_id,
model_dir=location / filePath.parent,
model_name=filePath.name,
access_token=self.access_token,
subfolder=filePath.parent / subfolder if subfolder else filePath.parent,
)
if p:
paths.append(p)
else:
logger.warning(f"Could not download {filename} from {repo_id}.")
return location if len(paths) > 0 else None
@classmethod
def _reverse_paths(cls, datasets) -> dict:
"""
Reverse mapping from repo_id/path to destination name.
"""
return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()}
# -------------------------------------
def yes_or_no(prompt: str, default_yes=True):
default = "y" if default_yes else "n"
response = input(f"{prompt} [{default}] ") or default
if default_yes:
return response[0] not in ("n", "N")
else:
return response[0] in ("y", "Y")
# ---------------------------------------------
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
logger = InvokeAILogger.get_logger("InvokeAI")
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
model = model_class.from_pretrained(
model_name,
resume_download=True,
**kwargs,
)
model.save_pretrained(destination, safe_serialization=True)
return destination
# ---------------------------------------------
def hf_download_with_resume(
repo_id: str,
model_dir: str,
model_name: str,
model_dest: Path = None,
access_token: str = None,
subfolder: str = None,
) -> Path:
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
os.makedirs(model_dir, exist_ok=True)
url = hf_hub_url(repo_id, model_name, subfolder=subfolder)
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
open_mode = "wb"
exist_size = 0
if os.path.exists(model_dest):
exist_size = os.path.getsize(model_dest)
header["Range"] = f"bytes={exist_size}-"
open_mode = "ab"
resp = requests.get(url, headers=header, stream=True)
total = int(resp.headers.get("content-length", 0))
if resp.status_code == 416: # "range not satisfiable", which means nothing to return
logger.info(f"{model_name}: complete file found. Skipping.")
return model_dest
elif resp.status_code == 404:
logger.warning("File not found")
return None
elif resp.status_code != 200:
logger.warning(f"{model_name}: {resp.reason}")
elif exist_size > 0:
logger.info(f"{model_name}: partial file found. Resuming...")
else:
logger.info(f"{model_name}: Downloading...")
try:
with (
open(model_dest, open_mode) as file,
tqdm(
desc=model_name,
initial=exist_size,
total=total + exist_size,
unit="iB",
unit_scale=True,
unit_divisor=1000,
) as bar,
):
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
except Exception as e:
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
return None
return model_dest

View File

@ -8,8 +8,8 @@ from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from invokeai.backend.model_management.models.base import calc_model_size_by_data
from ..raw_model import RawModel
from .resampler import Resampler
@ -92,7 +92,7 @@ class MLPProjModel(torch.nn.Module):
return clip_extra_context_tokens
class IPAdapter:
class IPAdapter(RawModel):
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
def __init__(
@ -124,6 +124,9 @@ class IPAdapter:
self.attn_weights.to(device=self.device, dtype=self.dtype)
def calc_size(self):
# workaround for circular import
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
def _init_image_proj_model(self, state_dict):

View File

@ -1,98 +1,17 @@
# Copyright (c) 2024 The InvokeAI Development team
"""LoRA model support."""
import bisect
import os
from enum import Enum
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
import torch
from safetensors.torch import load_file
from typing_extensions import Self
from .base import (
BaseModelType,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelNotFoundException,
ModelType,
SubModelType,
classproperty,
)
from invokeai.backend.model_manager import BaseModelType
class LoRAModelFormat(str, Enum):
LyCORIS = "lycoris"
Diffusers = "diffusers"
class LoRAModel(ModelBase):
# model_size: int
class Config(ModelConfigBase):
model_format: LoRAModelFormat # TODO:
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora
super().__init__(model_path, base_model, model_type)
self.model_size = os.path.getsize(self.model_path)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in lora")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in lora")
model = LoRAModelRaw.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
base_model=self.base_model,
)
self.model_size = model.calc_size()
return model
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException()
if os.path.isdir(path):
for ext in ["safetensors", "bin"]:
if os.path.exists(os.path.join(path, f"pytorch_lora_weights.{ext}")):
return LoRAModelFormat.Diffusers
if os.path.isfile(path):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return LoRAModelFormat.LyCORIS
raise InvalidModelException(f"Not a valid model: {path}")
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder
path = Path(model_path, f"pytorch_lora_weights.{ext}")
if path.exists():
return path
else:
return model_path
from .raw_model import RawModel
class LoRALayerBase:
@ -108,7 +27,7 @@ class LoRALayerBase:
def __init__(
self,
layer_key: str,
values: dict,
values: Dict[str, torch.Tensor],
):
if "alpha" in values:
self.alpha = values["alpha"].item()
@ -116,7 +35,7 @@ class LoRALayerBase:
self.alpha = None
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias = torch.sparse_coo_tensor(
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
@ -128,7 +47,7 @@ class LoRALayerBase:
self.rank = None # set in layer implementation
self.layer_key = layer_key
def get_weight(self, orig_weight: torch.Tensor):
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
raise NotImplementedError()
def calc_size(self) -> int:
@ -142,7 +61,7 @@ class LoRALayerBase:
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
) -> None:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
@ -156,20 +75,20 @@ class LoRALayer(LoRALayerBase):
def __init__(
self,
layer_key: str,
values: dict,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
if "lora_mid.weight" in values:
self.mid = values["lora_mid.weight"]
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
else:
self.mid = None
self.rank = self.down.shape[0]
def get_weight(self, orig_weight: torch.Tensor):
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
@ -190,7 +109,7 @@ class LoRALayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
) -> None:
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
@ -208,11 +127,7 @@ class LoHALayer(LoRALayerBase):
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
@ -221,20 +136,20 @@ class LoHALayer(LoRALayerBase):
self.w2_b = values["hada_w2_b"]
if "hada_t1" in values:
self.t1 = values["hada_t1"]
self.t1: Optional[torch.Tensor] = values["hada_t1"]
else:
self.t1 = None
if "hada_t2" in values:
self.t2 = values["hada_t2"]
self.t2: Optional[torch.Tensor] = values["hada_t2"]
else:
self.t2 = None
self.rank = self.w1_b.shape[0]
def get_weight(self, orig_weight: torch.Tensor):
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
@ -254,7 +169,7 @@ class LoHALayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
) -> None:
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
@ -280,12 +195,12 @@ class LoKRLayer(LoRALayerBase):
def __init__(
self,
layer_key: str,
values: dict,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
if "lokr_w1" in values:
self.w1 = values["lokr_w1"]
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
self.w1_a = None
self.w1_b = None
else:
@ -294,7 +209,7 @@ class LoKRLayer(LoRALayerBase):
self.w1_b = values["lokr_w1_b"]
if "lokr_w2" in values:
self.w2 = values["lokr_w2"]
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
self.w2_a = None
self.w2_b = None
else:
@ -303,7 +218,7 @@ class LoKRLayer(LoRALayerBase):
self.w2_b = values["lokr_w2_b"]
if "lokr_t2" in values:
self.t2 = values["lokr_t2"]
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
else:
self.t2 = None
@ -314,14 +229,18 @@ class LoKRLayer(LoRALayerBase):
else:
self.rank = None # unscaled
def get_weight(self, orig_weight: torch.Tensor):
w1 = self.w1
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
w1: Optional[torch.Tensor] = self.w1
if w1 is None:
assert self.w1_a is not None
assert self.w1_b is not None
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
assert self.w2_a is not None
assert self.w2_b is not None
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
@ -329,6 +248,8 @@ class LoKRLayer(LoRALayerBase):
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
assert w1 is not None
assert w2 is not None
weight = torch.kron(w1, w2)
return weight
@ -344,18 +265,22 @@ class LoKRLayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
) -> None:
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
assert self.w1_a is not None
assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
assert self.w2_a is not None
assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
@ -369,7 +294,7 @@ class FullLayer(LoRALayerBase):
def __init__(
self,
layer_key: str,
values: dict,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
@ -382,7 +307,7 @@ class FullLayer(LoRALayerBase):
self.rank = None # unscaled
def get_weight(self, orig_weight: torch.Tensor):
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
@ -394,7 +319,7 @@ class FullLayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
@ -407,7 +332,7 @@ class IA3Layer(LoRALayerBase):
def __init__(
self,
layer_key: str,
values: dict,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
@ -416,10 +341,11 @@ class IA3Layer(LoRALayerBase):
self.rank = None # unscaled
def get_weight(self, orig_weight: torch.Tensor):
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
assert orig_weight is not None
return orig_weight * weight
def calc_size(self) -> int:
@ -439,28 +365,30 @@ class IA3Layer(LoRALayerBase):
self.on_input = self.on_input.to(device=device, dtype=dtype)
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw: # (torch.nn.Module):
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
class LoRAModelRaw(RawModel): # (torch.nn.Module):
_name: str
layers: Dict[str, LoRALayer]
layers: Dict[str, AnyLoRALayer]
def __init__(
self,
name: str,
layers: Dict[str, LoRALayer],
layers: Dict[str, AnyLoRALayer],
):
self._name = name
self.layers = layers
@property
def name(self):
def name(self) -> str:
return self._name
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
) -> None:
# TODO: try revert if exception?
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
@ -472,7 +400,7 @@ class LoRAModelRaw: # (torch.nn.Module):
return model_size
@classmethod
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict):
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
@ -536,7 +464,7 @@ class LoRAModelRaw: # (torch.nn.Module):
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
base_model: Optional[BaseModelType] = None,
):
) -> Self:
device = device or torch.device("cpu")
dtype = dtype or torch.float32
@ -544,16 +472,16 @@ class LoRAModelRaw: # (torch.nn.Module):
file_path = Path(file_path)
model = cls(
name=file_path.stem, # TODO:
name=file_path.stem,
layers={},
)
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
sd = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
sd = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(state_dict)
state_dict = cls._group_state(sd)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
@ -561,7 +489,7 @@ class LoRAModelRaw: # (torch.nn.Module):
for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
layer = LoRALayer(layer_key, values)
layer: AnyLoRALayer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_b" in values:
@ -592,8 +520,8 @@ class LoRAModelRaw: # (torch.nn.Module):
return model
@staticmethod
def _group_state(state_dict: dict):
state_dict_groupped = {}
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
@ -606,7 +534,7 @@ class LoRAModelRaw: # (torch.nn.Module):
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def make_sdxl_unet_conversion_map():
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer = []

View File

@ -1,27 +0,0 @@
# Model Cache
## `glibc` Memory Allocator Fragmentation
Python (and PyTorch) relies on the memory allocator from the C Standard Library (`libc`). On linux, with the GNU C Standard Library implementation (`glibc`), our memory access patterns have been observed to cause severe memory fragmentation. This fragmentation results in large amounts of memory that has been freed but can't be released back to the OS. Loading models from disk and moving them between CPU/CUDA seem to be the operations that contribute most to the fragmentation. This memory fragmentation issue can result in OOM crashes during frequent model switching, even if `max_cache_size` is set to a reasonable value (e.g. a OOM crash with `max_cache_size=16` on a system with 32GB of RAM).
This problem may also exist on other OSes, and other `libc` implementations. But, at the time of writing, it has only been investigated on linux with `glibc`.
To better understand how the `glibc` memory allocator works, see these references:
- Basics: https://www.gnu.org/software/libc/manual/html_node/The-GNU-Allocator.html
- Details: https://sourceware.org/glibc/wiki/MallocInternals
Note the differences between memory allocated as chunks in an arena vs. memory allocated with `mmap`. Under `glibc`'s default configuration, most model tensors get allocated as chunks in an arena making them vulnerable to the problem of fragmentation.
We can work around this memory fragmentation issue by setting the following env var:
```bash
# Force blocks >1MB to be allocated with `mmap` so that they are released to the system immediately when they are freed.
MALLOC_MMAP_THRESHOLD_=1048576
```
See the following references for more information about the `malloc` tunable parameters:
- https://www.gnu.org/software/libc/manual/html_node/Malloc-Tunable-Parameters.html
- https://www.gnu.org/software/libc/manual/html_node/Memory-Allocation-Tunables.html
- https://man7.org/linux/man-pages/man3/mallopt.3.html
The model cache emits debug logs that provide visibility into the state of the `libc` memory allocator. See the `LibcUtil` class for more info on how these `libc` malloc stats are collected.

View File

@ -1,20 +0,0 @@
# ruff: noqa: I001, F401
"""
Initialization file for invokeai.backend.model_management
"""
# This import must be first
from .model_manager import AddModelResult, LoadedModelInfo, ModelManager, SchedulerPredictionType
from .lora import ModelPatcher, ONNXModelPatcher
from .model_cache import ModelCache
from .models import (
BaseModelType,
DuplicateModelException,
ModelNotFoundException,
ModelType,
ModelVariantType,
SubModelType,
)
# This import must be last
from .model_merge import MergeInterpolationMethod, ModelMerger

View File

@ -1,31 +0,0 @@
# Copyright (c) 2024 Lincoln Stein and the InvokeAI Development Team
"""
This module exports the function has_baked_in_sdxl_vae().
It returns True if an SDXL checkpoint model has the original SDXL 1.0 VAE,
which doesn't work properly in fp16 mode.
"""
import hashlib
from pathlib import Path
from safetensors.torch import load_file
SDXL_1_0_VAE_HASH = "bc40b16c3a0fa4625abdfc01c04ffc21bf3cefa6af6c7768ec61eb1f1ac0da51"
def has_baked_in_sdxl_vae(checkpoint_path: Path) -> bool:
"""Return true if the checkpoint contains a custom (non SDXL-1.0) VAE."""
hash = _vae_hash(checkpoint_path)
return hash != SDXL_1_0_VAE_HASH
def _vae_hash(checkpoint_path: Path) -> str:
checkpoint = load_file(checkpoint_path, device="cpu")
vae_keys = [x for x in checkpoint.keys() if x.startswith("first_stage_model.")]
hash = hashlib.new("sha256")
for key in vae_keys:
value = checkpoint[key]
hash.update(bytes(key, "UTF-8"))
hash.update(bytes(str(value), "UTF-8"))
return hash.hexdigest()

View File

@ -1,553 +0,0 @@
"""
Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
The cache returns context manager generators designed to load the
model into the GPU within the context, and unload outside the
context. Use like this:
cache = ModelCache(max_cache_size=7.5)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
do_something_in_GPU(SD1,SD2)
"""
import gc
import hashlib
import math
import os
import sys
import time
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional, Type, Union, types
import torch
import invokeai.backend.util.logging as logger
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
from ..util.devices import choose_torch_device
from .models import BaseModelType, ModelBase, ModelType, SubModelType
if choose_torch_device() == torch.device("mps"):
from torch import mps
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
# amount of GPU memory to hold in reserve for use by generations (GB)
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
# actual size of a gig
GIG = 1073741824
# Size of a MB in bytes.
MB = 2**20
@dataclass
class CacheStats(object):
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
# {submodel_key => size}
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
class ModelLocker(object):
"Forward declaration"
pass
class ModelCache(object):
"Forward declaration"
pass
class _CacheRecord:
size: int
model: Any
cache: ModelCache
_locks: int
def __init__(self, cache, model: Any, size: int):
self.size = size
self.model = model
self.cache = cache
self._locks = 0
def lock(self):
self._locks += 1
def unlock(self):
self._locks -= 1
assert self._locks >= 0
@property
def locked(self):
return self._locks > 0
@property
def loaded(self):
if self.model is not None and hasattr(self.model, "device"):
return self.model.device != self.cache.storage_device
else:
return False
class ModelCache(object):
def __init__(
self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
logger: types.ModuleType = logger,
log_memory_usage: bool = False,
):
"""
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
"""
self.model_infos: Dict[str, ModelBase] = {}
# allow lazy offloading only when vram cache enabled
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self.precision: torch.dtype = precision
self.max_cache_size: float = max_cache_size
self.max_vram_cache_size: float = max_vram_cache_size
self.execution_device: torch.device = execution_device
self.storage_device: torch.device = storage_device
self.sha_chunksize = sha_chunksize
self.logger = logger
self._log_memory_usage = log_memory_usage
# used for stats collection
self.stats = None
self._cached_models = {}
self._cache_stack = []
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
return MemorySnapshot.capture()
return None
def get_key(
self,
model_path: str,
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None,
):
key = f"{model_path}:{base_model}:{model_type}"
if submodel_type:
key += f":{submodel_type}"
return key
def _get_model_info(
self,
model_path: str,
model_class: Type[ModelBase],
base_model: BaseModelType,
model_type: ModelType,
):
model_info_key = self.get_key(
model_path=model_path,
base_model=base_model,
model_type=model_type,
submodel_type=None,
)
if model_info_key not in self.model_infos:
self.model_infos[model_info_key] = model_class(
model_path,
base_model,
model_type,
)
return self.model_infos[model_info_key]
# TODO: args
def get_model(
self,
model_path: Union[str, Path],
model_class: Type[ModelBase],
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
gpu_load: bool = True,
) -> Any:
if not isinstance(model_path, Path):
model_path = Path(model_path)
if not os.path.exists(model_path):
raise Exception(f"Model not found: {model_path}")
model_info = self._get_model_info(
model_path=model_path,
model_class=model_class,
base_model=base_model,
model_type=model_type,
)
key = self.get_key(
model_path=model_path,
base_model=base_model,
model_type=model_type,
submodel_type=submodel,
)
# TODO: lock for no copies on simultaneous calls?
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
self.logger.info(
f"Loading model {model_path}, type"
f" {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
)
if self.stats:
self.stats.misses += 1
self_reported_model_size_before_load = model_info.get_size(submodel)
# Remove old models from the cache to make room for the new model.
self._make_cache_room(self_reported_model_size_before_load)
# Load the model from disk and capture a memory snapshot before/after.
start_load_time = time.time()
snapshot_before = self._capture_memory_snapshot()
with skip_torch_weight_init():
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
snapshot_after = self._capture_memory_snapshot()
end_load_time = time.time()
self_reported_model_size_after_load = model_info.get_size(submodel)
self.logger.debug(
f"Moved model '{key}' from disk to cpu in {(end_load_time-start_load_time):.2f}s.\n"
f"Self-reported size before/after load: {(self_reported_model_size_before_load/GIG):.3f}GB /"
f" {(self_reported_model_size_after_load/GIG):.3f}GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if abs(self_reported_model_size_after_load - self_reported_model_size_before_load) > 10 * MB:
self.logger.debug(
f"Model '{key}' mis-reported its size before load. Self-reported size before/after load:"
f" {(self_reported_model_size_before_load/GIG):.2f}GB /"
f" {(self_reported_model_size_after_load/GIG):.2f}GB."
)
cache_entry = _CacheRecord(self, model, self_reported_model_size_after_load)
self._cached_models[key] = cache_entry
else:
if self.stats:
self.stats.hits += 1
if self.stats:
self.stats.cache_size = self.max_cache_size * GIG
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[key] = max(
self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel)
)
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
def _move_model_to_device(self, key: str, target_device: torch.device):
cache_entry = self._cached_models[key]
source_device = cache_entry.model.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support
# multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
cache_entry.model.to(target_device)
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self.logger.debug(
f"Moved model '{key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n"
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning.
if not math.isclose(
vram_change,
cache_entry.size,
rel_tol=0.1,
abs_tol=10 * MB,
):
self.logger.debug(
f"Moving model '{key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
class ModelLocker(object):
def __init__(self, cache, key, model, gpu_load, size_needed):
"""
:param cache: The model_cache object
:param key: The key of the model to lock in GPU
:param model: The model to lock
:param gpu_load: True if load into gpu
:param size_needed: Size of the model to load
"""
self.gpu_load = gpu_load
self.cache = cache
self.key = key
self.model = model
self.size_needed = size_needed
self.cache_entry = self.cache._cached_models[self.key]
def __enter__(self) -> Any:
if not hasattr(self.model, "to"):
return self.model
# NOTE that the model has to have the to() method in order for this
# code to move it into GPU!
if self.gpu_load:
self.cache_entry.lock()
try:
if self.cache.lazy_offloading:
self.cache._offload_unlocked_models(self.size_needed)
self.cache._move_model_to_device(self.key, self.cache.execution_device)
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
self.cache._print_cuda_stats()
except Exception:
self.cache_entry.unlock()
raise
# TODO: not fully understand
# in the event that the caller wants the model in RAM, we
# move it into CPU if it is in GPU and not locked
elif self.cache_entry.loaded and not self.cache_entry.locked:
self.cache._move_model_to_device(self.key, self.cache.storage_device)
return self.model
def __exit__(self, type, value, traceback):
if not hasattr(self.model, "to"):
return
self.cache_entry.unlock()
if not self.cache.lazy_offloading:
self.cache._offload_unlocked_models()
self.cache._print_cuda_stats()
# TODO: should it be called untrack_model?
def uncache_model(self, cache_id: str):
with suppress(ValueError):
self._cache_stack.remove(cache_id)
self._cached_models.pop(cache_id, None)
def model_hash(
self,
model_path: Union[str, Path],
) -> str:
"""
Given the HF repo id or path to a model on disk, returns a unique
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
:param model_path: Path to model file/directory on disk.
"""
return self._local_model_hash(model_path)
def cache_size(self) -> float:
"""Return the current size of the cache, in GB."""
return self._cache_size() / GIG
def _has_cuda(self) -> bool:
return self.execution_device.type == "cuda"
def _print_cuda_stats(self):
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % self.cache_size()
cached_models = 0
loaded_models = 0
locked_models = 0
for model_info in self._cached_models.values():
cached_models += 1
if model_info.loaded:
loaded_models += 1
if model_info.locked:
locked_models += 1
self.logger.debug(
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ ="
f" {cached_models}/{loaded_models}/{locked_models}"
)
def _cache_size(self) -> int:
return sum([m.size for m in self._cached_models.values()])
def _make_cache_room(self, model_size):
# calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = self._cache_size()
if current_size + bytes_needed > maximum_size:
self.logger.debug(
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GIG):.2f} GB"
)
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
pos = 0
models_cleared = 0
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(cache_entry.model)
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
# https://docs.python.org/3/library/gc.html#gc.get_referrers
# manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately
if refs > 2:
while True:
cleared = False
for referrer in gc.get_referrers(cache_entry.model):
if type(referrer).__name__ == "frame":
# RuntimeError: cannot clear an executing frame
with suppress(RuntimeError):
referrer.clear()
cleared = True
# break
# repeat if referrers changes(due to frame clear), else exit loop
if cleared:
gc.collect()
else:
break
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
f" refs: {refs}"
)
# Expected refs:
# 1 from cache_entry
# 1 from getrefcount function
# 1 from onnx runtime object
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
self.logger.debug(
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1
if self.stats:
self.stats.cleared += 1
del self._cache_stack[pos]
del self._cached_models[model_key]
del cache_entry
else:
pos += 1
if models_cleared > 0:
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
# is high even if no garbage gets collected.)
#
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
# - If models had to be cleared, it's a signal that we are close to our memory limit.
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
# collected.
#
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
# immediately when their reference count hits 0.
gc.collect()
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
def _offload_unlocked_models(self, size_needed: int = 0):
reserved = self.max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated()
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if not cache_entry.locked and cache_entry.loaded:
self._move_model_to_device(model_key, self.storage_device)
vram_in_use = torch.cuda.memory_allocated()
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
sha = hashlib.sha256()
path = Path(model_path)
hashpath = path / "checksum.sha256"
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
with open(hashpath) as f:
hash = f.read()
return hash
self.logger.debug(f"computing hash of model {path.name}")
for file in list(path.rglob("*.ckpt")) + list(path.rglob("*.safetensors")) + list(path.rglob("*.pth")):
with open(file, "rb") as f:
while chunk := f.read(self.sha_chunksize):
sha.update(chunk)
hash = sha.hexdigest()
with open(hashpath, "w") as f:
f.write(hash)
return hash

File diff suppressed because it is too large Load Diff

View File

@ -1,140 +0,0 @@
"""
invokeai.backend.model_management.model_merge exports:
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import warnings
from enum import Enum
from pathlib import Path
from typing import List, Optional, Union
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
import invokeai.backend.util.logging as logger
from ...backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType
class MergeInterpolationMethod(str, Enum):
WeightedSum = "weighted_sum"
Sigmoid = "sigmoid"
InvSigmoid = "inv_sigmoid"
AddDifference = "add_difference"
class ModelMerger(object):
def __init__(self, manager: ModelManager):
self.manager = manager
def merge_diffusion_models(
self,
model_paths: List[Path],
alpha: float = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: bool = False,
**kwargs,
) -> DiffusionPipeline:
"""
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
pipe = DiffusionPipeline.from_pretrained(
model_paths[0],
custom_pipeline="checkpoint_merger",
)
merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_paths,
alpha=alpha,
interp=interp.value if interp else None, # diffusers API treats None as "weighted sum"
force=force,
**kwargs,
)
dlogging.set_verbosity(verbosity)
return merged_pipe
def merge_diffusion_models_and_save(
self,
model_names: List[str],
base_model: Union[BaseModelType, str],
merged_model_name: str,
alpha: float = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: bool = False,
merge_dest_directory: Optional[Path] = None,
**kwargs,
) -> AddModelResult:
"""
:param models: up to three models, designated by their InvokeAI models.yaml model name
:param base_model: base model (must be the same for all merged models!)
:param merged_model_name: name for new model
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
model_paths = []
config = self.manager.app_config
base_model = BaseModelType(base_model)
vae = None
for mod in model_names:
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
assert info, f"model {mod}, base_model {base_model}, is unknown"
assert (
info["model_format"] == "diffusers"
), f"{mod} is not a diffusers model. It must be optimized before merging"
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
assert (
len(model_names) <= 2 or interp == MergeInterpolationMethod.AddDifference
), "When merging three models, only the 'add_difference' merge method is supported"
# pick up the first model's vae
if mod == model_names[0]:
vae = info.get("vae")
model_paths.extend([(config.root_path / info["path"]).as_posix()])
merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp)
logger.debug(f"interp = {interp}, merge_method={merge_method}")
merged_pipe = self.merge_diffusion_models(model_paths, alpha, merge_method, force, **kwargs)
dump_path = (
Path(merge_dest_directory)
if merge_dest_directory
else config.models_path / base_model.value / ModelType.Main.value
)
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = (dump_path / merged_model_name).as_posix()
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
attributes = {
"path": dump_path,
"description": f"Merge of models {', '.join(model_names)}",
"model_format": "diffusers",
"variant": ModelVariantType.Normal.value,
"vae": vae,
}
return self.manager.add_model(
merged_model_name,
base_model=base_model,
model_type=ModelType.Main,
model_attributes=attributes,
clobber=True,
)

View File

@ -1,664 +0,0 @@
import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, Literal, Optional, Union
import safetensors.torch
import torch
from diffusers import ConfigMixin, ModelMixin
from picklescan.scanner import scan_file_path
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
from .models import (
BaseModelType,
InvalidModelException,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SilenceWarnings,
)
from .models.base import read_checkpoint_meta
from .util import lora_token_vector_length
@dataclass
class ModelProbeInfo(object):
model_type: ModelType
base_type: BaseModelType
variant_type: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
image_size: int
name: Optional[str] = None
description: Optional[str] = None
class ProbeBase(object):
"""forward declaration"""
pass
class ModelProbe(object):
PROBES = {
"diffusers": {},
"checkpoint": {},
"onnx": {},
}
CLASS2TYPE = {
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"AutoencoderTiny": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter,
}
@classmethod
def register_probe(
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
):
cls.PROBES[format][model_type] = probe_class
@classmethod
def heuristic_probe(
cls,
model: Union[Dict, ModelMixin, Path],
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
) -> ModelProbeInfo:
if isinstance(model, Path):
return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper)
elif isinstance(model, (dict, ModelMixin, ConfigMixin)):
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
else:
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
@classmethod
def probe(
cls,
model_path: Path,
model: Optional[Union[Dict, ModelMixin]] = None,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> ModelProbeInfo:
"""
Probe the model at model_path and return sufficient information about it
to place it somewhere in the models directory hierarchy. If the model is
already loaded into memory, you may provide it as model in order to avoid
opening it a second time. The prediction_type_helper callable is a function that receives
the path to the model and returns the SchedulerPredictionType.
"""
if model_path:
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
else:
format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint"
model_info = None
try:
model_type = (
cls.get_model_type_from_folder(model_path, model)
if format_type == "diffusers"
else cls.get_model_type_from_checkpoint(model_path, model)
)
format_type = "onnx" if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class:
return None
probe = probe_class(model_path, model, prediction_type_helper)
base_type = probe.get_base_type()
variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type()
name = cls.get_model_name(model_path)
description = f"{base_type.value} {model_type.value} model {name}"
format = probe.get_format()
model_info = ModelProbeInfo(
model_type=model_type,
base_type=base_type,
variant_type=variant_type,
prediction_type=prediction_type,
name=name,
description=description,
upcast_attention=(
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
),
format=format,
image_size=(
1024
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
else (
768
if (
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
)
else 512
)
),
)
except Exception:
raise
return model_info
@classmethod
def get_model_name(cls, model_path: Path) -> str:
if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
return model_path.stem
else:
return model_path.name
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
return None
if model_path.name == "learned_embeds.bin":
return ModelType.TextualInversion
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
ckpt = ckpt.get("state_dict", ckpt)
for key in ckpt.keys():
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
return ModelType.Vae
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.Lora
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
return ModelType.Lora
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.ControlNet
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
else:
# diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
raise InvalidModelException(f"Unable to determine model type for {model_path}")
@classmethod
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType:
"""
Get the model type of a hugging-face style folder.
"""
class_name = None
error_hint = None
if model:
class_name = model.__class__.__name__
else:
for suffix in ["bin", "safetensors"]:
if (folder_path / f"learned_embeds.{suffix}").exists():
return ModelType.TextualInversion
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
return ModelType.Lora
if (folder_path / "unet/model.onnx").exists():
return ModelType.ONNX
if (folder_path / "image_encoder.txt").exists():
return ModelType.IPAdapter
i = folder_path / "model_index.json"
c = folder_path / "config.json"
config_path = i if i.exists() else c if c.exists() else None
if config_path:
with open(config_path, "r") as file:
conf = json.load(file)
if "_class_name" in conf:
class_name = conf["_class_name"]
elif "architectures" in conf:
class_name = conf["architectures"][0]
else:
class_name = None
else:
error_hint = f"No model_index.json or config.json found in {folder_path}."
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type
else:
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
# give up
raise InvalidModelException(
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
)
@classmethod
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model_path, model_path)
return torch.load(model_path, map_location="cpu")
else:
return safetensors.torch.load_file(model_path)
@classmethod
def _scan_model(cls, model_name, checkpoint):
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
# ##################################################3
# Checkpoint probing
# ##################################################3
class ProbeBase(object):
def get_base_type(self) -> BaseModelType:
pass
def get_variant_type(self) -> ModelVariantType:
pass
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
pass
def get_format(self) -> str:
pass
class CheckpointProbeBase(ProbeBase):
def __init__(
self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None
) -> BaseModelType:
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
self.checkpoint_path = checkpoint_path
self.helper = helper
def get_base_type(self) -> BaseModelType:
pass
def get_format(self) -> str:
return "checkpoint"
def get_variant_type(self) -> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint)
if model_type != ModelType.Main:
return ModelVariantType.Normal
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
else:
raise InvalidModelException(
f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}"
)
class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
else:
raise InvalidModelException("Cannot determine base type")
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
"""Return model prediction type."""
# if there is a .yaml associated with this checkpoint, then we do not need
# to probe for the prediction type as it will be ignored.
if self.checkpoint_path and self.checkpoint_path.with_suffix(".yaml").exists():
return None
type = self.get_base_type()
if type == BaseModelType.StableDiffusion2:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if "global_step" in checkpoint:
if checkpoint["global_step"] == 220000:
return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000:
return SchedulerPredictionType.VPrediction
if self.helper and self.checkpoint_path:
if helper_guess := self.helper(self.checkpoint_path):
return helper_guess
return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
elif type == BaseModelType.StableDiffusion1:
if self.helper and self.checkpoint_path:
if helper_guess := self.helper(self.checkpoint_path):
return helper_guess
return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
else:
return None
class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1
class LoRACheckpointProbe(CheckpointProbeBase):
def get_format(self) -> str:
return "lycoris"
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 1280:
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")
class TextualInversionCheckpointProbe(CheckpointProbeBase):
def get_format(self) -> str:
return None
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
if "string_to_token" in checkpoint:
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
elif "emb_params" in checkpoint:
token_dim = checkpoint["emb_params"].shape[-1]
elif "clip_g" in checkpoint:
token_dim = checkpoint["clip_g"].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[-1]
if token_dim == 768:
return BaseModelType.StableDiffusion1
elif token_dim == 1024:
return BaseModelType.StableDiffusion2
elif token_dim == 1280:
return BaseModelType.StableDiffusionXL
else:
return None
class ControlNetCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
for key_name in (
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
):
if key_name not in checkpoint:
continue
if checkpoint[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
elif checkpoint[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
elif self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path)
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
class IPAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
########################################################
# classes for probing folders
#######################################################
class FolderProbeBase(ProbeBase):
def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used
self.model = model
self.folder_path = folder_path
def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal
def get_format(self) -> str:
return "diffusers"
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
if self.model:
unet_conf = self.model.unet.config
else:
with open(self.folder_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"Unknown base model for {self.folder_path}")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
if self.model:
scheduler_conf = self.model.scheduler.config
else:
with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
scheduler_conf = json.load(file)
if scheduler_conf["prediction_type"] == "v_prediction":
return SchedulerPredictionType.VPrediction
elif scheduler_conf["prediction_type"] == "epsilon":
return SchedulerPredictionType.Epsilon
else:
return None
def get_variant_type(self) -> ModelVariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the
# "normal" variant type
try:
if self.model:
conf = self.model.unet.config
else:
config_file = self.folder_path / "unet" / "config.json"
with open(config_file, "r") as file:
conf = json.load(file)
in_channels = conf["in_channels"]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
except Exception:
pass
return ModelVariantType.Normal
class VaeFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._name_looks_like_sdxl():
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
return BaseModelType.StableDiffusionXL
else:
return BaseModelType.StableDiffusion1
def _config_looks_like_sdxl(self) -> bool:
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file:
config = json.load(file)
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
def _name_looks_like_sdxl(self) -> bool:
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
def _guess_name(self) -> str:
name = self.folder_path.name
if name == "vae":
name = self.folder_path.parent.name
return name
class TextualInversionFolderProbe(FolderProbeBase):
def get_format(self) -> str:
return None
def get_base_type(self) -> BaseModelType:
path = self.folder_path / "learned_embeds.bin"
if not path.exists():
return None
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
class ONNXFolderProbe(FolderProbeBase):
def get_format(self) -> str:
return "onnx"
def get_base_type(self) -> BaseModelType:
return BaseModelType.StableDiffusion1
def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal
class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file:
config = json.load(file)
# no obvious way to distinguish between sd2-base and sd2-768
dimension = config["cross_attention_dim"]
base_model = (
BaseModelType.StableDiffusion1
if dimension == 768
else (
BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
)
)
if not base_model:
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
return base_model
class LoRAFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
model_file = None
for suffix in ["safetensors", "bin"]:
base_file = self.folder_path / f"pytorch_lora_weights.{suffix}"
if base_file.exists():
model_file = base_file
break
if not model_file:
raise InvalidModelException("Unknown LoRA format encountered")
return LoRACheckpointProbe(model_file, None).get_base_type()
class IPAdapterFolderProbe(FolderProbeBase):
def get_format(self) -> str:
return IPAdapterModelFormat.InvokeAI.value
def get_base_type(self) -> BaseModelType:
model_file = self.folder_path / "ip_adapter.bin"
if not model_file.exists():
raise InvalidModelException("Unknown IP-Adapter model format.")
state_dict = torch.load(model_file, map_location="cpu")
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
return BaseModelType.StableDiffusion1
elif cross_attention_dim == 1024:
return BaseModelType.StableDiffusion2
elif cross_attention_dim == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
class CLIPVisionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file:
config = json.load(file)
adapter_type = config.get("adapter_type", None)
if adapter_type == "full_adapter_xl":
return BaseModelType.StableDiffusionXL
elif adapter_type == "full_adapter" or "light_adapter":
# I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
return BaseModelType.StableDiffusion1
else:
raise InvalidModelException(
f"Unable to determine base model for '{self.folder_path}' (adapter_type = {adapter_type})."
)
############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@ -1,112 +0,0 @@
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
"""
Abstract base class for recursive directory search for models.
"""
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Set, types
import invokeai.backend.util.logging as logger
class ModelSearch(ABC):
def __init__(self, directories: List[Path], logger: types.ModuleType = logger):
"""
Initialize a recursive model directory search.
:param directories: List of directory Paths to recurse through
:param logger: Logger to use
"""
self.directories = directories
self.logger = logger
self._items_scanned = 0
self._models_found = 0
self._scanned_dirs = set()
self._scanned_paths = set()
self._pruned_paths = set()
@abstractmethod
def on_search_started(self):
"""
Called before the scan starts.
"""
pass
@abstractmethod
def on_model_found(self, model: Path):
"""
Process a found model. Raise an exception if something goes wrong.
:param model: Model to process - could be a directory or checkpoint.
"""
pass
@abstractmethod
def on_search_completed(self):
"""
Perform some activity when the scan is completed. May use instance
variables, items_scanned and models_found
"""
pass
def search(self):
self.on_search_started()
for dir in self.directories:
self.walk_directory(dir)
self.on_search_completed()
def walk_directory(self, path: Path):
for root, dirs, files in os.walk(path, followlinks=True):
if str(Path(root).name).startswith("."):
self._pruned_paths.add(root)
if any(Path(root).is_relative_to(x) for x in self._pruned_paths):
continue
self._items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in self._scanned_paths or path.parent in self._scanned_dirs:
self._scanned_dirs.add(path)
continue
if any(
(path / x).exists()
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
}
):
try:
self.on_model_found(path)
self._models_found += 1
self._scanned_dirs.add(path)
except Exception as e:
self.logger.warning(f"Failed to process '{path}': {e}")
for f in files:
path = Path(root) / f
if path.parent in self._scanned_dirs:
continue
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
try:
self.on_model_found(path)
self._models_found += 1
except Exception as e:
self.logger.warning(f"Failed to process '{path}': {e}")
class FindModels(ModelSearch):
def on_search_started(self):
self.models_found: Set[Path] = set()
def on_model_found(self, model: Path):
self.models_found.add(model)
def on_search_completed(self):
pass
def list_models(self) -> List[Path]:
self.search()
return list(self.models_found)

View File

@ -1,167 +0,0 @@
import inspect
from enum import Enum
from typing import Literal, get_origin
from pydantic import BaseModel, ConfigDict, create_model
from .base import ( # noqa: F401
BaseModelType,
DuplicateModelException,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelError,
ModelNotFoundException,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SilenceWarnings,
SubModelType,
)
from .clip_vision import CLIPVisionModel
from .controlnet import ControlNetModel # TODO:
from .ip_adapter import IPAdapterModel
from .lora import LoRAModel
from .sdxl import StableDiffusionXLModel
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
from .t2i_adapter import T2IAdapterModel
from .textual_inversion import TextualInversionModel
from .vae import VaeModel
MODEL_CLASSES = {
BaseModelType.StableDiffusion1: {
ModelType.ONNX: ONNXStableDiffusion1Model,
ModelType.Main: StableDiffusion1Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
BaseModelType.StableDiffusion2: {
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.Main: StableDiffusion2Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
BaseModelType.StableDiffusionXL: {
ModelType.Main: StableDiffusionXLModel,
ModelType.Vae: VaeModel,
# will not work until support written
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
BaseModelType.StableDiffusionXLRefiner: {
ModelType.Main: StableDiffusionXLModel,
ModelType.Vae: VaeModel,
# will not work until support written
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
BaseModelType.Any: {
ModelType.CLIPVision: CLIPVisionModel,
# The following model types are not expected to be used with BaseModelType.Any.
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.Main: StableDiffusion2Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
# BaseModelType.Kandinsky2_1: {
# ModelType.Main: Kandinsky2_1Model,
# ModelType.MoVQ: MoVQModel,
# ModelType.Lora: LoRAModel,
# ModelType.ControlNet: ControlNetModel,
# ModelType.TextualInversion: TextualInversionModel,
# },
}
MODEL_CONFIGS = []
OPENAPI_MODEL_CONFIGS = []
class OpenAPIModelInfoBase(BaseModel):
model_name: str
base_model: BaseModelType
model_type: ModelType
model_config = ConfigDict(protected_namespaces=())
for _base_model, models in MODEL_CLASSES.items():
for model_type, model_class in models.items():
model_configs = set(model_class._get_configs().values())
model_configs.discard(None)
MODEL_CONFIGS.extend(model_configs)
# LS: sort to get the checkpoint configs first, which makes
# for a better template in the Swagger docs
for cfg in sorted(model_configs, key=lambda x: str(x)):
model_name, cfg_name = cfg.__qualname__.split(".")[-2:]
openapi_cfg_name = model_name + cfg_name
if openapi_cfg_name in vars():
continue
api_wrapper = create_model(
openapi_cfg_name,
__base__=(cfg, OpenAPIModelInfoBase),
model_type=(Literal[model_type], model_type), # type: ignore
)
vars()[openapi_cfg_name] = api_wrapper
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
def get_model_config_enums():
enums = []
for model_config in MODEL_CONFIGS:
if hasattr(inspect, "get_annotations"):
fields = inspect.get_annotations(model_config)
else:
fields = model_config.__annotations__
try:
field = fields["model_format"]
except Exception:
raise Exception("format field not found")
# model_format: None
# model_format: SomeModelFormat
# model_format: Literal[SomeModelFormat.Diffusers]
# model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint]
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
enums.append(field)
elif get_origin(field) is Literal and all(
isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__
):
enums.append(type(field.__args__[0]))
elif field is None:
pass
else:
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
return enums

View File

@ -1,681 +0,0 @@
import inspect
import json
import os
import sys
import typing
import warnings
from abc import ABCMeta, abstractmethod
from contextlib import suppress
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union
import numpy as np
import onnx
import safetensors.torch
import torch
from diffusers import ConfigMixin, DiffusionPipeline
from diffusers import logging as diffusers_logging
from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from picklescan.scanner import scan_file_path
from pydantic import BaseModel, ConfigDict, Field
from transformers import logging as transformers_logging
class DuplicateModelException(Exception):
pass
class InvalidModelException(Exception):
pass
class ModelNotFoundException(Exception):
pass
class BaseModelType(str, Enum):
Any = "any" # For models that are not associated with any particular base model.
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
# Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
ONNX = "onnx"
Main = "main"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
T2IAdapter = "t2i_adapter"
class SubModelType(str, Enum):
UNet = "unet"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Vae = "vae"
VaeDecoder = "vae_decoder"
VaeEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
# MoVQ = "movq"
class ModelVariantType(str, Enum):
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class SchedulerPredictionType(str, Enum):
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
class ModelError(str, Enum):
NotFound = "not_found"
def model_config_json_schema_extra(schema: dict[str, Any]) -> None:
if "required" not in schema:
schema["required"] = []
schema["required"].append("model_type")
class ModelConfigBase(BaseModel):
path: str # or Path
description: Optional[str] = Field(None)
model_format: Optional[str] = Field(None)
error: Optional[ModelError] = Field(None)
model_config = ConfigDict(
use_enum_values=True, protected_namespaces=(), json_schema_extra=model_config_json_schema_extra
)
class EmptyConfigLoader(ConfigMixin):
@classmethod
def load_config(cls, *args, **kwargs):
cls.config_name = kwargs.pop("config_name")
return super().load_config(*args, **kwargs)
T_co = TypeVar("T_co", covariant=True)
class classproperty(Generic[T_co]):
def __init__(self, fget: Callable[[Any], T_co]) -> None:
self.fget = fget
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
return self.fget(owner)
def __set__(self, instance: Optional[Any], value: Any) -> None:
raise AttributeError("cannot set attribute")
class ModelBase(metaclass=ABCMeta):
# model_path: str
# base_model: BaseModelType
# model_type: ModelType
def __init__(
self,
model_path: str,
base_model: BaseModelType,
model_type: ModelType,
):
self.model_path = model_path
self.base_model = base_model
self.model_type = model_type
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
if len(subtypes) < 2:
raise Exception("Invalid subfolder definition!")
if all(t is None for t in subtypes):
return None
elif any(t is None for t in subtypes):
raise Exception(f"Unsupported definition: {subtypes}")
if subtypes[0] in ["diffusers", "transformers"]:
res_type = sys.modules[subtypes[0]]
subtypes = subtypes[1:]
else:
res_type = sys.modules["diffusers"]
res_type = res_type.pipelines
for subtype in subtypes:
res_type = getattr(res_type, subtype)
return res_type
@classmethod
def _get_configs(cls):
with suppress(Exception):
return cls.__configs
configs = {}
for name in dir(cls):
if name.startswith("__"):
continue
value = getattr(cls, name)
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
continue
if hasattr(inspect, "get_annotations"):
fields = inspect.get_annotations(value)
else:
fields = value.__annotations__
try:
field = fields["model_format"]
except Exception:
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
for model_format in field:
configs[model_format.value] = value
elif typing.get_origin(field) is Literal and all(
isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__
):
for model_format in field.__args__:
configs[model_format.value] = value
elif field is None:
configs[None] = value
else:
raise Exception(f"Unsupported format definition in {cls.__qualname__}")
cls.__configs = configs
return cls.__configs
@classmethod
def create_config(cls, **kwargs) -> ModelConfigBase:
if "model_format" not in kwargs:
raise Exception("Field 'model_format' not found in model config")
configs = cls._get_configs()
return configs[kwargs["model_format"]](**kwargs)
@classmethod
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
return cls.create_config(
path=path,
model_format=cls.detect_format(path),
)
@classmethod
@abstractmethod
def detect_format(cls, path: str) -> str:
raise NotImplementedError()
@classproperty
@abstractmethod
def save_to_config(cls) -> bool:
raise NotImplementedError()
@abstractmethod
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
raise NotImplementedError()
@abstractmethod
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> Any:
raise NotImplementedError()
class DiffusersModel(ModelBase):
# child_types: Dict[str, Type]
# child_sizes: Dict[str, int]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
super().__init__(model_path, base_model, model_type)
self.child_types: Dict[str, Type] = {}
self.child_sizes: Dict[str, int] = {}
try:
config_data = DiffusionPipeline.load_config(self.model_path)
# config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
except Exception:
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
config_data.pop("_ignore_files", None)
# retrieve all folder_names that contain relevant files
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
for child_name in child_components:
child_type = self._hf_definition_to_type(config_data[child_name])
self.child_types[child_name] = child_type
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is None:
return sum(self.child_sizes.values())
else:
return self.child_sizes[child_type]
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
# return pipeline in different function to pass more arguments
if child_type is None:
raise Exception("Child model type can't be null on diffusers model")
if child_type not in self.child_types:
return None # TODO: or raise
if torch_dtype == torch.float16:
variants = ["fp16", None]
else:
variants = [None, "fp16"]
# TODO: better error handling(differentiate not found from others)
for variant in variants:
try:
# TODO: set cache_dir to /dev/null to be sure that cache not used?
model = self.child_types[child_type].from_pretrained(
self.model_path,
subfolder=child_type.value,
torch_dtype=torch_dtype,
variant=variant,
local_files_only=True,
)
break
except Exception as e:
if not str(e).startswith("Error no file"):
print("====ERR LOAD====")
print(f"{variant}: {e}")
pass
else:
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
# calc more accurate size
self.child_sizes[child_type] = calc_model_size_by_data(model)
return model
# def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, variant: Optional[str] = None):
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
# this can happen when, for example, the safety checker
# is not downloaded.
if not os.path.exists(model_path):
return 0
all_files = os.listdir(model_path)
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
fp16_files = {f for f in all_files if ".fp16." in f or ".fp16-" in f}
bit8_files = {f for f in all_files if ".8bit." in f or ".8bit-" in f}
other_files = set(all_files) - fp16_files - bit8_files
if variant is None:
files = other_files
elif variant == "fp16":
files = fp16_files
elif variant == "8bit":
files = bit8_files
else:
raise NotImplementedError(f"Unknown variant: {variant}")
# try read from index if exists
index_postfix = ".index.json"
if variant is not None:
index_postfix = f".index.{variant}.json"
for file in files:
if not file.endswith(index_postfix):
continue
try:
with open(os.path.join(model_path, file), "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except Exception:
pass
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
model_files = [f for f in files if f.endswith(file_format)]
if len(model_files) == 0:
continue
model_size = 0
for model_file in model_files:
file_stats = os.stat(os.path.join(model_path, model_file))
model_size += file_stats.st_size
return model_size
# raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
def calc_model_size_by_data(model) -> int:
if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model)
elif isinstance(model, torch.nn.Module):
return _calc_model_by_data(model)
elif isinstance(model, IAIOnnxRuntimeModel):
return _calc_onnx_model_by_data(model)
else:
return 0
def _calc_pipeline_by_data(pipeline) -> int:
res = 0
for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key)
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += _calc_model_by_data(submodel)
return res
def _calc_model_by_data(model) -> int:
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
return mem
def _calc_onnx_model_by_data(model) -> int:
tensor_size = model.tensors.size() * 2 # The session doubles this
mem = tensor_size # in bytes
return mem
def _fast_safetensors_reader(path: str):
checkpoint = {}
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), "little")
definition_json = f.read(definition_len)
definition = json.loads(definition_json)
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {
"pt",
"torch",
"pytorch",
}:
raise Exception("Supported only pytorch safetensors files")
definition.pop("__metadata__", None)
for key, info in definition.items():
dtype = {
"I8": torch.int8,
"I16": torch.int16,
"I32": torch.int32,
"I64": torch.int64,
"F16": torch.float16,
"F32": torch.float32,
"F64": torch.float64,
}[info["dtype"]]
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
return checkpoint
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
if str(path).endswith(".safetensors"):
try:
checkpoint = _fast_safetensors_reader(path)
except Exception:
# TODO: create issue for support "meta"?
checkpoint = safetensors.torch.load_file(path, device="cpu")
else:
if scan:
scan_result = scan_file_path(path)
if scan_result.infected_files != 0:
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint
class SilenceWarnings(object):
def __init__(self):
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self):
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore")
def __exit__(self, type, value, traceback):
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter("default")
ONNX_WEIGHTS_NAME = "model.onnx"
class IAIOnnxRuntimeModel:
class _tensor_access:
def __init__(self, model):
self.model = model
self.indexes = {}
for idx, obj in enumerate(self.model.proto.graph.initializer):
self.indexes[obj.name] = idx
def __getitem__(self, key: str):
value = self.model.proto.graph.initializer[self.indexes[key]]
return numpy_helper.to_array(value)
def __setitem__(self, key: str, value: np.ndarray):
new_node = numpy_helper.from_array(value)
# set_external_data(new_node, location="in-memory-location")
new_node.name = key
# new_node.ClearField("raw_data")
del self.model.proto.graph.initializer[self.indexes[key]]
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
# self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
# __delitem__
def __contains__(self, key: str):
return self.indexes[key] in self.model.proto.graph.initializer
def items(self):
raise NotImplementedError("tensor.items")
# return [(obj.name, obj) for obj in self.raw_proto]
def keys(self):
return self.indexes.keys()
def values(self):
raise NotImplementedError("tensor.values")
# return [obj for obj in self.raw_proto]
def size(self):
bytesSum = 0
for node in self.model.proto.graph.initializer:
bytesSum += sys.getsizeof(node.raw_data)
return bytesSum
class _access_helper:
def __init__(self, raw_proto):
self.indexes = {}
self.raw_proto = raw_proto
for idx, obj in enumerate(raw_proto):
self.indexes[obj.name] = idx
def __getitem__(self, key: str):
return self.raw_proto[self.indexes[key]]
def __setitem__(self, key: str, value):
index = self.indexes[key]
del self.raw_proto[index]
self.raw_proto.insert(index, value)
# __delitem__
def __contains__(self, key: str):
return key in self.indexes
def items(self):
return [(obj.name, obj) for obj in self.raw_proto]
def keys(self):
return self.indexes.keys()
def values(self):
return list(self.raw_proto)
def __init__(self, model_path: str, provider: Optional[str]):
self.path = model_path
self.session = None
self.provider = provider
"""
self.data_path = self.path + "_data"
if not os.path.exists(self.data_path):
print(f"Moving model tensors to separate file: {self.data_path}")
tmp_proto = onnx.load(model_path, load_external_data=True)
onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False)
del tmp_proto
gc.collect()
self.proto = onnx.load(model_path, load_external_data=False)
"""
self.proto = onnx.load(model_path, load_external_data=True)
# self.data = dict()
# for tensor in self.proto.graph.initializer:
# name = tensor.name
# if tensor.HasField("raw_data"):
# npt = numpy_helper.to_array(tensor)
# orv = OrtValue.ortvalue_from_numpy(npt)
# # self.data[name] = orv
# # set_external_data(tensor, location="in-memory-location")
# tensor.name = name
# # tensor.ClearField("raw_data")
self.nodes = self._access_helper(self.proto.graph.node)
# self.initializers = self._access_helper(self.proto.graph.initializer)
# print(self.proto.graph.input)
# print(self.proto.graph.initializer)
self.tensors = self._tensor_access(self)
# TODO: integrate with model manager/cache
def create_session(self, height=None, width=None):
if self.session is None or self.session_width != width or self.session_height != height:
# onnx.save(self.proto, "tmp.onnx")
# onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
# TODO: something to be able to get weight when they already moved outside of model proto
# (trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
sess = SessionOptions()
# self._external_data.update(**external_data)
# sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
# sess.enable_profiling = True
# sess.intra_op_num_threads = 1
# sess.inter_op_num_threads = 1
# sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL
# sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
# sess.enable_cpu_mem_arena = True
# sess.enable_mem_pattern = True
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
self.session_height = height
self.session_width = width
if height and width:
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height)
sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width)
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
providers = []
if self.provider:
providers.append(self.provider)
else:
providers = get_available_providers()
if "TensorrtExecutionProvider" in providers:
providers.remove("TensorrtExecutionProvider")
try:
self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess)
except Exception as e:
raise e
# self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
# self.io_binding = self.session.io_binding()
def release_session(self):
self.session = None
import gc
gc.collect()
return
def __call__(self, **kwargs):
if self.session is None:
raise Exception("You should call create_session before running model")
inputs = {k: np.array(v) for k, v in kwargs.items()}
# output_names = self.session.get_outputs()
# for k in inputs:
# self.io_binding.bind_cpu_input(k, inputs[k])
# for name in output_names:
# self.io_binding.bind_output(name.name)
# self.session.run_with_iobinding(self.io_binding, None)
# return self.io_binding.copy_outputs_to_cpu()
return self.session.run(None, inputs)
# compatability with diffusers load code
@classmethod
def from_pretrained(
cls,
model_id: Union[str, Path],
subfolder: Union[str, Path] = None,
file_name: Optional[str] = None,
provider: Optional[str] = None,
sess_options: Optional["SessionOptions"] = None,
**kwargs,
):
file_name = file_name or ONNX_WEIGHTS_NAME
if os.path.isdir(model_id):
model_path = model_id
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
model_path = os.path.join(model_path, file_name)
else:
model_path = model_id
# load model from local directory
if not os.path.isfile(model_path):
raise Exception(f"Model not found: {model_path}")
# TODO: session options
return cls(model_path, provider=provider)

View File

@ -1,82 +0,0 @@
import os
from enum import Enum
from typing import Literal, Optional
import torch
from transformers import CLIPVisionModelWithProjection
from invokeai.backend.model_management.models.base import (
BaseModelType,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelType,
SubModelType,
calc_model_size_by_data,
calc_model_size_by_fs,
classproperty,
)
class CLIPVisionModelFormat(str, Enum):
Diffusers = "diffusers"
class CLIPVisionModel(ModelBase):
class DiffusersConfig(ModelConfigBase):
model_format: Literal[CLIPVisionModelFormat.Diffusers]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.CLIPVision
super().__init__(model_path, base_model, model_type)
self.model_size = calc_model_size_by_fs(self.model_path)
@classmethod
def detect_format(cls, path: str) -> str:
if not os.path.exists(path):
raise ModuleNotFoundError(f"No CLIP Vision model at path '{path}'.")
if os.path.isdir(path) and os.path.exists(os.path.join(path, "config.json")):
return CLIPVisionModelFormat.Diffusers
raise InvalidModelException(f"Unexpected CLIP Vision model format: {path}")
@classproperty
def save_to_config(cls) -> bool:
return True
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
if child_type is not None:
raise ValueError("There are no child models in a CLIP Vision model.")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> CLIPVisionModelWithProjection:
if child_type is not None:
raise ValueError("There are no child models in a CLIP Vision model.")
model = CLIPVisionModelWithProjection.from_pretrained(self.model_path, torch_dtype=torch_dtype)
# Calculate a more accurate model size.
self.model_size = calc_model_size_by_data(model)
return model
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == CLIPVisionModelFormat.Diffusers:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")

View File

@ -1,163 +0,0 @@
import os
from enum import Enum
from pathlib import Path
from typing import Literal, Optional
import torch
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from .base import (
BaseModelType,
EmptyConfigLoader,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelNotFoundException,
ModelType,
SubModelType,
calc_model_size_by_data,
calc_model_size_by_fs,
classproperty,
)
class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class ControlNetModel(ModelBase):
# model_class: Type
# model_size: int
class DiffusersConfig(ModelConfigBase):
model_format: Literal[ControlNetModelFormat.Diffusers]
class CheckpointConfig(ModelConfigBase):
model_format: Literal[ControlNetModelFormat.Checkpoint]
config: str
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet
super().__init__(model_path, base_model, model_type)
try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
# config = json.loads(os.path.join(self.model_path, "config.json"))
except Exception:
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
model_class_name = config.get("_class_name", None)
if model_class_name not in {"ControlNetModel"}:
raise Exception(f"Invalid ControlNet model! Unknown _class_name: {model_class_name}")
try:
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
self.model_size = calc_model_size_by_fs(self.model_path)
except Exception:
raise Exception("Invalid ControlNet model!")
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in controlnet model")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There are no child models in controlnet model")
model = None
for variant in ["fp16", None]:
try:
model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
variant=variant,
)
break
except Exception:
pass
if not model:
raise ModelNotFoundException()
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model
@classproperty
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException()
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "config.json")):
return ControlNetModelFormat.Diffusers
if os.path.isfile(path):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]):
return ControlNetModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {path}")
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint:
return _convert_controlnet_ckpt_and_cache(
model_path=model_path,
model_config=config.config,
output_path=output_path,
base_model=base_model,
)
else:
return model_path
def _convert_controlnet_ckpt_and_cache(
model_path: str,
output_path: str,
base_model: BaseModelType,
model_config: str,
) -> str:
"""
Convert the controlnet from checkpoint format to diffusers format,
cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
print(f"DEBUG: controlnet config = {model_config}")
app_config = InvokeAIAppConfig.get_config()
weights = app_config.root_path / model_path
output_path = Path(output_path)
logger.info(f"Converting {weights} to diffusers format")
# return cached version if it exists
if output_path.exists():
return output_path
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
convert_controlnet_to_diffusers(
weights,
output_path,
original_config_file=app_config.root_path / model_config,
image_size=512,
scan_needed=True,
from_safetensors=weights.suffix == ".safetensors",
)
return output_path

View File

@ -1,98 +0,0 @@
import os
import typing
from enum import Enum
from typing import Literal, Optional
import torch
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, build_ip_adapter
from invokeai.backend.model_management.models.base import (
BaseModelType,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelType,
SubModelType,
calc_model_size_by_fs,
classproperty,
)
class IPAdapterModelFormat(str, Enum):
# The custom IP-Adapter model format defined by InvokeAI.
InvokeAI = "invokeai"
class IPAdapterModel(ModelBase):
class InvokeAIConfig(ModelConfigBase):
model_format: Literal[IPAdapterModelFormat.InvokeAI]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.IPAdapter
super().__init__(model_path, base_model, model_type)
self.model_size = calc_model_size_by_fs(self.model_path)
@classmethod
def detect_format(cls, path: str) -> str:
if not os.path.exists(path):
raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.")
if os.path.isdir(path):
model_file = os.path.join(path, "ip_adapter.bin")
image_encoder_config_file = os.path.join(path, "image_encoder.txt")
if os.path.exists(model_file) and os.path.exists(image_encoder_config_file):
return IPAdapterModelFormat.InvokeAI
raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}")
@classproperty
def save_to_config(cls) -> bool:
return True
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
return self.model_size
def get_model(
self,
torch_dtype: torch.dtype,
child_type: Optional[SubModelType] = None,
) -> typing.Union[IPAdapter, IPAdapterPlus]:
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
model = build_ip_adapter(
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"),
device=torch.device("cpu"),
dtype=torch_dtype,
)
self.model_size = model.calc_size()
return model
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == IPAdapterModelFormat.InvokeAI:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")
def get_ip_adapter_image_encoder_model_id(model_path: str):
"""Read the ID of the image encoder associated with the IP-Adapter at `model_path`."""
image_encoder_config_file = os.path.join(model_path, "image_encoder.txt")
with open(image_encoder_config_file, "r") as f:
image_encoder_model = f.readline().strip()
return image_encoder_model

View File

@ -1,148 +0,0 @@
import json
import os
from enum import Enum
from pathlib import Path
from typing import Literal, Optional
from omegaconf import OmegaConf
from pydantic import Field
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management.detect_baked_in_vae import has_baked_in_sdxl_vae
from invokeai.backend.util.logging import InvokeAILogger
from .base import (
BaseModelType,
DiffusersModel,
InvalidModelException,
ModelConfigBase,
ModelType,
ModelVariantType,
classproperty,
read_checkpoint_meta,
)
class StableDiffusionXLModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusionXLModel(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
model_format: Literal[StableDiffusionXLModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: str
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}
assert model_type == ModelType.Main
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusionXL,
model_type=ModelType.Main,
)
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
if model_format == StableDiffusionXLModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get("state_dict", checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == StableDiffusionXLModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config["in_channels"]
else:
raise InvalidModelException(f"{path} is not a recognized Stable Diffusion diffusers model")
else:
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 5:
variant = ModelVariantType.Depth
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 2.* model format")
if ckpt_config_path is None:
# avoid circular import
from .stable_diffusion import _select_ckpt_config
ckpt_config_path = _select_ckpt_config(kwargs.get("model_base", BaseModelType.StableDiffusionXL), variant)
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return StableDiffusionXLModelFormat.Diffusers
else:
return StableDiffusionXLModelFormat.Checkpoint
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
# The convert script adapted from the diffusers package uses
# strings for the base model type. To avoid making too many
# source code changes, we simply translate here
if Path(output_path).exists():
return output_path
if isinstance(config, cls.CheckpointConfig):
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
# Hack in VAE-fp16 fix - If model sdxl-vae-fp16-fix is installed,
# then we bake it into the converted model unless there is already
# a nonstandard VAE installed.
kwargs = {}
app_config = InvokeAIAppConfig.get_config()
vae_path = app_config.models_path / "sdxl/vae/sdxl-vae-fp16-fix"
if vae_path.exists() and not has_baked_in_sdxl_vae(Path(model_path)):
InvokeAILogger.get_logger().warning("No baked-in VAE detected. Inserting sdxl-vae-fp16-fix.")
kwargs["vae_path"] = vae_path
return _convert_ckpt_and_cache(
version=base_model,
model_config=config,
output_path=output_path,
use_safetensors=True,
**kwargs,
)
else:
return model_path

Some files were not shown because too many files have changed in this diff Show More